Skip to content

Commit d54dca8

Browse files
authored
[Model] Implement model AM-GCN (#184)
* [Model] Implement model AM-GCN * [Model] Update AMGCN
1 parent d187b83 commit d54dca8

File tree

5 files changed

+263
-0
lines changed

5 files changed

+263
-0
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ CUDA_VISIBLE_DEVICES="1" TL_BACKEND="paddle" python gcn_trainer.py
397397
| [CAGCN [NeurIPS 2021]](./examples/cagcn) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
398398
| [DR-GST [WWW 2022]](./examples/drgst) | :heavy_check_mark: | :heavy_check_mark: | | |
399399
| [Specformer [ICLR 2023]](./examples/specformer) | | :heavy_check_mark: | :heavy_check_mark: | |
400+
| [AM-GCN [KDD 2020]](./examples/amgcn) | | :heavy_check_mark: | | |
400401

401402
| Contrastive Learning | TensorFlow | PyTorch | Paddle | MindSpore |
402403
| ---------------------------------------------- | ------------------ | ------------------ | ------------------ | --------- |

examples/amgcn/amgcn_trainer.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# !/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
"""
4+
@File : gcn_trainer.py
5+
@Time : 2021/11/02 22:05:55
6+
@Author : hanhui
7+
"""
8+
9+
import os
10+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
11+
# os.environ['TL_BACKEND'] = 'torch'
12+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
13+
import numpy as np
14+
import scipy.sparse as sp
15+
import argparse
16+
import tensorlayerx as tlx
17+
from gammagl.datasets import Planetoid
18+
from gammagl.models import SFGCNModel
19+
from gammagl.utils import add_self_loops, mask_to_index
20+
from tensorlayerx.model import TrainOneStep, WithLoss
21+
from sklearn.metrics.pairwise import cosine_similarity as cos
22+
23+
def knn(feat, num_node, k):
24+
adj = np.zeros((num_node, num_node), dtype=np.int64)
25+
dist = cos(tlx.to_device(feat, "CPU"))
26+
col = np.argpartition(dist, -(k + 1), axis=1)[:, -(k + 1):].flatten()
27+
adj[np.arange(num_node).repeat(k + 1), col] = 1
28+
adj = sp.coo_matrix(adj)
29+
return adj
30+
31+
def nll_loss_func(output, target):
32+
return -tlx.reduce_mean(tlx.gather(output, [range(tlx.get_tensor_shape(target)[0]), target]))
33+
34+
def common_loss(emb1, emb2):
35+
emb1 = emb1 - tlx.reduce_mean(emb1, axis=0, keepdims=True)
36+
emb2 = emb2 - tlx.reduce_mean(emb2, axis=0, keepdims=True)
37+
emb1 = tlx.l2_normalize(emb1, axis=1)
38+
emb2 = tlx.l2_normalize(emb2, axis=1)
39+
cov1 = tlx.matmul(emb1, tlx.transpose(emb1))
40+
cov2 = tlx.matmul(emb2, tlx.transpose(emb2))
41+
cost = tlx.reduce_mean((cov1 - cov2)**2)
42+
return cost
43+
44+
def loss_dependence(emb1, emb2, dim):
45+
R = tlx.eye(dim) - (1 / dim) * tlx.ones(shape=(dim, dim))
46+
K1 = tlx.matmul(emb1, tlx.transpose(emb1))
47+
K2 = tlx.matmul(emb2, tlx.transpose(emb2))
48+
RK1 = tlx.matmul(R, K1)
49+
RK2 = tlx.matmul(R, K2)
50+
HSIC = tlx.matmul(RK1, RK2)
51+
HSIC = tlx.reduce_sum(tlx.diag(HSIC))
52+
return HSIC
53+
54+
class SemiSpvzLoss(WithLoss):
55+
def __init__(self, net):
56+
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=None)
57+
58+
def forward(self, data, y):
59+
logits, att, emb1, com1, com2, emb2, emb = self.backbone_network(data['x'], data['edge_index_s'], data['edge_index_f'])
60+
loss_class = nll_loss_func(tlx.gather(logits, data['train_idx']), tlx.gather(data['y'], data['train_idx']))
61+
loss_dep = (loss_dependence(emb1, com1, data['num_nodes']) + loss_dependence(emb2, com2, data['num_nodes'])) / 2
62+
loss_com = common_loss(com1, com2)
63+
loss = loss_class + data['beta'] * loss_dep + data['theta'] * loss_com
64+
65+
return loss
66+
67+
68+
def calculate_acc(logits, y, metrics):
69+
"""
70+
Args:
71+
logits: node logits
72+
y: node labels
73+
metrics: tensorlayerx.metrics
74+
75+
Returns:
76+
rst
77+
"""
78+
79+
metrics.update(logits, y)
80+
rst = metrics.result()
81+
metrics.reset()
82+
return rst
83+
84+
85+
def main(args):
86+
# load datasets
87+
# set_device(5)
88+
if str.lower(args.dataset) not in ['cora','pubmed','citeseer']:
89+
raise ValueError('Unknown dataset: {}'.format(args.dataset))
90+
dataset = Planetoid(args.dataset_path, args.dataset)
91+
graph = dataset[0]
92+
edge_index_f = knn(graph.x, graph.num_nodes, args.k)
93+
edge_index_f = tlx.convert_to_tensor([edge_index_f.row, edge_index_f.col], dtype=tlx.int64)
94+
edge_index_s, _ = add_self_loops(graph.edge_index, num_nodes=graph.num_nodes, n_loops=args.self_loops)
95+
96+
# edge_weight = tlx.convert_to_tensor(calc_gcn_norm(edge_index, graph.num_nodes))
97+
98+
# for mindspore, it should be passed into node indices
99+
train_idx = mask_to_index(graph.train_mask)
100+
test_idx = mask_to_index(graph.test_mask)
101+
val_idx = mask_to_index(graph.val_mask)
102+
103+
net = SFGCNModel(num_feat=dataset.num_node_features,
104+
num_class=dataset.num_classes,
105+
num_hidden1=args.hidden1,
106+
num_hidden2=args.hidden2,
107+
dropout=args.drop_rate)
108+
109+
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef)
110+
metrics = tlx.metrics.Accuracy()
111+
train_weights = net.trainable_weights
112+
113+
loss_func = SemiSpvzLoss(net)
114+
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
115+
116+
data = {
117+
"x": graph.x,
118+
"y": graph.y,
119+
"edge_index_s": edge_index_s,
120+
"edge_index_f": edge_index_f,
121+
"train_idx": train_idx,
122+
"test_idx": test_idx,
123+
"val_idx": val_idx,
124+
"num_nodes": graph.num_nodes,
125+
"beta": args.beta,
126+
"theta": args.theta
127+
}
128+
129+
best_val_acc = 0
130+
for epoch in range(args.n_epoch):
131+
net.set_train()
132+
train_loss = train_one_step(data, graph.y)
133+
net.set_eval()
134+
logits, att, emb1, com1, com2, emb2, emb = net(data['x'], data['edge_index_s'], data['edge_index_f'])
135+
val_logits = tlx.gather(logits, data['val_idx'])
136+
val_y = tlx.gather(data['y'], data['val_idx'])
137+
val_acc = calculate_acc(val_logits, val_y, metrics)
138+
139+
print("Epoch [{:0>3d}] ".format(epoch+1)\
140+
+ " train loss: {:.4f}".format(train_loss.item())\
141+
+ " val acc: {:.4f}".format(val_acc))
142+
143+
# save best model on evaluation set
144+
if val_acc > best_val_acc:
145+
best_val_acc = val_acc
146+
net.save_weights(args.best_model_path+net.name+".npz", format='npz_dict')
147+
148+
net.load_weights(args.best_model_path+net.name+".npz", format='npz_dict')
149+
if tlx.BACKEND == 'torch':
150+
net.to(data['x'].device)
151+
net.set_eval()
152+
logits, att, emb1, com1, com2, emb2, emb = net(data['x'], data['edge_index_s'], data['edge_index_f'])
153+
test_logits = tlx.gather(logits, data['test_idx'])
154+
test_y = tlx.gather(data['y'], data['test_idx'])
155+
test_acc = calculate_acc(test_logits, test_y, metrics)
156+
print("Test acc: {:.4f}".format(test_acc))
157+
158+
159+
if __name__ == '__main__':
160+
# parameters setting
161+
parser = argparse.ArgumentParser()
162+
parser.add_argument("--lr", type=float, default=0.01, help="learnin rate")
163+
parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch")
164+
parser.add_argument("--hidden1", type=int, default=32, help="dimention of hidden layers")
165+
parser.add_argument("--hidden2", type=int, default=16, help="dimention of hidden layers")
166+
parser.add_argument("--drop_rate", type=float, default=0.5, help="drop_rate")
167+
parser.add_argument("--beta", type=float, default=0.000005, help="drop_rate")
168+
parser.add_argument("--theta", type=float, default=0.001, help="drop_rate")
169+
parser.add_argument("--l2_coef", type=float, default=5e-4, help="l2 loss coeficient")
170+
parser.add_argument('--dataset', type=str, default='cora', help='dataset')
171+
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
172+
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
173+
parser.add_argument("--self_loops", type=int, default=1, help="number of graph self-loop")
174+
parser.add_argument("--k", type=int, default=7, help="dimention of hidden layers")
175+
# parser.add_argument("--n", type=int, default=10, help="dimention of hidden layers")
176+
args = parser.parse_args()
177+
178+
main(args)

examples/amgcn/readme.md

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# AM-GCN: Adaptive Multi-channel Graph Convolutional Networks (AM-GCN)
2+
3+
- Paper link: [https://dl.acm.org/doi/10.1145/3394486.3403177](https://dl.acm.org/doi/10.1145/3394486.3403177)
4+
- Author's code repo: [https://github.com/BUPT-GAMMA/AM-GCN](https://github.com/BUPT-GAMMA/AM-GCN). Note that the original code is
5+
implemented with PyTorch for the paper.
6+
7+
# Dataset Statics
8+
9+
| Dataset | # Nodes | # Edges | # Classes |
10+
| -------- | ------- | ------- | --------- |
11+
| Cora | 2,708 | 10,556 | 7 |
12+
| Citeseer | 3,327 | 9,228 | 6 |
13+
| Pubmed | 19,717 | 88,651 | 3 |
14+
15+
Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid).
16+
17+
Results
18+
-------
19+
20+
```bash
21+
# available dataset: "cora", "citeseer", "pubmed"
22+
TL_BACKEND=torch python amgcn_trainer.py --dataset cora --lr 0.0005 --k 6 --hidden1 512 --hidden2 32 --drop_rate 0.5 --beta 1e-10 --theta 0.0001
23+
TL_BACKEND=torch python amgcn_trainer.py --dataset citeseer --lr 0.0005 --k 7 --hidden1 768 --hidden2 256 --drop_rate 0.5 --beta 5e-10 --theta 0.001
24+
TL_BACKEND=torch python amgcn_trainer.py --dataset pubmed --lr 0.0005 --k 6 --hidden1 512 --hidden2 128 --drop_rate 0.5 --beta 5e-10 --theta 0.001
25+
```
26+
27+
| Dataset | Paper | Our(th) |
28+
| -------- | ----- | ---------- |
29+
| cora | | 79.5(±0.3) |
30+
| citeseer | 73.1 | 71.7(±1.2) |
31+
| pubmed | | 64.4(±0.8) |

gammagl/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .cagcn import CAGCNModel
4848
from .cogsl import CoGSLModel
4949
from .specformer import Specformer, SpecLayer
50+
from .sfgcn import SFGCNModel
5051

5152
__all__ = [
5253
'GCNModel',
@@ -100,6 +101,7 @@
100101
'CAGCNModel',
101102
'CoGSLModel',
102103
'Specformer',
104+
'SFGCNModel'
103105
]
104106

105107
classes = __all__

gammagl/models/sfgcn.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import tensorlayerx as tlx
2+
import tensorlayerx.nn as nn
3+
from gammagl.models import GCNModel
4+
5+
class Attention(nn.Module):
6+
def __init__(self, in_size, hidden_size=16):
7+
super(Attention, self).__init__()
8+
9+
self.project = nn.Sequential(
10+
nn.Linear(in_features=in_size, out_features=hidden_size),
11+
nn.Tanh(),
12+
nn.Linear(in_features=hidden_size, out_features=1)
13+
)
14+
15+
def forward(self, x):
16+
w = self.project(x)
17+
beta = tlx.softmax(w, axis=1)
18+
return tlx.reduce_sum(beta * x, axis=1), beta
19+
20+
class SFGCNModel(nn.Module):
21+
def __init__(self, num_feat, num_class, num_hidden1, num_hidden2, dropout):
22+
super(SFGCNModel, self).__init__(name="SFGCN")
23+
24+
self.SGCN1 = GCNModel(num_feat, num_hidden1, num_hidden2, dropout)
25+
self.SGCN2 = GCNModel(num_feat, num_hidden1, num_hidden2, dropout)
26+
self.CGCN = GCNModel(num_feat, num_hidden1, num_hidden2, dropout)
27+
28+
self.dropout = dropout
29+
self.a = self._get_weights("a", shape=(num_hidden2, 1),
30+
init=tlx.initializers.xavier_uniform(gain=1.414))
31+
self.attention = Attention(num_hidden2)
32+
self.tanh = nn.Tanh()
33+
34+
self.MLP = nn.Sequential(
35+
nn.Linear(in_features=num_hidden2, out_features=num_class),
36+
nn.LogSoftmax(dim=1)
37+
)
38+
39+
def forward(self, x, edge_index_s, edge_index_f):
40+
emb1 = self.SGCN1(x, edge_index_s, None, None)
41+
com1 = self.CGCN(x, edge_index_s, None, None)
42+
emb2 = self.SGCN2(x, edge_index_f, None, None)
43+
com2 = self.CGCN(x, edge_index_f, None, None)
44+
Xcom = (com1 + com2) / 2
45+
46+
# attention
47+
emb = tlx.stack([emb1, emb2, Xcom], axis=1)
48+
emb, att = self.attention(emb)
49+
output = self.MLP(emb)
50+
51+
return output, att, emb1, com1, com2, emb2, emb

0 commit comments

Comments
 (0)