Skip to content

Commit d741516

Browse files
author
gyzhou2000
committed
update ngsim dataset
1 parent a95f092 commit d741516

File tree

2 files changed

+32
-32
lines changed

2 files changed

+32
-32
lines changed

gammagl/datasets/ngsim.py

+31-31
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,42 @@
88

99

1010
class NGSIM_US_101(InMemoryDataset):
11-
1211
r"""
13-
The NGSIM US-101 dataset from the "NGSIM: Next Generation Simulation"
14-
<https://ops.fhwa.dot.gov/trafficanalysistools/ngsim.htm>`_ project,
15-
containing detailed vehicle trajectory data from the US-101 highway in
16-
Los Angeles, California.
12+
The NGSIM US-101 dataset from the "NGSIM: Next Generation Simulation"
13+
<https://ops.fhwa.dot.gov/trafficanalysistools/ngsim.htm>`_ project,
14+
containing detailed vehicle trajectory data from the US-101 highway in
15+
Los Angeles, California.
1716
18-
Parameters
19-
----------
20-
root: str
21-
Root directory where the dataset should be saved.
22-
name: str, optional
23-
The name of the dataset (:obj:`"train", "val", "test"`).
24-
transform: callable, optional
25-
A function/transform that takes in an
26-
:obj:`gammagl.data.Graph` object and returns a transformed
27-
version. The data object will be transformed before every access.
28-
(default: :obj:`None`)
29-
pre_transform: callable, optional
30-
A function/transform that takes in
31-
an :obj:`gammagl.data.Graph` object and returns a
32-
transformed version. The data object will be transformed before
33-
being saved to disk. (default: :obj:`None`)
34-
force_reload (bool, optional): Whether to re-process the dataset.
35-
(default: :obj:`False`)
17+
Parameters
18+
----------
19+
root: str
20+
Root directory where the dataset should be saved.
21+
name: str, optional
22+
The name of the dataset (:obj:`"train", "val", "test"`).
23+
transform: callable, optional
24+
A function/transform that takes in an
25+
:obj:`gammagl.data.Graph` object and returns a transformed
26+
version. The data object will be transformed before every access.
27+
(default: :obj:`None`)
28+
pre_transform: callable, optional
29+
A function/transform that takes in
30+
an :obj:`gammagl.data.Graph` object and returns a
31+
transformed version. The data object will be transformed before
32+
being saved to disk. (default: :obj:`None`)
33+
force_reload (bool, optional): Whether to re-process the dataset.
34+
(default: :obj:`False`)
3635
37-
"""
36+
"""
3837

3938
url = 'https://github.com/gjy1221/NGSIM-US-101/raw/main/data'
4039

4140
def __init__(self, root: str = None, name: str = None,
4241
transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None,
4342
force_reload: bool = False):
44-
self.name = name.lower()
43+
self.name = osp.join('ngsim', name.lower())
44+
self.split = name.lower()
4545
super().__init__(root, transform, pre_transform, force_reload=force_reload)
46-
self.data_path = f'{root}/processed/{name}'
46+
self.data_path = osp.join(self.processed_dir, name)
4747
self.data_names = os.listdir('{}'.format(self.data_path))
4848

4949
def __len__(self):
@@ -57,23 +57,23 @@ def __getitem__(self, index):
5757

5858
@property
5959
def raw_dir(self) -> str:
60-
return osp.join(self.root, 'raw', self.name)
60+
return osp.join(self.root, 'ngsim', 'raw', self.split)
6161

6262
@property
6363
def processed_dir(self) -> str:
64-
return osp.join(self.root, 'processed')
64+
return osp.join(self.root, 'ngsim', 'processed')
6565

6666
@property
6767
def raw_file_names(self) -> List[str]:
68-
return [f'/{self.name.lower()}.zip']
68+
return [f'{self.split.lower()}.zip']
6969

7070
@property
7171
def processed_file_names(self) -> str:
7272
return tlx.BACKEND + '_data.pt'
7373

7474
def download(self):
75-
print(self.root)
76-
path = download_url(self.url + self.raw_file_names[0], self.raw_dir)
75+
# print(self.root)
76+
path = download_url(f'{self.url}/{self.raw_file_names[0]}', self.raw_dir)
7777
with zipfile.ZipFile(path, 'r') as zip_ref:
7878
zip_ref.extractall(self.processed_dir)
7979

gammagl/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
'Graphormer',
115115
'FusedGATModel',
116116
'hid_net',
117-
'HEAT'
117+
'HEAT',
118118
'GNNLFHFModel'
119119
]
120120

0 commit comments

Comments
 (0)