|
| 1 | +import json |
| 2 | +import os.path as osp |
| 3 | +from typing import Callable, List, Optional |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import scipy.sparse as sp |
| 7 | +import tensorlayerx as tlx |
| 8 | + |
| 9 | +from gammagl.data import Graph, InMemoryDataset, download_google_url |
| 10 | + |
| 11 | + |
| 12 | +class Yelp(InMemoryDataset): |
| 13 | + r"""The Yelp dataset from the `"GraphSAINT: Graph Sampling Based |
| 14 | + Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper, |
| 15 | + containing customer reviewers and their friendship. |
| 16 | + |
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + root: str, optional |
| 20 | + Root directory where the dataset should be saved. |
| 21 | + transform: callable, optional |
| 22 | + A function/transform that takes in an |
| 23 | + :obj:`gammagl.data.Graph` object and returns a transformed |
| 24 | + version. The data object will be transformed before every access. |
| 25 | + (default: :obj:`None`) |
| 26 | + pre_transform: callable, optional |
| 27 | + A function/transform that takes in |
| 28 | + an :obj:`gammagl.data.Graph` object and returns a |
| 29 | + transformed version. The data object will be transformed before |
| 30 | + being saved to disk. (default: :obj:`None`) |
| 31 | + force_reload (bool, optional): Whether to re-process the dataset. |
| 32 | + (default: :obj:`False`) |
| 33 | +
|
| 34 | + Tip |
| 35 | + --- |
| 36 | + .. list-table:: |
| 37 | + :widths: 10 10 10 10 10 |
| 38 | + :header-rows: 1 |
| 39 | + |
| 40 | + * - #nodes |
| 41 | + - #edges |
| 42 | + - #features |
| 43 | + - #tasks |
| 44 | + * - 716,847 |
| 45 | + - 13,954,819 |
| 46 | + - 300 |
| 47 | + - 100 |
| 48 | + """ |
| 49 | + |
| 50 | + adj_full_id = '1Juwx8HtDwSzmVIJ31ooVa1WljI4U5JnA' |
| 51 | + feats_id = '1Zy6BZH_zLEjKlEFSduKE5tV9qqA_8VtM' |
| 52 | + class_map_id = '1VUcBGr0T0-klqerjAjxRmAqFuld_SMWU' |
| 53 | + role_id = '1NI5pa5Chpd-52eSmLW60OnB3WS5ikxq_' |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + root: str = None, |
| 58 | + transform: Optional[Callable] = None, |
| 59 | + pre_transform: Optional[Callable] = None, |
| 60 | + force_reload: bool = False, |
| 61 | + ) -> None: |
| 62 | + super().__init__(root, transform, pre_transform, |
| 63 | + force_reload=force_reload) |
| 64 | + self.data, self.slices = self.load_data(self.processed_paths[0]) |
| 65 | + |
| 66 | + @property |
| 67 | + def raw_file_names(self) -> List[str]: |
| 68 | + return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json'] |
| 69 | + |
| 70 | + @property |
| 71 | + def processed_file_names(self) -> str: |
| 72 | + return 'data.pt' |
| 73 | + |
| 74 | + def download(self) -> None: |
| 75 | + download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz') |
| 76 | + download_google_url(self.feats_id, self.raw_dir, 'feats.npy') |
| 77 | + download_google_url(self.class_map_id, self.raw_dir, 'class_map.json') |
| 78 | + download_google_url(self.role_id, self.raw_dir, 'role.json') |
| 79 | + |
| 80 | + def process(self) -> None: |
| 81 | + f = np.load(osp.join(self.raw_dir, 'adj_full.npz')) |
| 82 | + adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape']) |
| 83 | + adj = adj.tocoo() |
| 84 | + row = tlx.convert_to_tensor(adj.row, dtype=tlx.int64) |
| 85 | + col = tlx.convert_to_tensor(adj.col, dtype=tlx.int64) |
| 86 | + edge_index = tlx.stack([row, col], axis=0) |
| 87 | + |
| 88 | + x = np.load(osp.join(self.raw_dir, 'feats.npy')) |
| 89 | + x = tlx.convert_to_tensor(x, dtype=tlx.float32) |
| 90 | + |
| 91 | + ys = [-1] * x.size(0) |
| 92 | + with open(osp.join(self.raw_dir, 'class_map.json')) as f: |
| 93 | + class_map = json.load(f) |
| 94 | + for key, item in class_map.items(): |
| 95 | + ys[int(key)] = item |
| 96 | + y = tlx.convert_to_tensor(ys) |
| 97 | + |
| 98 | + with open(osp.join(self.raw_dir, 'role.json')) as f: |
| 99 | + role = json.load(f) |
| 100 | + |
| 101 | + train_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool) |
| 102 | + train_mask[tlx.convert_to_tensor(role['tr'])] = True |
| 103 | + |
| 104 | + val_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool) |
| 105 | + val_mask[tlx.convert_to_tensor(role['va'])] = True |
| 106 | + |
| 107 | + test_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool) |
| 108 | + test_mask[tlx.convert_to_tensor(role['te'])] = True |
| 109 | + |
| 110 | + data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask, |
| 111 | + val_mask=val_mask, test_mask=test_mask) |
| 112 | + |
| 113 | + data = data if self.pre_transform is None else self.pre_transform(data) |
| 114 | + |
| 115 | + self.save_data(self.collate([data]), self.processed_paths[0]) |
0 commit comments