|
9 | 9 | import random
|
10 | 10 | from tensorpack import dataflow
|
11 | 11 |
|
| 12 | + |
12 | 13 | class PCN_pcd(data.Dataset):
|
13 |
| - def __init__(self, prefix="train"): |
| 14 | + def __init__(self, path, prefix="train"): |
14 | 15 | if prefix=="train":
|
15 |
| - self.file_path = '/data0/guodongyan/completion_dataset//ShapeNetCompletion/train' |
| 16 | + self.file_path = os.path.join(path,'train') |
16 | 17 | elif prefix=="val":
|
17 |
| - self.file_path = '/data0/guodongyan/completion_dataset//ShapeNetCompletion/val' |
| 18 | + self.file_path = os.path.join(path,'val') |
18 | 19 | elif prefix=="test":
|
19 |
| - self.file_path = '/data0/guodongyan/completion_dataset//ShapeNetCompletion/test' |
| 20 | + self.file_path = os.path.join(path,'test') |
20 | 21 | else:
|
21 | 22 | raise ValueError("ValueError prefix should be [train/val/test] ")
|
22 | 23 |
|
@@ -154,13 +155,13 @@ def __getitem__(self, index):
|
154 | 155 |
|
155 | 156 |
|
156 | 157 | class C3D_h5(data.Dataset):
|
157 |
| - def __init__(self, prefix="train"): |
| 158 | + def __init__(self, path, prefix="train"): |
158 | 159 | if prefix=="train":
|
159 |
| - self.file_path = '/data0/guodongyan/completion_dataset/c3d/shapenet/train' |
| 160 | + self.file_path = os.path.join(path,'train') |
160 | 161 | elif prefix=="val":
|
161 |
| - self.file_path = '/data0/guodongyan/completion_dataset/c3d/shapenet/val' |
| 162 | + self.file_path = os.path.join(path,'val') |
162 | 163 | elif prefix=="test":
|
163 |
| - self.file_path = '/data0/guodongyan/completion_dataset/c3d/shapenet/test' |
| 164 | + self.file_path = os.path.join(path,'test') |
164 | 165 | else:
|
165 | 166 | raise ValueError("ValueError prefix should be [train/val/test] ")
|
166 | 167 |
|
|
0 commit comments