Skip to content

Commit 9f86d5e

Browse files
authored
Merge pull request #208 from gyzhou2000/gnnlfhf
[Model & dataset] add model gnnlfhf and dataset yelp
2 parents c8d05e2 + 5c7ccfe commit 9f86d5e

File tree

9 files changed

+531
-3
lines changed

9 files changed

+531
-3
lines changed

examples/gnnlfhf/gnnlfhf_trainer.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import os
2+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3+
# os.environ['TL_BACKEND'] = 'torch'
4+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
5+
import sys
6+
import argparse
7+
sys.path.insert(0, os.path.abspath('./'))
8+
import tensorlayerx as tlx
9+
from gammagl.datasets import Planetoid
10+
from gammagl.utils import mask_to_index
11+
from gammagl.models import GNNLFHFModel
12+
from tensorlayerx.model import TrainOneStep, WithLoss
13+
14+
15+
class SemiSpvzLoss(WithLoss):
16+
def __init__(self, net, loss_fn):
17+
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)
18+
19+
def forward(self, data, y):
20+
logits = self.backbone_network(data['x'])
21+
train_logits = tlx.gather(logits, data['train_idx'])
22+
train_y = tlx.gather(data['y'], data['train_idx'])
23+
loss = self._loss_fn(train_logits, train_y)
24+
25+
l2_reg = sum((tlx.reduce_sum(param ** 2) for param in self.backbone_network.reg_params))
26+
loss = loss + data["reg_lambda"] / 2 * l2_reg
27+
28+
return loss
29+
30+
31+
def calculate_acc(logits, y, metrics):
32+
"""
33+
Args:
34+
logits: node logits
35+
y: node labels
36+
metrics: tensorlayerx.metrics
37+
Returns:
38+
rst
39+
"""
40+
41+
metrics.update(logits, y)
42+
rst = metrics.result()
43+
metrics.reset()
44+
return rst
45+
46+
47+
def main(args):
48+
# load datasets
49+
if str.lower(args.dataset) not in ['cora','pubmed','citeseer']:
50+
raise ValueError('Unknown dataset: {}'.format(args.dataset))
51+
dataset = Planetoid(args.dataset_path, args.dataset)
52+
graph = dataset[0]
53+
54+
# for mindspore, it should be passed into node indices
55+
train_idx = mask_to_index(graph.train_mask)
56+
test_idx = mask_to_index(graph.test_mask)
57+
val_idx = mask_to_index(graph.val_mask)
58+
59+
net = GNNLFHFModel(in_channels = graph.num_features,
60+
out_channels = dataset.num_classes,
61+
hidden_dim = args.hidden_dim,
62+
model_type = args.model_type,
63+
model_form = args.model_form,
64+
edge_index = graph.edge_index,
65+
x = graph.x,
66+
alpha = args.alpha,
67+
mu = args.mu,
68+
beta = args.beta,
69+
niter = args.niter,
70+
drop_rate = args.drop_rate,
71+
num_layers = args.num_layers,
72+
name = "GNNLFHF")
73+
74+
optimizer = tlx.optimizers.Adam(lr=args.lr)
75+
metrics = tlx.metrics.Accuracy()
76+
train_weights = net.trainable_weights
77+
78+
loss_func = SemiSpvzLoss(net, tlx.losses.softmax_cross_entropy_with_logits)
79+
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
80+
81+
data = {
82+
"x": graph.x,
83+
"y": graph.y,
84+
"edge_index": graph.edge_index,
85+
"train_idx": train_idx,
86+
"test_idx": test_idx,
87+
"val_idx": val_idx,
88+
"num_nodes": graph.num_nodes,
89+
"reg_lambda": args.reg_lambda
90+
}
91+
92+
best_val_acc = 0
93+
for epoch in range(args.n_epoch):
94+
net.set_train()
95+
train_loss = train_one_step(data, data['y'])
96+
net.set_eval()
97+
logits = net(data['x'])
98+
val_logits = tlx.gather(logits, data['val_idx'])
99+
val_y = tlx.gather(data['y'], data['val_idx'])
100+
val_acc = calculate_acc(val_logits, val_y, metrics)
101+
102+
print("Epoch [{:0>3d}] ".format(epoch+1)\
103+
+ " train loss: {:.4f}".format(train_loss.item())\
104+
+ " val acc: {:.4f}".format(val_acc))
105+
106+
# save best model on evaluation set
107+
if val_acc > best_val_acc:
108+
best_val_acc = val_acc
109+
net.save_weights(args.best_model_path+net.name+".npz", format='npz_dict')
110+
111+
net.load_weights(args.best_model_path+net.name+".npz", format='npz_dict')
112+
net.set_eval()
113+
logits = net(data['x'])
114+
test_logits = tlx.gather(logits, data['test_idx'])
115+
test_y = tlx.gather(data['y'], data['test_idx'])
116+
test_acc = calculate_acc(test_logits, test_y, metrics)
117+
print("Test acc: {:.4f}".format(test_acc))
118+
119+
120+
if __name__ == '__main__':
121+
# parameters setting
122+
parser = argparse.ArgumentParser()
123+
parser.add_argument("--lr", type=float, default=0.01, help="learnin rate")
124+
parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch")
125+
parser.add_argument("--hidden_dim", type=int, default=64, help="dimention of hidden layers")
126+
parser.add_argument("--drop_rate", type=float, default=0.8, help="drop_rate")
127+
parser.add_argument("--num_layers", type=int, default=2, help="number of layers")
128+
parser.add_argument("--reg_lambda", type=float, default=5e-3, help="reg_lambda")
129+
parser.add_argument('--dataset', type=str, default='cora', help='dataset')
130+
parser.add_argument("--model_type", type=str, default=r'GNN-LF', help="GNN-LF or GNN-HF")
131+
parser.add_argument("--model_form", type=str, default=r'closed', help="closed or iterative")
132+
parser.add_argument("--dataset_path", type=str, default=r'./', help="path to save dataset")
133+
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
134+
parser.add_argument("--alpha", type=float, default=0.3, help="the value of alpha")
135+
parser.add_argument("--mu", type=float, default=0.1, help="the value of mu")
136+
parser.add_argument("--beta", type=float, default=0.1, help="the value of beta")
137+
parser.add_argument("--niter", type=int, default=20, help="the value of niter")
138+
parser.add_argument("--gpu", type=int, default=0)
139+
140+
args = parser.parse_args()
141+
if args.gpu >= 0:
142+
tlx.set_device("GPU", args.gpu)
143+
else:
144+
tlx.set_device("CPU")
145+
146+
main(args)

examples/gnnlfhf/readme.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Interpreting and Unifying Graph Neural Networks with An Optimization Framework(GNNLFHF)
2+
3+
- Paper link: [https://arxiv.org/pdf/2101.11859](https://arxiv.org/pdf/2101.11859)
4+
- Author's code repo: [https://github.com/zhumeiqiBUPT/GNN-LF-HF/tree/main](https://github.com/zhumeiqiBUPT/GNN-LF-HF/tree/main). Note that the original code is implemented with PyTorch for the paper.
5+
6+
# Dataset Statics
7+
8+
| Dataset | # Nodes | # Edges | # Classes |
9+
|----------|---------|---------|-----------|
10+
| Cora | 2,708 | 10,556 | 7 |
11+
| Citeseer | 3,327 | 9,228 | 6 |
12+
| Pubmed | 19,717 | 88,651 | 3 |
13+
14+
Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid).
15+
16+
Results
17+
-------
18+
19+
```bash
20+
# available dataset: "cora", "citeseer", "pubmed"
21+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset cora --model_type GNN-LF --model_form closed --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
22+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset cora --model_type GNN-LF --model_form iterative --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
23+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset cora --model_type GNN-HF --model_form closed --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
24+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset cora --model_type GNN-HF --model_form iterative --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
25+
26+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset citeseer --model_type GNN-LF --model_form closed --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
27+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset citeseer --model_type GNN-LF --model_form iterative --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
28+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset citeseer --model_type GNN-HF --model_form closed --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
29+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset citeseer --model_type GNN-HF --model_form iterative --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
30+
31+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset pubmed --model_type GNN-LF --model_form closed --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
32+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset pubmed --model_type GNN-LF --model_form iterative --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
33+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset pubmed --model_type GNN-HF --model_form closed --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
34+
TL_BACKEND="torch" python gnnlfhf_trainer.py --dataset pubmed --model_type GNN-HF --model_form iterative --alpha 0.3 --mu 0.1 --beta 0.1 --niter 20 --lr 0.01 --hidden_dim 64 --drop_rate 0.8 --reg_lambda 5e-3
35+
```
36+
37+
| Dataset | Model | Paper | Our(th) |
38+
| -------- | ------------- | ---------- | ---------- |
39+
| cora | GNN-LF-closed | 83.70±0.14 | 82.05±0.98 |
40+
| cora | GNN-LF-iter | 83.53±0.24 | 81.81±0.65 |
41+
| cora | GNN-HF-closed | 83.96±0.22 | 82.48±1.18 |
42+
| cora | GNN-HF-iter | 83.79±0.29 | 81.28±0.69 |
43+
| citeseer | GNN-LF-closed | 71.98±0.33 | 70.51±1.08 |
44+
| citeseer | GNN-LF-iter | 71.92±0.24 | 71.11±1.38 |
45+
| citeseer | GNN-HF-closed | 72.30±0.28 | 70.24±1.01 |
46+
| citeseer | GNN-HF-iter | 72.03±0.36 | 70.14±1.52 |
47+
| pubmed | GNN-LF-closed | 80.34±0.18 | 75.14±0.89 |
48+
| pubmed | GNN-LF-iter | 80.33±0.20 | 76.68±0.58 |
49+
| pubmed | GNN-HF-closed | 80.41±0.25 | 76.36±0.71 |
50+
| pubmed | GNN-HF-iter | 80.54±0.25 | 78.02±0.28 |

gammagl/data/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .heterograph import HeteroGraph
33
from .dataset import Dataset
44
from .batch import BatchGraph
5-
from .download import download_url
5+
from .download import download_url, download_google_url
66
from .in_memory_dataset import InMemoryDataset
77
from .extract import extract_zip, extract_tar
88
from .utils import global_config_init
@@ -14,6 +14,7 @@
1414
'HeteroGraph',
1515
'Dataset',
1616
'download_url',
17+
'download_google_url',
1718
'InMemoryDataset',
1819
'extract_zip',
1920
'extract_tar',

gammagl/data/download.py

+7
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,10 @@ def download_url(url: str, folder: str, log: bool = True,
6767
pbar.update(chunk_size)
6868

6969
return path
70+
71+
72+
def download_google_url(id: str, folder: str,
73+
filename: str, log: bool = True):
74+
r"""Downloads the content of a Google Drive ID to a specific folder."""
75+
url = f'https://drive.usercontent.google.com/download?id={id}&confirm=t'
76+
return download_url(url, folder, log, filename)

gammagl/datasets/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .wikics import WikiCS
2020
from .blogcatalog import BlogCatalog
2121
from .molecule_net import MoleculeNet
22+
from .yelp import Yelp
2223

2324
__all__ = [
2425
'Amazon',
@@ -40,7 +41,8 @@
4041
'AMiner',
4142
'PolBlogs',
4243
'WikiCS',
43-
'MoleculeNet'
44+
'MoleculeNet',
45+
'Yelp'
4446
]
4547

4648
classes = __all__

gammagl/datasets/yelp.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import json
2+
import os.path as osp
3+
from typing import Callable, List, Optional
4+
5+
import numpy as np
6+
import scipy.sparse as sp
7+
import tensorlayerx as tlx
8+
9+
from gammagl.data import Graph, InMemoryDataset, download_google_url
10+
11+
12+
class Yelp(InMemoryDataset):
13+
r"""The Yelp dataset from the `"GraphSAINT: Graph Sampling Based
14+
Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper,
15+
containing customer reviewers and their friendship.
16+
17+
Parameters
18+
----------
19+
root: str, optional
20+
Root directory where the dataset should be saved.
21+
transform: callable, optional
22+
A function/transform that takes in an
23+
:obj:`gammagl.data.Graph` object and returns a transformed
24+
version. The data object will be transformed before every access.
25+
(default: :obj:`None`)
26+
pre_transform: callable, optional
27+
A function/transform that takes in
28+
an :obj:`gammagl.data.Graph` object and returns a
29+
transformed version. The data object will be transformed before
30+
being saved to disk. (default: :obj:`None`)
31+
force_reload (bool, optional): Whether to re-process the dataset.
32+
(default: :obj:`False`)
33+
34+
Tip
35+
---
36+
.. list-table::
37+
:widths: 10 10 10 10 10
38+
:header-rows: 1
39+
40+
* - #nodes
41+
- #edges
42+
- #features
43+
- #tasks
44+
* - 716,847
45+
- 13,954,819
46+
- 300
47+
- 100
48+
"""
49+
50+
adj_full_id = '1Juwx8HtDwSzmVIJ31ooVa1WljI4U5JnA'
51+
feats_id = '1Zy6BZH_zLEjKlEFSduKE5tV9qqA_8VtM'
52+
class_map_id = '1VUcBGr0T0-klqerjAjxRmAqFuld_SMWU'
53+
role_id = '1NI5pa5Chpd-52eSmLW60OnB3WS5ikxq_'
54+
55+
def __init__(
56+
self,
57+
root: str = None,
58+
transform: Optional[Callable] = None,
59+
pre_transform: Optional[Callable] = None,
60+
force_reload: bool = False,
61+
) -> None:
62+
super().__init__(root, transform, pre_transform,
63+
force_reload=force_reload)
64+
self.data, self.slices = self.load_data(self.processed_paths[0])
65+
66+
@property
67+
def raw_file_names(self) -> List[str]:
68+
return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json']
69+
70+
@property
71+
def processed_file_names(self) -> str:
72+
return 'data.pt'
73+
74+
def download(self) -> None:
75+
download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz')
76+
download_google_url(self.feats_id, self.raw_dir, 'feats.npy')
77+
download_google_url(self.class_map_id, self.raw_dir, 'class_map.json')
78+
download_google_url(self.role_id, self.raw_dir, 'role.json')
79+
80+
def process(self) -> None:
81+
f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
82+
adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
83+
adj = adj.tocoo()
84+
row = tlx.convert_to_tensor(adj.row, dtype=tlx.int64)
85+
col = tlx.convert_to_tensor(adj.col, dtype=tlx.int64)
86+
edge_index = tlx.stack([row, col], axis=0)
87+
88+
x = np.load(osp.join(self.raw_dir, 'feats.npy'))
89+
x = tlx.convert_to_tensor(x, dtype=tlx.float32)
90+
91+
ys = [-1] * x.size(0)
92+
with open(osp.join(self.raw_dir, 'class_map.json')) as f:
93+
class_map = json.load(f)
94+
for key, item in class_map.items():
95+
ys[int(key)] = item
96+
y = tlx.convert_to_tensor(ys)
97+
98+
with open(osp.join(self.raw_dir, 'role.json')) as f:
99+
role = json.load(f)
100+
101+
train_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
102+
train_mask[tlx.convert_to_tensor(role['tr'])] = True
103+
104+
val_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
105+
val_mask[tlx.convert_to_tensor(role['va'])] = True
106+
107+
test_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
108+
test_mask[tlx.convert_to_tensor(role['te'])] = True
109+
110+
data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
111+
val_mask=val_mask, test_mask=test_mask)
112+
113+
data = data if self.pre_transform is None else self.pre_transform(data)
114+
115+
self.save_data(self.collate([data]), self.processed_paths[0])

gammagl/models/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
from .graphormer import Graphormer
5555
from .fusedgat import FusedGATModel
5656
from .hid_net import Hid_net
57+
from .gnnlfhf import GNNLFHFModel
58+
5759
__all__ = [
5860
'GCNModel',
5961
'GATModel',
@@ -110,7 +112,8 @@
110112
'SFGCNModel',
111113
'Graphormer',
112114
'FusedGATModel',
113-
'hid_net'
115+
'hid_net',
116+
'GNNLFHFModel'
114117
]
115118

116119
classes = __all__

0 commit comments

Comments
 (0)