Skip to content

Commit c8d05e2

Browse files
authored
Merge pull request #205 from UNAOUN/add-glnn
Add model glnn
2 parents 2403822 + 9c9f644 commit c8d05e2

File tree

10 files changed

+672
-4
lines changed

10 files changed

+672
-4
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,9 @@ Now, GammaGL supports about 60 models, we welcome everyone to use or contribute
436436
| [GGD [NeurIPS 2022]](./examples/ggd) | | :heavy_check_mark: | | :heavy_check_mark: |
437437
| [LTD [WSDM 2022]](./examples/ltd) | | :heavy_check_mark: | | :heavy_check_mark: |
438438
| [Graphormer [NeurIPS 2021]](./examples/graphormer) | | :heavy_check_mark: | | :heavy_check_mark: |
439-
| [HiD-Net [AAAI 2023]](./examples/hid_net) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
439+
| [HiD-Net [AAAI 2023]](./examples/hid_net) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
440+
| [FusedGAT [MLSys 2022]](./examples/fusedgat) | | :heavy_check_mark: | | |
441+
| [GLNN [ICLR 2022]](./examples/glnn) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
440442

441443

442444
| Contrastive Learning | TensorFlow | PyTorch | Paddle | MindSpore |

docs/source/api/gammagl.utils.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ gammagl.utils
2626
gammagl.utils.negative_sampling
2727
gammagl.utils.to_scipy_sparse_matrix
2828
gammagl.utils.read_embeddings
29-
gammagl.utils.homophily
29+
gammagl.utils.homophily
30+
gammagl.utils.get_train_val_test_split

examples/glnn/readme.md

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Graph-less Neural Networks (GLNN)
2+
3+
- Paper link: [https://arxiv.org/pdf/2110.08727](https://arxiv.org/pdf/2110.08727)
4+
- Author's code repo: [https://github.com/snap-research/graphless-neural-networks](https://github.com/snap-research/graphless-neural-networks)
5+
6+
# Dataset Statics
7+
| Dataset | # Nodes | # Edges | # Classes |
8+
| -------- | ------- | ------- | --------- |
9+
| Cora | 2,708 | 10,556 | 7 |
10+
| Citeseer | 3,327 | 9,228 | 6 |
11+
| Pubmed | 19,717 | 88,651 | 3 |
12+
| Computers| 13,752 | 491,722 | 10 |
13+
| Photo | 7,650 | 238,162 | 8 |
14+
15+
Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid), [Amazon](https://gammagl.readthedocs.io/en/latest/generated/gammagl.datasets.Amazon.html#gammagl.datasets.Amazon).
16+
17+
# Results
18+
19+
- Available dataset: "cora", "citeseer", "pubmed", "computers", "photo"
20+
- Available teacher: "SAGE", "GCN", "GAT", "APPNP", "MLP"
21+
22+
```bash
23+
TL_BACKEND="tensorflow" python train_teacher.py --dataset cora --teacher SAGE
24+
TL_BACKEND="tensorflow" python train_student.py --dataset cora --teacher SAGE
25+
TL_BACKEND="torch" python train_teacher.py --dataset cora --teacher SAGE
26+
TL_BACKEND="torch" python train_student.py --dataset cora --teacher SAGE
27+
TL_BACKEND="paddle" python train_teacher.py --dataset cora --teacher SAGE
28+
TL_BACKEND="paddle" python train_student.py --dataset cora --teacher SAGE
29+
TL_BACKEND="mindspore" python train_teacher.py --dataset cora --teacher SAGE
30+
TL_BACKEND="mindspore" python train_student.py --dataset cora --teacher SAGE
31+
```
32+
33+
| Dataset | Paper | Our(tf) | Our(th) | Our(pd) | Our(ms) |
34+
| --------- | ---------- | ---------- | ---------- | ---------- | ---------- |
35+
| Cora | 80.54±1.35 | 80.94±0.31 | 80.84±0.30 | 80.90±0.21 | 81.04±0.30 |
36+
| Citeseer | 71.77±2.01 | 70.74±0.87 | 71.34±0.55 | 71.18±1.20 | 70.58±1.14 |
37+
| Pubmed | 75.42±2.31 | 77.90±0.07 | 77.88±0.23 | 77.78±0.19 | 77.78±0.13 |
38+
| Computers | 83.03±1.87 | 83.45±0.61 | 82.78±0.47 | 83.03±0.14 | 83.40±0.45 |
39+
| Photo | 92.11±1.08 | 91.93±0.16 | 91.91±0.24 | 91.89±0.27 | 91.88±0.21 |
40+
41+
- The model performance is the average of 5 tests

examples/glnn/train.conf.yaml

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
global:
2+
num_layers: 2
3+
hidden_dim: 128
4+
learning_rate: 0.01
5+
6+
cora:
7+
SAGE:
8+
fan_out: 5,5
9+
learning_rate: 0.01
10+
dropout_ratio: 0
11+
weight_decay: 0.0005
12+
13+
GCN:
14+
hidden_dim: 64
15+
dropout_ratio: 0.8
16+
weight_decay: 0.001
17+
18+
MLP:
19+
learning_rate: 0.01
20+
weight_decay: 0.005
21+
dropout_ratio: 0.6
22+
23+
GAT:
24+
dropout_ratio: 0.6
25+
weight_decay: 0.01
26+
num_heads: 8
27+
attn_dropout_ratio: 0.3
28+
29+
APPNP:
30+
dropout_ratio: 0.5
31+
weight_decay: 0.01
32+
33+
34+
citeseer:
35+
SAGE:
36+
fan_out: 5,5
37+
learning_rate: 0.01
38+
dropout_ratio: 0
39+
weight_decay: 0.0005
40+
41+
GCN:
42+
hidden_dim: 64
43+
dropout_ratio: 0.8
44+
weight_decay: 0.001
45+
46+
MLP:
47+
learning_rate: 0.01
48+
weight_decay: 0.001
49+
dropout_ratio: 0.1
50+
51+
GAT:
52+
dropout_ratio: 0.6
53+
weight_decay: 0.01
54+
num_heads: 8
55+
attn_dropout_ratio: 0.3
56+
57+
APPNP:
58+
dropout_ratio: 0.5
59+
weight_decay: 0.01
60+
61+
pubmed:
62+
SAGE:
63+
fan_out: 5,5
64+
learning_rate: 0.01
65+
dropout_ratio: 0
66+
weight_decay: 0.0005
67+
68+
GCN:
69+
hidden_dim: 64
70+
dropout_ratio: 0.8
71+
weight_decay: 0.001
72+
73+
MLP:
74+
learning_rate: 0.005
75+
weight_decay: 0
76+
dropout_ratio: 0.4
77+
78+
GAT:
79+
dropout_ratio: 0.6
80+
weight_decay: 0.01
81+
num_heads: 8
82+
attn_dropout_ratio: 0.3
83+
84+
APPNP:
85+
dropout_ratio: 0.5
86+
weight_decay: 0.01
87+
88+
computers:
89+
SAGE:
90+
fan_out: 5,5
91+
learning_rate: 0.01
92+
dropout_ratio: 0
93+
weight_decay: 0.0005
94+
95+
GCN:
96+
hidden_dim: 64
97+
dropout_ratio: 0.8
98+
weight_decay: 0.001
99+
100+
MLP:
101+
learning_rate: 0.001
102+
weight_decay: 0.002
103+
dropout_ratio: 0.3
104+
105+
GAT:
106+
dropout_ratio: 0.6
107+
weight_decay: 0.01
108+
num_heads: 8
109+
attn_dropout_ratio: 0.3
110+
111+
APPNP:
112+
dropout_ratio: 0.5
113+
weight_decay: 0.01
114+
115+
photo:
116+
SAGE:
117+
fan_out: 5,5
118+
learning_rate: 0.01
119+
dropout_ratio: 0
120+
weight_decay: 0.0005
121+
122+
GCN:
123+
hidden_dim: 64
124+
dropout_ratio: 0.8
125+
weight_decay: 0.001
126+
127+
MLP:
128+
learning_rate: 0.005
129+
weight_decay: 0.002
130+
dropout_ratio: 0.3
131+
132+
GAT:
133+
dropout_ratio: 0.6
134+
weight_decay: 0.01
135+
num_heads: 8
136+
attn_dropout_ratio: 0.3
137+
138+
APPNP:
139+
dropout_ratio: 0.5
140+
weight_decay: 0.01

examples/glnn/train_student.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# !/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
import os
5+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
6+
# os.environ['TL_BACKEND'] = 'torch'
7+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
8+
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
9+
10+
import yaml
11+
import argparse
12+
import tensorlayerx as tlx
13+
from gammagl.datasets import Planetoid, Amazon
14+
from gammagl.models import MLP
15+
from gammagl.utils import mask_to_index
16+
from tensorlayerx.model import TrainOneStep, WithLoss
17+
18+
19+
class SemiSpvzLoss(WithLoss):
20+
def __init__(self, net, loss_fn):
21+
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)
22+
23+
def forward(self, data, teacher_logits):
24+
student_logits = self.backbone_network(data['x'])
25+
train_y = tlx.gather(data['y'], data['t_idx'])
26+
train_teacher_logits = tlx.gather(teacher_logits, data['t_idx'])
27+
train_student_logits = tlx.gather(student_logits, data['t_idx'])
28+
loss = self._loss_fn(train_y, train_student_logits, train_teacher_logits, args.lamb)
29+
return loss
30+
31+
32+
def get_training_config(config_path, model_name, dataset):
33+
with open(config_path, "r") as conf:
34+
full_config = yaml.load(conf, Loader=yaml.FullLoader)
35+
dataset_specific_config = full_config["global"]
36+
model_specific_config = full_config[dataset][model_name]
37+
38+
if model_specific_config is not None:
39+
specific_config = dict(dataset_specific_config, **model_specific_config)
40+
else:
41+
specific_config = dataset_specific_config
42+
43+
specific_config["model_name"] = model_name
44+
return specific_config
45+
46+
47+
def calculate_acc(logits, y, metrics):
48+
metrics.update(logits, y)
49+
rst = metrics.result()
50+
metrics.reset()
51+
return rst
52+
53+
54+
def kl_divergence(teacher_logits, student_logits):
55+
# convert logits to probabilities
56+
teacher_probs = tlx.softmax(teacher_logits)
57+
student_probs = tlx.softmax(student_logits)
58+
# compute KL divergence
59+
kl_div = tlx.reduce_sum(teacher_probs * (tlx.log(teacher_probs+1e-10) - tlx.log(student_probs+1e-10)), axis=-1)
60+
return tlx.reduce_mean(kl_div)
61+
62+
63+
def cal_mlp_loss(labels, student_logits, teacher_logits, lamb):
64+
loss_l = tlx.losses.softmax_cross_entropy_with_logits(student_logits, labels)
65+
loss_t = kl_divergence(teacher_logits, student_logits)
66+
return lamb * loss_l + (1 - lamb) * loss_t
67+
68+
69+
def train_student(args):
70+
# load datasets
71+
if str.lower(args.dataset) not in ['cora','pubmed','citeseer','computers','photo']:
72+
raise ValueError('Unknown dataset: {}'.format(args.dataset))
73+
if args.dataset in ['cora', 'pubmed', 'citeseer']:
74+
dataset = Planetoid(args.dataset_path, args.dataset)
75+
elif args.dataset == 'computers':
76+
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=200/13752, val_ratio=(200/13752)*1.5)
77+
elif args.dataset == 'photo':
78+
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=160/7650, val_ratio=(160/7650)*1.5)
79+
graph = dataset[0]
80+
81+
# load teacher_logits from .npy file
82+
teacher_logits = tlx.files.load_npy_to_any(path = r'./', name = f'{args.dataset}_{args.teacher}_logits.npy')
83+
teacher_logits = tlx.ops.convert_to_tensor(teacher_logits)
84+
85+
# for mindspore, it should be passed into node indices
86+
train_idx = mask_to_index(graph.train_mask)
87+
test_idx = mask_to_index(graph.test_mask)
88+
val_idx = mask_to_index(graph.val_mask)
89+
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0)
90+
91+
net = MLP(in_channels=dataset.num_node_features,
92+
hidden_channels=conf["hidden_dim"],
93+
out_channels=dataset.num_classes,
94+
num_layers=conf["num_layers"],
95+
act=tlx.nn.ReLU(),
96+
norm=None,
97+
dropout=float(conf["dropout_ratio"]))
98+
99+
optimizer = tlx.optimizers.Adam(lr=conf["learning_rate"], weight_decay=conf["weight_decay"])
100+
metrics = tlx.metrics.Accuracy()
101+
train_weights = net.trainable_weights
102+
103+
loss_func = SemiSpvzLoss(net, cal_mlp_loss)
104+
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
105+
106+
data = {
107+
"x": graph.x,
108+
"y": graph.y,
109+
"train_idx": train_idx,
110+
"test_idx": test_idx,
111+
"val_idx": val_idx,
112+
"t_idx": t_idx
113+
}
114+
115+
best_val_acc = 0
116+
for epoch in range(args.n_epoch):
117+
net.set_train()
118+
train_loss = train_one_step(data, teacher_logits)
119+
net.set_eval()
120+
logits = net(data['x'])
121+
val_logits = tlx.gather(logits, data['val_idx'])
122+
val_y = tlx.gather(data['y'], data['val_idx'])
123+
val_acc = calculate_acc(val_logits, val_y, metrics)
124+
125+
print("Epoch [{:0>3d}] ".format(epoch+1)\
126+
+ " train loss: {:.4f}".format(train_loss.item())\
127+
+ " val acc: {:.4f}".format(val_acc))
128+
129+
# save best model on evaluation set
130+
if val_acc > best_val_acc:
131+
best_val_acc = val_acc
132+
net.save_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict')
133+
134+
net.load_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict')
135+
net.set_eval()
136+
logits = net(data['x'])
137+
test_logits = tlx.gather(logits, data['test_idx'])
138+
test_y = tlx.gather(data['y'], data['test_idx'])
139+
test_acc = calculate_acc(test_logits, test_y, metrics)
140+
print("Test acc: {:.4f}".format(test_acc))
141+
142+
143+
144+
if __name__ == '__main__':
145+
# parameters setting
146+
parser = argparse.ArgumentParser()
147+
parser.add_argument("--model_config_path",type=str,default="./train.conf.yaml",help="path to modelconfigeration")
148+
parser.add_argument("--teacher", type=str, default="SAGE", help="teacher model")
149+
parser.add_argument("--lamb", type=float, default=0, help="parameter balances loss from hard labels and teacher outputs")
150+
parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch")
151+
parser.add_argument('--dataset', type=str, default="cora", help="dataset")
152+
parser.add_argument("--dataset_path", type=str, default=r'./data', help="path to save dataset")
153+
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
154+
parser.add_argument("--gpu", type=int, default=0)
155+
156+
args = parser.parse_args()
157+
158+
conf = {}
159+
if args.model_config_path is not None:
160+
conf = get_training_config(args.model_config_path, args.teacher, args.dataset)
161+
conf = dict(args.__dict__, **conf)
162+
163+
if args.gpu >= 0:
164+
tlx.set_device("GPU", args.gpu)
165+
else:
166+
tlx.set_device("CPU")
167+
168+
train_student(args)

0 commit comments

Comments
 (0)