Skip to content

Commit 66bd9b2

Browse files
xy-Jigyzhou2000
andauthored
[Model] add dfad-gnn (#212)
* [Model] add dfad-gnn * update trainer * update train_student.py * update train_student.py * update train_student.py * add readme * update readme.md * update --------- Co-authored-by: gyzhou2000 <gyzhou2000@gmail.com>
1 parent 9e04783 commit 66bd9b2

File tree

5 files changed

+455
-1
lines changed

5 files changed

+455
-1
lines changed

examples/dfad_gnn/readme.md

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Data-Free Adversarial Knowledge Distillation for Graph Neural Networks
2+
3+
- Paper link: [https://arxiv.org/pdf/2205.03811](https://arxiv.org/pdf/2205.03811)
4+
- Author's code repo: [https://anonymous.4open.science/r/DF-GNNs-EC75](https://anonymous.4open.science/r/DF-GNNs-EC75)
5+
6+
# Dataset Statics
7+
| Dataset | # Graphs | # Nodes | # Edges | # Features | # Classes |
8+
| -------- | -------- | ------- | ------- | ---------- | --------- |
9+
| MUTAG | 188 | ~17.9 | ~39.6 | 7 | 2 |
10+
11+
Refer to [TUDataset](https://gammagl.readthedocs.io/en/latest/generated/gammagl.datasets.TUDataset.html).
12+
13+
# Results
14+
15+
```bash
16+
TL_BACKEND="torch" python train_teacher.py --dataset MUTAG
17+
TL_BACKEND="torch" python train_student.py --dataset MUTAG --student gcn
18+
TL_BACKEND="torch" python train_student.py --dataset MUTAG --student gin
19+
```
20+
21+
| Dataset | Student Model | Paper | Our(th) |
22+
| --------- | ------------- | ---------- | ---------- |
23+
| MUTAG | gcn | 76.2% | 88.2% |
24+
| MUTAG | gin | 90.8% | 88.2% |
25+
| MUTAG | gat | 79.5% | 88.2% |
26+
| MUTAG | graphsage | 79.1% | 88.2% |

examples/dfad_gnn/train_student.py

+230
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import os
2+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3+
# os.environ['TL_BACKEND'] = 'torch'
4+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
5+
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
6+
7+
import argparse
8+
import tensorlayerx as tlx
9+
from gammagl.models import GINModel, DFADModel, DFADGenerator
10+
from gammagl.loader import DataLoader
11+
from gammagl.data import Graph
12+
from gammagl.datasets import TUDataset
13+
from tensorlayerx.model import TrainOneStep, WithLoss
14+
import numpy
15+
import scipy.sparse as sp
16+
from train_teacher import SemiSpvzLoss
17+
18+
class GeneratorLoss(WithLoss):
19+
def __init__(self, net, loss_fn):
20+
super(GeneratorLoss, self).__init__(backbone=net, loss_fn=loss_fn)
21+
22+
def forward(self, student_logits, teacher_logits):
23+
loss = -self._loss_fn(student_logits, teacher_logits)
24+
return loss
25+
26+
class StudentLoss(WithLoss):
27+
def __init__(self, net, loss_fn, batch_size):
28+
super(StudentLoss, self).__init__(backbone=net, loss_fn=loss_fn)
29+
self.loss_fn = loss_fn
30+
self.batch_size = batch_size
31+
32+
def forward(self, data, label):
33+
logits = self.backbone_network(data['x'], data['edge_index'], data['x'].shape[0], data['batch'])
34+
loss = self._loss_fn(logits, label)
35+
return loss
36+
37+
def dense_to_sparse(adj_mat):
38+
adj_mat = tlx.convert_to_numpy(adj_mat)
39+
adj_mat = sp.coo_matrix(adj_mat)
40+
row = tlx.convert_to_tensor(adj_mat.row, dtype=tlx.int64)
41+
col = tlx.convert_to_tensor(adj_mat.col, dtype=tlx.int64)
42+
res = tlx.stack((row, col))
43+
return res
44+
45+
def Data_construct(batch_size, edges_logits, nodes_logits):
46+
nodes_logits = tlx.softmax(nodes_logits, -1)
47+
max_indices = tlx.argmax(nodes_logits, axis=2, keepdim=True)
48+
nodes_logits = tlx.zeros_like(nodes_logits)
49+
for i in range(len(nodes_logits)):
50+
for j in range(len(nodes_logits[0])):
51+
index = max_indices[i, j]
52+
nodes_logits[i, j, index] = 1
53+
edges_logits = tlx.sigmoid(tlx.cast(edges_logits, dtype=tlx.float32))
54+
edges_logits = (edges_logits>0.3).long()
55+
data_list = []
56+
s=len(nodes_logits)
57+
for i in range(s):
58+
edge = dense_to_sparse(edges_logits[i])
59+
x = nodes_logits[i]
60+
data = Graph(x=x, edge_index=edge)
61+
# draw(args,data,'filename',i)
62+
data_list.append(data)
63+
G_data = DataLoader(data_list, batch_size=batch_size, shuffle=False)
64+
return G_data
65+
66+
67+
def generate_graph(args, generator):
68+
z = tlx.random_normal((args.batch_size, args.nz))
69+
adj, nodes_logits = generator(z)
70+
loader = Data_construct(args.batch_size, adj, nodes_logits)
71+
return loader
72+
73+
74+
def train_student(args):
75+
dataset = TUDataset(args.dataset_path,args.dataset)
76+
77+
dataset_unit = len(dataset) // 10
78+
train_set = dataset[2 * dataset_unit:]
79+
val_set = dataset[:dataset_unit]
80+
test_set = dataset[dataset_unit: 2 * dataset_unit]
81+
82+
train_loader = DataLoader(train_set, batch_size=args.batch_size)
83+
val_loader = DataLoader(val_set, batch_size=args.batch_size)
84+
test_loader = DataLoader(test_set, batch_size=args.batch_size)
85+
86+
87+
teacher = GINModel(
88+
in_channels=max(dataset.num_features, 1),
89+
hidden_channels=args.t_hidden,
90+
out_channels=dataset.num_classes,
91+
num_layers=args.t_layers,
92+
name="GIN"
93+
)
94+
95+
student = DFADModel(
96+
model_name=args.student,
97+
feature_dim=max(dataset.num_features, 1),
98+
hidden_dim=args.hidden_units,
99+
num_classes=dataset.num_classes,
100+
num_layers=args.num_layers,
101+
drop_rate=args.student_dropout
102+
)
103+
104+
105+
#initialize generator
106+
x_example = dataset[0].x
107+
generator = DFADGenerator([64, 128, 256], args.nz, args.vertexes, dataset.num_features, args.generator_dropout)
108+
109+
optimizer_s = tlx.optimizers.Adam(lr=args.student_lr, weight_decay=args.student_l2_coef)
110+
optimizer_g = tlx.optimizers.Adam(lr=args.generator_lr, weight_decay=args.generator_l2_coef)
111+
112+
optimizer = tlx.optimizers.Adam(lr=0.001, weight_decay=5e-4)
113+
train_weights = teacher.trainable_weights
114+
loss_func = SemiSpvzLoss(teacher, tlx.losses.softmax_cross_entropy_with_logits)
115+
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
116+
117+
teacher.set_train()
118+
for data in train_loader:
119+
train_loss = train_one_step(data, data.y)
120+
teacher.set_eval()
121+
teacher.load_weights("./teacher_" + args.dataset + ".npz", format='npz_dict')
122+
123+
student_trainable_weight = student.trainable_weights
124+
generator_trainable_weights = generator.trainable_weights
125+
s_loss_fun = tlx.losses.L1Loss
126+
g_loss_fun = tlx.losses.L1Loss
127+
128+
s_with_loss = StudentLoss(student, s_loss_fun, args.batch_size)
129+
130+
g_with_loss = GeneratorLoss(generator, g_loss_fun)
131+
s_train_one_step = TrainOneStep(s_with_loss, optimizer_s, student_trainable_weight)
132+
g_train_one_step = TrainOneStep(g_with_loss, optimizer_g, generator_trainable_weights)
133+
134+
epochs = args.n_epoch
135+
student_epoch = args.student_epoch
136+
g_epoch = args.g_epoch
137+
138+
best_acc = 0
139+
for epoch in range(epochs):
140+
student.set_train()
141+
for _ in range(student_epoch):
142+
# train student model
143+
loader = generate_graph(args, generator)
144+
teacher.set_eval()
145+
s_loss = 0
146+
for data in loader:
147+
t_logits = teacher(data.x, data.edge_index, data.batch)
148+
s_loss += s_train_one_step(data, t_logits)
149+
# print('s_loss:', s_loss)
150+
student.set_eval()
151+
152+
# train generator
153+
generator.set_train()
154+
for _ in range(g_epoch):
155+
z = tlx.random_normal((args.batch_size, args.nz))
156+
adj, nodes_logits = generator(z)
157+
loader = Data_construct(z.shape[0], adj, nodes_logits)
158+
g_loss = 0
159+
for data in loader:
160+
x, edge_index, num_nodes, batch = data.x, data.edge_index, data.num_nodes, data.batch
161+
student_logits = student(x, edge_index, num_nodes, batch)
162+
teacher_logits = teacher(x, edge_index, batch)
163+
student_logits = tlx.nn.Softmax()(student_logits)
164+
teacher_logits = tlx.nn.Softmax()(teacher_logits)
165+
g_loss += g_train_one_step(student_logits, teacher_logits)
166+
# print('g_loss:', g_loss)
167+
generator.set_eval()
168+
169+
total_correct = 0
170+
for data in val_loader:
171+
test_logits = student(data.x, data.edge_index, data.x.shape[0], data.batch)
172+
teacher_logits = teacher(data.x, data.edge_index, data.batch)
173+
pred = tlx.argmax(test_logits, axis=-1)
174+
total_correct += int((numpy.sum(tlx.convert_to_numpy(pred == data['y']).astype(int))))
175+
test_acc = total_correct / len(val_set)
176+
177+
if test_acc > best_acc:
178+
best_acc = test_acc
179+
student.save_weights(args.student + "_" + args.dataset + ".npz", format='npz_dict')
180+
print("Epoch [{:0>3d}] ".format(epoch + 1)
181+
+ " acc: {:.4f}".format(test_acc))
182+
183+
total_correct = 0
184+
for data in test_loader:
185+
teacher_logits = teacher(data.x, data.edge_index, data.batch)
186+
pred = tlx.argmax(teacher_logits, axis=-1)
187+
total_correct += int((numpy.sum(tlx.convert_to_numpy(pred == data['y']).astype(int))))
188+
teacher_acc = total_correct / len(test_set)
189+
190+
student.load_weights(args.student + "_" + args.dataset + ".npz", format='npz_dict')
191+
total_correct = 0
192+
for data in test_loader:
193+
student_logits = student(data.x, data.edge_index, data.x.shape[0], data.batch)
194+
pred = tlx.argmax(student_logits, axis=-1)
195+
total_correct += int((numpy.sum(tlx.convert_to_numpy(pred == data['y']).astype(int))))
196+
student_acc = total_correct / len(test_set)
197+
198+
print('teacher_acc:', teacher_acc)
199+
print('student_acc:', student_acc)
200+
201+
if __name__ == '__main__':
202+
parser = argparse.ArgumentParser()
203+
parser.add_argument("--student", type=str, default='gcn', help="student model")
204+
parser.add_argument("--student_lr", type=float, default=0.005, help="learning rate of student model")
205+
parser.add_argument("--generator_lr", type=float, default=0.005, help="learning rate of generator")
206+
parser.add_argument("--n_epoch", type=int, default=100, help="number of epoch")
207+
parser.add_argument("--student_epoch", type=int, default=5)
208+
parser.add_argument("--g_epoch", type=int, default=5)
209+
parser.add_argument("--num_layers", type=int, default=5)
210+
parser.add_argument("--t_hidden", type=int, default=128, help="dimention of hidden layers of teacher model")
211+
parser.add_argument("--t_layers", type=int, default=5)
212+
parser.add_argument("--hidden_units", type=int, default=128, help="dimention of hidden layers")
213+
parser.add_argument("--student_l2_coef", type=float, default=5e-4, help="l2 loss coeficient for student")
214+
parser.add_argument("--generator_l2_coef", type=float, default=5e-4, help="l2 loss coeficient for generator")
215+
parser.add_argument('--dataset', type=str, default='MUTAG', help='dataset(MUTAG/IMDB-BINARY/REDDIT-BINARY)')
216+
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
217+
parser.add_argument('--vertexes', type=int, default=40, help='dimension of domain labels')
218+
parser.add_argument("--generator_dropout", type=float, default=0.)
219+
parser.add_argument("--student_dropout", type=float, default=0.)
220+
parser.add_argument("--batch_size", type=int, default=128)
221+
parser.add_argument("--nz", type=int, default=32)
222+
parser.add_argument("--gpu", type=int, default=0)
223+
args = parser.parse_args()
224+
225+
if args.gpu >= 0:
226+
tlx.set_device("GPU", args.gpu)
227+
else:
228+
tlx.set_device("CPU")
229+
230+
train_student(args)

examples/dfad_gnn/train_teacher.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3+
# os.environ['TL_BACKEND'] = 'torch'
4+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
5+
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
6+
7+
import argparse
8+
import tensorlayerx as tlx
9+
from gammagl.models import GINModel
10+
from gammagl.loader import DataLoader
11+
from gammagl.datasets import TUDataset
12+
from tensorlayerx.model import TrainOneStep, WithLoss
13+
import numpy
14+
15+
class SemiSpvzLoss(WithLoss):
16+
def __init__(self, net, loss_fn):
17+
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)
18+
19+
def forward(self, data, y):
20+
train_logits = self.backbone_network(data['x'], data['edge_index'], data['batch'])
21+
loss = self._loss_fn(train_logits, data['y'])
22+
return loss
23+
24+
def train_teacher(args):
25+
dataset = TUDataset(args.dataset_path,args.dataset)
26+
27+
dataset_unit = len(dataset) // 10
28+
train_set = dataset[2 * dataset_unit:]
29+
val_set = dataset[:dataset_unit]
30+
test_set = dataset[dataset_unit: 2 * dataset_unit]
31+
32+
train_loader = DataLoader(train_set, batch_size=args.batch_size)
33+
val_loader = DataLoader(val_set, batch_size=args.batch_size)
34+
test_loader = DataLoader(test_set, batch_size=args.batch_size)
35+
36+
teacher = GINModel(
37+
in_channels=max(dataset.num_features, 1),
38+
hidden_channels=args.hidden_dim,
39+
out_channels=dataset.num_classes,
40+
num_layers=args.num_layers,
41+
name="GIN"
42+
)
43+
44+
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef)
45+
train_weights = teacher.trainable_weights
46+
loss_func = SemiSpvzLoss(teacher, tlx.losses.softmax_cross_entropy_with_logits)
47+
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
48+
49+
best_val_acc = 0
50+
for epoch in range(args.n_epoch):
51+
teacher.set_train()
52+
53+
for data in train_loader:
54+
train_loss = train_one_step(data, data.y)
55+
56+
teacher.set_eval()
57+
58+
total_correct = 0
59+
for data in val_loader:
60+
val_logits = teacher(data.x, data.edge_index, data.batch)
61+
pred = tlx.argmax(val_logits, axis=-1)
62+
total_correct += int((numpy.sum(tlx.convert_to_numpy(pred == data.y).astype(int))))
63+
val_acc = total_correct / len(val_set)
64+
65+
print("Epoch [{:0>3d}] ".format(epoch + 1) \
66+
+ " train loss: {:.4f}".format(train_loss.item()) \
67+
+ " val acc: {:.4f}".format(val_acc))
68+
69+
if val_acc > best_val_acc:
70+
best_val_acc = val_acc
71+
teacher.save_weights("./teacher_" + args.dataset + ".npz", format='npz_dict')
72+
73+
teacher.load_weights("./teacher_" + args.dataset + ".npz", format='npz_dict', skip=True)
74+
teacher.set_eval()
75+
total_correct = 0
76+
for data in test_loader:
77+
test_logits = teacher(data.x, data.edge_index, data.batch)
78+
pred = tlx.argmax(test_logits, axis=-1)
79+
total_correct += int((numpy.sum(tlx.convert_to_numpy(pred == data['y']).astype(int))))
80+
test_acc = total_correct / len(test_set)
81+
82+
print("Test acc: {:.4f}".format(test_acc))
83+
84+
85+
if __name__ == '__main__':
86+
parser = argparse.ArgumentParser()
87+
parser.add_argument("--n_epoch", type=int, default=1000, help="number of epoch")
88+
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
89+
parser.add_argument("--num_layers", type=int, default=5)
90+
parser.add_argument("--hidden_dim", type=int, default=128, help="dimention of hidden layers")
91+
parser.add_argument("--l2_coef", type=float, default=5e-4, help="l2 loss coeficient")
92+
parser.add_argument('--dataset', type=str, default='MUTAG', help='dataset(MUTAG/IMDB-BINARY/REDDIT-BINARY)')
93+
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
94+
parser.add_argument("--batch_size", type=int, default=128)
95+
parser.add_argument("--gpu", type=int, default=0)
96+
args = parser.parse_args()
97+
98+
if args.gpu >= 0:
99+
tlx.set_device("GPU", args.gpu)
100+
else:
101+
tlx.set_device("CPU")
102+
103+
train_teacher(args)

gammagl/models/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from .hid_net import Hid_net
6060
from .gnnlfhf import GNNLFHFModel
6161
from .dna import DNAModel
62+
from .dfad import DFADModel, DFADGenerator
6263

6364
__all__ = [
6465
'HeCo',
@@ -123,7 +124,9 @@
123124
'hid_net',
124125
'HEAT',
125126
'GNNLFHFModel',
126-
'DNAModel'
127+
'DNAModel',
128+
'DFADModel',
129+
'DFADGenerator'
127130
]
128131

129132
classes = __all__

0 commit comments

Comments
 (0)