Skip to content

Commit 5c7ccfe

Browse files
committed
add yelp dataset
1 parent 64d3909 commit 5c7ccfe

File tree

6 files changed

+144
-20
lines changed

6 files changed

+144
-20
lines changed

examples/gnnlfhf/gnnlfhf_trainer.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3-
os.environ['TL_BACKEND'] = 'torch'
3+
# os.environ['TL_BACKEND'] = 'torch'
44
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
55
import sys
66
import argparse
@@ -116,8 +116,6 @@ def main(args):
116116
test_acc = calculate_acc(test_logits, test_y, metrics)
117117
print("Test acc: {:.4f}".format(test_acc))
118118

119-
return test_acc
120-
121119

122120
if __name__ == '__main__':
123121
# parameters setting
@@ -145,18 +143,4 @@ def main(args):
145143
else:
146144
tlx.set_device("CPU")
147145

148-
import numpy as np
149-
150-
number = []
151-
for i in range(5):
152-
acc = main(args)
153-
154-
number.append(acc)
155-
156-
print("实验结果:")
157-
print(np.mean(number))
158-
print(np.std(number))
159-
print(number)
160-
161-
# main(args)
162-
146+
main(args)

gammagl/data/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .heterograph import HeteroGraph
33
from .dataset import Dataset
44
from .batch import BatchGraph
5-
from .download import download_url
5+
from .download import download_url, download_google_url
66
from .in_memory_dataset import InMemoryDataset
77
from .extract import extract_zip, extract_tar
88
from .utils import global_config_init
@@ -14,6 +14,7 @@
1414
'HeteroGraph',
1515
'Dataset',
1616
'download_url',
17+
'download_google_url',
1718
'InMemoryDataset',
1819
'extract_zip',
1920
'extract_tar',

gammagl/data/download.py

+7
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,10 @@ def download_url(url: str, folder: str, log: bool = True,
6767
pbar.update(chunk_size)
6868

6969
return path
70+
71+
72+
def download_google_url(id: str, folder: str,
73+
filename: str, log: bool = True):
74+
r"""Downloads the content of a Google Drive ID to a specific folder."""
75+
url = f'https://drive.usercontent.google.com/download?id={id}&confirm=t'
76+
return download_url(url, folder, log, filename)

gammagl/datasets/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .wikics import WikiCS
2020
from .blogcatalog import BlogCatalog
2121
from .molecule_net import MoleculeNet
22+
from .yelp import Yelp
2223

2324
__all__ = [
2425
'Amazon',
@@ -40,7 +41,8 @@
4041
'AMiner',
4142
'PolBlogs',
4243
'WikiCS',
43-
'MoleculeNet'
44+
'MoleculeNet',
45+
'Yelp'
4446
]
4547

4648
classes = __all__

gammagl/datasets/yelp.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import json
2+
import os.path as osp
3+
from typing import Callable, List, Optional
4+
5+
import numpy as np
6+
import scipy.sparse as sp
7+
import tensorlayerx as tlx
8+
9+
from gammagl.data import Graph, InMemoryDataset, download_google_url
10+
11+
12+
class Yelp(InMemoryDataset):
13+
r"""The Yelp dataset from the `"GraphSAINT: Graph Sampling Based
14+
Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper,
15+
containing customer reviewers and their friendship.
16+
17+
Parameters
18+
----------
19+
root: str, optional
20+
Root directory where the dataset should be saved.
21+
transform: callable, optional
22+
A function/transform that takes in an
23+
:obj:`gammagl.data.Graph` object and returns a transformed
24+
version. The data object will be transformed before every access.
25+
(default: :obj:`None`)
26+
pre_transform: callable, optional
27+
A function/transform that takes in
28+
an :obj:`gammagl.data.Graph` object and returns a
29+
transformed version. The data object will be transformed before
30+
being saved to disk. (default: :obj:`None`)
31+
force_reload (bool, optional): Whether to re-process the dataset.
32+
(default: :obj:`False`)
33+
34+
Tip
35+
---
36+
.. list-table::
37+
:widths: 10 10 10 10 10
38+
:header-rows: 1
39+
40+
* - #nodes
41+
- #edges
42+
- #features
43+
- #tasks
44+
* - 716,847
45+
- 13,954,819
46+
- 300
47+
- 100
48+
"""
49+
50+
adj_full_id = '1Juwx8HtDwSzmVIJ31ooVa1WljI4U5JnA'
51+
feats_id = '1Zy6BZH_zLEjKlEFSduKE5tV9qqA_8VtM'
52+
class_map_id = '1VUcBGr0T0-klqerjAjxRmAqFuld_SMWU'
53+
role_id = '1NI5pa5Chpd-52eSmLW60OnB3WS5ikxq_'
54+
55+
def __init__(
56+
self,
57+
root: str = None,
58+
transform: Optional[Callable] = None,
59+
pre_transform: Optional[Callable] = None,
60+
force_reload: bool = False,
61+
) -> None:
62+
super().__init__(root, transform, pre_transform,
63+
force_reload=force_reload)
64+
self.data, self.slices = self.load_data(self.processed_paths[0])
65+
66+
@property
67+
def raw_file_names(self) -> List[str]:
68+
return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json']
69+
70+
@property
71+
def processed_file_names(self) -> str:
72+
return 'data.pt'
73+
74+
def download(self) -> None:
75+
download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz')
76+
download_google_url(self.feats_id, self.raw_dir, 'feats.npy')
77+
download_google_url(self.class_map_id, self.raw_dir, 'class_map.json')
78+
download_google_url(self.role_id, self.raw_dir, 'role.json')
79+
80+
def process(self) -> None:
81+
f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
82+
adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
83+
adj = adj.tocoo()
84+
row = tlx.convert_to_tensor(adj.row, dtype=tlx.int64)
85+
col = tlx.convert_to_tensor(adj.col, dtype=tlx.int64)
86+
edge_index = tlx.stack([row, col], axis=0)
87+
88+
x = np.load(osp.join(self.raw_dir, 'feats.npy'))
89+
x = tlx.convert_to_tensor(x, dtype=tlx.float32)
90+
91+
ys = [-1] * x.size(0)
92+
with open(osp.join(self.raw_dir, 'class_map.json')) as f:
93+
class_map = json.load(f)
94+
for key, item in class_map.items():
95+
ys[int(key)] = item
96+
y = tlx.convert_to_tensor(ys)
97+
98+
with open(osp.join(self.raw_dir, 'role.json')) as f:
99+
role = json.load(f)
100+
101+
train_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
102+
train_mask[tlx.convert_to_tensor(role['tr'])] = True
103+
104+
val_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
105+
val_mask[tlx.convert_to_tensor(role['va'])] = True
106+
107+
test_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
108+
test_mask[tlx.convert_to_tensor(role['te'])] = True
109+
110+
data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
111+
val_mask=val_mask, test_mask=test_mask)
112+
113+
data = data if self.pre_transform is None else self.pre_transform(data)
114+
115+
self.save_data(self.collate([data]), self.processed_paths[0])

tests/datasets/test_yelp.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
4+
# os.environ['TL_BACKEND'] = 'paddle'
5+
import tensorlayerx as tlx
6+
from gammagl.datasets import Yelp
7+
8+
def yelp():
9+
dataset = Yelp()
10+
graph = dataset[0]
11+
assert len(dataset) == 1
12+
assert dataset.num_classes == 100
13+
assert dataset.num_node_features == 300
14+
assert graph.edge_index.shape[1] == 13954819
15+
assert graph.x.shape[0] == 716847

0 commit comments

Comments
 (0)