Skip to content

Commit a420e73

Browse files
guofeng97dddg617
andcommitted
[Model] Implement Grace-POT (#171)
* Grace-POT * Grace-POT * [Model] Update Model * nightly update * [Model] Update Grace_POT * nightly udpate * nightly update * nightly update --------- Co-authored-by: dddg617 <996179900@qq.com> Co-authored-by: dddg617 <75086617+dddg617@users.noreply.github.com>
1 parent 7d17435 commit a420e73

File tree

12 files changed

+894
-25
lines changed

12 files changed

+894
-25
lines changed

CONTRIBUTING.md

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Contribution is always welcomed. Please feel free to open an issue or email to y
3030
- Xingyuan Ji (Gamma Lab)
3131
- Yuxin Guo (Gamma Lab)
3232
- Zihao Zhao (Gamma Lab)
33+
- Feng Guo (Gamma Lab)
3334
- Yuxuan Shan (BUPT)
3435
- Zeyao Ma (BUPT)
3536
- Yiming Jia (BUPT)

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ CUDA_VISIBLE_DEVICES="1" TL_BACKEND="paddle" python gcn_trainer.py
406406
| [MVGRL [ICML 2020]](./examples/mvgrl) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
407407
| [InfoGraph [ICLR 2020]](./examples/infograph) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
408408
| [MERIT [IJCAI 2021]](./examples/merit) | :heavy_check_mark: | | :heavy_check_mark: | |
409+
| [GNN-POT [NeurIPS 2023]](./examples/grace_pot) | | :heavy_check_mark: | | |
409410

410411
| Heterogeneous Graph Learning | TensorFlow | PyTorch | Paddle | MindSpore |
411412
| -------------------------------------------- | ------------------ | ------------------ | ------------------ | --------- |
+290
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import argparse
2+
import os.path as osp
3+
import os
4+
# os.environ['TL_BACKEND'] = 'torch'
5+
from time import perf_counter as t
6+
import yaml
7+
from yaml import SafeLoader
8+
import numpy as np
9+
import pickle
10+
import tensorlayerx as tlx
11+
from tensorlayerx.model import TrainOneStep, WithLoss
12+
from tensorlayerx.dataflow import random_split
13+
from gammagl.layers.conv import GCNConv
14+
from gammagl.datasets import Planetoid, Coauthor, Amazon
15+
import gammagl.transforms as T
16+
17+
from gammagl.models.grace_pot import Grace_POT_Encoder, Grace_POT_Model
18+
from eval_gracepot import log_regression, MulticlassEvaluator
19+
20+
A_upper_1 = None
21+
A_upper_2 = None
22+
A_lower_1 = None
23+
A_lower_2 = None
24+
25+
class train_loss(WithLoss):
26+
def __init__(self, model, drop_edge_rate_1, drop_edge_rate_2, use_pot=False, pot_batch=-1, kappa=0.5):
27+
super(train_loss, self).__init__(backbone=model, loss_fn=None)
28+
self.drop_edge_rate_1 = drop_edge_rate_1
29+
self.drop_edge_rate_2 = drop_edge_rate_2
30+
self.use_pot = use_pot
31+
self.pot_batch = pot_batch
32+
self.kappa = kappa
33+
34+
def forward(self, model, x, edge_index, epoch, data=None):
35+
edge_index_1 = dropout_adj(edge_index, p=self.drop_edge_rate_1)[0]
36+
edge_index_2 = dropout_adj(edge_index, p=self.drop_edge_rate_2)[0]
37+
x_1, x_2 = x, x
38+
z1 = model(x_1, edge_index_1)
39+
z2 = model(x_2, edge_index_2)
40+
node_list = np.arange(z1.shape[0])
41+
np.random.shuffle(node_list)
42+
43+
batch_size = 4096 if args.dataset in ["PubMed", "Computers", "WikiCS"] else None
44+
45+
if batch_size is not None:
46+
node_list_batch = get_batch(node_list, batch_size, epoch)
47+
48+
# nce loss
49+
if batch_size is not None:
50+
z11 = z1[node_list_batch]
51+
z22 = z2[node_list_batch]
52+
nce_loss = model.loss(z11, z22)
53+
else:
54+
nce_loss = model.loss(z1, z2)
55+
56+
# pot loss
57+
if self.use_pot:
58+
# get node_list_tmp, the nodes to calculate pot_loss
59+
if self.pot_batch != -1:
60+
if batch_size is None:
61+
node_list_tmp = get_batch(node_list, self.pot_batch, epoch)
62+
else:
63+
node_list_tmp = get_batch(node_list_batch, self.pot_batch, epoch)
64+
else:
65+
# full pot batch
66+
if batch_size is None:
67+
node_list_tmp = node_list
68+
else:
69+
node_list_tmp = node_list_batch
70+
71+
z11 = tlx.gather(z1, tlx.convert_to_tensor(node_list_tmp))
72+
z22 = tlx.gather(z2, tlx.convert_to_tensor(node_list_tmp))
73+
74+
global A_upper_1, A_upper_2, A_lower_1, A_lower_2
75+
if A_upper_1 is None or A_upper_2 is None:
76+
A_upper_1, A_lower_1 = get_A_bounds(args.dataset, self.drop_edge_rate_1, args.cache)
77+
A_upper_2, A_lower_2 = get_A_bounds(args.dataset, self.drop_edge_rate_2, args.cache)
78+
79+
pot_loss_1 = model.pot_loss(z11, z22, data.x, data.edge_index, edge_index_1, local_changes=self.drop_edge_rate_1,
80+
node_list=node_list_tmp, A_upper=A_upper_1, A_lower=A_lower_1)
81+
pot_loss_2 = model.pot_loss(z22, z11, data.x, data.edge_index, edge_index_2, local_changes=self.drop_edge_rate_2,
82+
node_list=node_list_tmp, A_upper=A_upper_2, A_lower=A_lower_2)
83+
pot_loss = (pot_loss_1 + pot_loss_2) / 2
84+
loss = (1 - self.kappa) * nce_loss + self.kappa * pot_loss
85+
else:
86+
loss = nce_loss
87+
88+
return loss
89+
90+
91+
def test(model, data, dataset, split):
92+
model.set_eval()
93+
z = model(data.x, data.edge_index)
94+
evaluator = MulticlassEvaluator()
95+
res = log_regression(z, dataset, evaluator, split='preloaded', num_epochs=3000, preload_split=split)
96+
return res
97+
98+
def get_dataset(path, name):
99+
assert name in ['Cora', 'CiteSeer', 'PubMed', 'Coauthor-CS', 'Coauthor-Phy', 'Computers', 'Photo']
100+
name = 'dblp' if name == 'DBLP' else name
101+
102+
if name == 'Coauthor-CS':
103+
return Coauthor(root=path, name='cs', transform=T.NormalizeFeatures())
104+
105+
if name == 'Coauthor-Phy':
106+
return Coauthor(root=path, name='physics', transform=T.NormalizeFeatures())
107+
108+
if name == 'Computers':
109+
return Amazon(root=path, name='computers', transform=T.NormalizeFeatures())
110+
111+
if name == 'Photo':
112+
return Amazon(root=path, name='photo', transform=T.NormalizeFeatures())
113+
114+
115+
return (Planetoid)(path, name, transform=T.NormalizeFeatures()) # public split
116+
117+
def generate_split(num_samples: int, train_ratio: float, val_ratio: float):
118+
train_len = int(num_samples * train_ratio)
119+
val_len = int(num_samples * val_ratio)
120+
test_len = num_samples - train_len - val_len
121+
122+
train_set, test_set, val_set = random_split(tlx.arange(0, num_samples), (train_len, test_len, val_len))
123+
124+
idx_train, idx_test, idx_val = train_set.indices, test_set.indices, val_set.indices
125+
train_mask = tlx.zeros((num_samples,)).to(tlx.bool)
126+
test_mask = tlx.zeros((num_samples,)).to(tlx.bool)
127+
val_mask = tlx.zeros((num_samples,)).to(tlx.bool)
128+
129+
train_mask[idx_train] = True
130+
test_mask[idx_test] = True
131+
val_mask[idx_val] = True
132+
133+
return train_mask, test_mask, val_mask
134+
135+
def get_batch(node_list, batch_size, epoch):
136+
num_nodes = len(node_list)
137+
num_batches = (num_nodes - 1) // batch_size + 1
138+
i = epoch % num_batches
139+
if (i + 1) * batch_size >= len(node_list):
140+
node_list_batch = node_list[i * batch_size:]
141+
else:
142+
node_list_batch = node_list[i * batch_size:(i + 1) * batch_size]
143+
return node_list_batch
144+
145+
def get_A_bounds(dataset, drop_rate, cache):
146+
upper_lower_file = osp.join(cache, f"{dataset}_{drop_rate}_upper_lower.pkl")
147+
if osp.exists(upper_lower_file):
148+
with open(upper_lower_file, 'rb') as file:
149+
A_upper, A_lower = pickle.load(file)
150+
else:
151+
A_upper, A_lower = None, None
152+
return A_upper, A_lower
153+
154+
def filter_adj(row, col, edge_attr, mask):
155+
mask = tlx.convert_to_tensor(mask, dtype=tlx.bool)
156+
return row[mask], col[mask], None if edge_attr is None else edge_attr[mask]
157+
158+
def dropout_adj(
159+
edge_index,
160+
edge_attr = None,
161+
p = 0.5,
162+
force_undirected = False,
163+
num_nodes = None,
164+
training = True,
165+
):
166+
167+
if p < 0. or p > 1.:
168+
raise ValueError(f'Dropout probability has to be between 0 and 1 '
169+
f'(got {p}')
170+
171+
if not training or p == 0.0:
172+
return edge_index, edge_attr
173+
174+
# row, col = edge_index
175+
row = edge_index[0]
176+
col = edge_index[1]
177+
178+
mask = np.random.random(tlx.get_tensor_shape(row)) >= p
179+
180+
if force_undirected:
181+
mask[row > col] = False
182+
183+
row, col, edge_attr = filter_adj(row, col, edge_attr, mask)
184+
185+
if force_undirected:
186+
edge_index = tlx.stack(
187+
[tlx.concat([row, col], 0),
188+
tlx.concat([col, row], 0)], dim=0)
189+
if edge_attr is not None:
190+
edge_attr = tlx.concat([edge_attr, edge_attr], 0)
191+
else:
192+
edge_index = tlx.stack([row, col])
193+
194+
return edge_index, edge_attr
195+
196+
197+
def main(args):
198+
if args.gpu_id >= 0:
199+
tlx.set_device(device='GPU', id=args.gpu_id)
200+
else:
201+
tlx.set_device(device='CPU')
202+
203+
config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset]
204+
# for hyperparameter tuning
205+
if args.drop_1 != -1:
206+
config['drop_edge_rate_1'] = args.drop_1
207+
if args.drop_2 != -1:
208+
config['drop_edge_rate_2'] = args.drop_2
209+
if args.tau != -1:
210+
config['tau'] = args.tau
211+
if args.num_epochs != -1:
212+
config['num_epochs'] = args.num_epochs
213+
print(args)
214+
print(config)
215+
216+
learning_rate = config['learning_rate']
217+
num_hidden = config['num_hidden']
218+
num_proj_hidden = config['num_proj_hidden']
219+
activation = ({'relu': tlx.nn.ReLU, 'prelu': tlx.nn.PRelu()})[config['activation']]
220+
base_model = ({'GCNConv': GCNConv})[config['base_model']]
221+
num_layers = config['num_layers']
222+
223+
drop_edge_rate_1 = config['drop_edge_rate_1']
224+
drop_edge_rate_2 = config['drop_edge_rate_2']
225+
tau = config['tau']
226+
num_epochs = config['num_epochs']
227+
weight_decay = config['weight_decay']
228+
use_pot = args.use_pot
229+
kappa = args.kappa
230+
pot_batch = args.pot_batch
231+
232+
dataset = get_dataset(args.path, args.dataset)
233+
data = dataset[0]
234+
235+
# generate split
236+
if args.dataset in ["Cora", "CiteSeer", "PubMed"]:
237+
split = data.train_mask, data.val_mask, data.test_mask
238+
print("Public Split")
239+
else:
240+
split = generate_split(data.num_nodes, train_ratio=0.1, val_ratio=0.1)
241+
print("Random Split")
242+
243+
encoder = Grace_POT_Encoder(dataset.num_features, num_hidden, activation,
244+
base_model=base_model, k=num_layers)
245+
model = Grace_POT_Model(encoder, num_hidden, num_proj_hidden, tau, dataset=args.dataset, cached=args.cache)
246+
train_weights = model.trainable_weights
247+
optimizer = tlx.optimizers.Adam(lr=learning_rate, weight_decay=weight_decay)
248+
loss_func = train_loss(model, drop_edge_rate_1, drop_edge_rate_2, use_pot, pot_batch, kappa)
249+
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
250+
251+
#timing
252+
start = t()
253+
prev = start
254+
for epoch in range(1, num_epochs + 1):
255+
model.set_train()
256+
loss=train_one_step(model, data.x, data.edge_index, epoch ,data)
257+
now = t()
258+
print(f'(T) | Epoch={epoch:03d}, loss={loss:.4f}, '
259+
f'this epoch {now - prev:.4f}, total {now - start:.4f}')
260+
if epoch % 100 == 0:
261+
res = test(model, data, dataset, split)
262+
print(res)
263+
prev = now
264+
265+
print("=== Final ===")
266+
res = test(model, data, dataset, split)
267+
print(res)
268+
269+
270+
if __name__ == '__main__':
271+
# parameters setting
272+
parser = argparse.ArgumentParser()
273+
parser.add_argument('--path', type=str, default="./")
274+
parser.add_argument('--cache', type=str, default="./")
275+
parser.add_argument('--dataset', type=str, default='Cora')
276+
parser.add_argument('--gpu_id', type=int, default=0)
277+
parser.add_argument('--config', type=str, default='./config.yaml')
278+
parser.add_argument('--use_pot', default=True) # whether to use pot in loss
279+
parser.add_argument('--kappa', type=float, default=0.5)
280+
parser.add_argument('--pot_batch', type=int, default=-1)
281+
parser.add_argument('--drop_1', type=float, default=0.4)
282+
parser.add_argument('--drop_2', type=float, default=0.3)
283+
parser.add_argument('--tau', type=float, default=0.9) # temperature of nce loss
284+
parser.add_argument('--num_epochs',type=int,default=-1)
285+
parser.add_argument('--save_file', type=str, default=".")
286+
parser.add_argument('--seed', type=int, default=12345)
287+
args = parser.parse_args()
288+
main(args)
289+
290+

examples/grace_pot/README.md

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# GammaGL Implementation of GRACE-POT
2+
This GammaGL example implements the model proposed in the paper [xxxx](https://arxiv.org/).
3+
4+
Author's code:
5+
6+
## Example Implementor
7+
8+
This example was implemented by Siyuan Zhang
9+
10+
## Datasets
11+
12+
##### Unsupervised Node Classification Datasets:
13+
14+
'Cora', 'Citeseer' and 'Pubmed'
15+
16+
| Dataset | # Nodes | # Edges | # Classes |
17+
| -------- | ------- | ------- | --------- |
18+
| Cora | 2,708 | 10,556 | 7 |
19+
| Pubmed | 19,717 | 88,651 | 3 |
20+
| Photo | 7,650 | 238,162 | 8 |
21+
22+
23+
## How to run examples
24+
Fisrt, make the directories for datasets and bounds to save
25+
``` bash
26+
mkdir ~/datasets
27+
mkdir ~/datasets/bounds
28+
```
29+
Then, go into the directory of a model. If you want to set the parameters, you should modify the ocnfiguration files in the directory ("config.yaml" for GRACE). The following is the command line to run each model (dataset used is Cora for example):
30+
```bash
31+
# original GRACE
32+
python GRACE_POT_trainer.py --dataset Cora --gpu_id 0
33+
# GRACE + POT
34+
python GRACE_POT_trainer.py --dataset Cora --gpu_id 0 --use_pot --kappa 0.4
35+
```
36+
The result will be appended to the file "res/{dataset_name}_base_temp.csv" and "res/{dataset_name}_pot_temp.csv" respectively. You can also set the parameter "save_file" to specify the file to save results. We use minibatch to reduce the memory occupation, you can modify it in the code.
37+
38+
## Performance
39+
40+
```
41+
| Author's Code | GAMMAGL's Code |
42+
| Dataset | Metrics | GRACE |GRACE-POT| GRACE |GRACE-POT|
43+
|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|
44+
| Cora | Mi-F1 | 78.2 | 79.2 | 78.2 | 82.2 |
45+
| Cora | Ma-F1 | 76.8 | 77.8 | 77.1 | 81.3 |
46+
| PubMed | Mi-F1 | 81.6 | 82.0 | 81.6 | 82.0 |
47+
| PubMed | Ma-F1 | 81.7 | 82.4 | 80.5 | 80.1 |
48+
| Photo | Mi-F1 | 91.2 | 91.8 | 89.8 | 90.0 |
49+
| Photo | Ma-F1 | 89.2 | 90.0 | 88.5 | 87.9 |

0 commit comments

Comments
 (0)