1
+ import os
2
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3
+ # os.environ['TL_BACKEND'] = 'torch'
4
+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '2'
5
+ # 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
6
+
7
+ import argparse
8
+ import tensorlayerx as tlx
9
+ from gammagl .models import GINModel , DFADModel , DFADGenerator
10
+ from gammagl .loader import DataLoader
11
+ from gammagl .data import Graph
12
+ from gammagl .datasets import TUDataset
13
+ from tensorlayerx .model import TrainOneStep , WithLoss
14
+ import numpy
15
+ import scipy .sparse as sp
16
+ from train_teacher import SemiSpvzLoss
17
+
18
+ class GeneratorLoss (WithLoss ):
19
+ def __init__ (self , net , loss_fn ):
20
+ super (GeneratorLoss , self ).__init__ (backbone = net , loss_fn = loss_fn )
21
+
22
+ def forward (self , student_logits , teacher_logits ):
23
+ loss = - self ._loss_fn (student_logits , teacher_logits )
24
+ return loss
25
+
26
+ class StudentLoss (WithLoss ):
27
+ def __init__ (self , net , loss_fn , batch_size ):
28
+ super (StudentLoss , self ).__init__ (backbone = net , loss_fn = loss_fn )
29
+ self .loss_fn = loss_fn
30
+ self .batch_size = batch_size
31
+
32
+ def forward (self , data , label ):
33
+ logits = self .backbone_network (data ['x' ], data ['edge_index' ], data ['x' ].shape [0 ], data ['batch' ])
34
+ loss = self ._loss_fn (logits , label )
35
+ return loss
36
+
37
+ def dense_to_sparse (adj_mat ):
38
+ adj_mat = tlx .convert_to_numpy (adj_mat )
39
+ adj_mat = sp .coo_matrix (adj_mat )
40
+ row = tlx .convert_to_tensor (adj_mat .row , dtype = tlx .int64 )
41
+ col = tlx .convert_to_tensor (adj_mat .col , dtype = tlx .int64 )
42
+ res = tlx .stack ((row , col ))
43
+ return res
44
+
45
+ def Data_construct (batch_size , edges_logits , nodes_logits ):
46
+ nodes_logits = tlx .softmax (nodes_logits , - 1 )
47
+ max_indices = tlx .argmax (nodes_logits , axis = 2 , keepdim = True )
48
+ nodes_logits = tlx .zeros_like (nodes_logits )
49
+ for i in range (len (nodes_logits )):
50
+ for j in range (len (nodes_logits [0 ])):
51
+ index = max_indices [i , j ]
52
+ nodes_logits [i , j , index ] = 1
53
+ edges_logits = tlx .sigmoid (tlx .cast (edges_logits , dtype = tlx .float32 ))
54
+ edges_logits = (edges_logits > 0.3 ).long ()
55
+ data_list = []
56
+ s = len (nodes_logits )
57
+ for i in range (s ):
58
+ edge = dense_to_sparse (edges_logits [i ])
59
+ x = nodes_logits [i ]
60
+ data = Graph (x = x , edge_index = edge )
61
+ # draw(args,data,'filename',i)
62
+ data_list .append (data )
63
+ G_data = DataLoader (data_list , batch_size = batch_size , shuffle = False )
64
+ return G_data
65
+
66
+
67
+ def generate_graph (args , generator ):
68
+ z = tlx .random_normal ((args .batch_size , args .nz ))
69
+ adj , nodes_logits = generator (z )
70
+ loader = Data_construct (args .batch_size , adj , nodes_logits )
71
+ return loader
72
+
73
+
74
+ def train_student (args ):
75
+ dataset = TUDataset (args .dataset_path ,args .dataset )
76
+
77
+ dataset_unit = len (dataset ) // 10
78
+ train_set = dataset [2 * dataset_unit :]
79
+ val_set = dataset [:dataset_unit ]
80
+ test_set = dataset [dataset_unit : 2 * dataset_unit ]
81
+
82
+ train_loader = DataLoader (train_set , batch_size = args .batch_size )
83
+ val_loader = DataLoader (val_set , batch_size = args .batch_size )
84
+ test_loader = DataLoader (test_set , batch_size = args .batch_size )
85
+
86
+
87
+ teacher = GINModel (
88
+ in_channels = max (dataset .num_features , 1 ),
89
+ hidden_channels = args .t_hidden ,
90
+ out_channels = dataset .num_classes ,
91
+ num_layers = args .t_layers ,
92
+ name = "GIN"
93
+ )
94
+
95
+ student = DFADModel (
96
+ model_name = args .student ,
97
+ feature_dim = max (dataset .num_features , 1 ),
98
+ hidden_dim = args .hidden_units ,
99
+ num_classes = dataset .num_classes ,
100
+ num_layers = args .num_layers ,
101
+ drop_rate = args .student_dropout
102
+ )
103
+
104
+
105
+ #initialize generator
106
+ x_example = dataset [0 ].x
107
+ generator = DFADGenerator ([64 , 128 , 256 ], args .nz , args .vertexes , dataset .num_features , args .generator_dropout )
108
+
109
+ optimizer_s = tlx .optimizers .Adam (lr = args .student_lr , weight_decay = args .student_l2_coef )
110
+ optimizer_g = tlx .optimizers .Adam (lr = args .generator_lr , weight_decay = args .generator_l2_coef )
111
+
112
+ optimizer = tlx .optimizers .Adam (lr = 0.001 , weight_decay = 5e-4 )
113
+ train_weights = teacher .trainable_weights
114
+ loss_func = SemiSpvzLoss (teacher , tlx .losses .softmax_cross_entropy_with_logits )
115
+ train_one_step = TrainOneStep (loss_func , optimizer , train_weights )
116
+
117
+ teacher .set_train ()
118
+ for data in train_loader :
119
+ train_loss = train_one_step (data , data .y )
120
+ teacher .set_eval ()
121
+ teacher .load_weights ("./teacher_" + args .dataset + ".npz" , format = 'npz_dict' )
122
+
123
+ student_trainable_weight = student .trainable_weights
124
+ generator_trainable_weights = generator .trainable_weights
125
+ s_loss_fun = tlx .losses .L1Loss
126
+ g_loss_fun = tlx .losses .L1Loss
127
+
128
+ s_with_loss = StudentLoss (student , s_loss_fun , args .batch_size )
129
+
130
+ g_with_loss = GeneratorLoss (generator , g_loss_fun )
131
+ s_train_one_step = TrainOneStep (s_with_loss , optimizer_s , student_trainable_weight )
132
+ g_train_one_step = TrainOneStep (g_with_loss , optimizer_g , generator_trainable_weights )
133
+
134
+ epochs = args .n_epoch
135
+ student_epoch = args .student_epoch
136
+ g_epoch = args .g_epoch
137
+
138
+ best_acc = 0
139
+ for epoch in range (epochs ):
140
+ student .set_train ()
141
+ for _ in range (student_epoch ):
142
+ # train student model
143
+ loader = generate_graph (args , generator )
144
+ teacher .set_eval ()
145
+ s_loss = 0
146
+ for data in loader :
147
+ t_logits = teacher (data .x , data .edge_index , data .batch )
148
+ s_loss += s_train_one_step (data , t_logits )
149
+ # print('s_loss:', s_loss)
150
+ student .set_eval ()
151
+
152
+ # train generator
153
+ generator .set_train ()
154
+ for _ in range (g_epoch ):
155
+ z = tlx .random_normal ((args .batch_size , args .nz ))
156
+ adj , nodes_logits = generator (z )
157
+ loader = Data_construct (z .shape [0 ], adj , nodes_logits )
158
+ g_loss = 0
159
+ for data in loader :
160
+ x , edge_index , num_nodes , batch = data .x , data .edge_index , data .num_nodes , data .batch
161
+ student_logits = student (x , edge_index , num_nodes , batch )
162
+ teacher_logits = teacher (x , edge_index , batch )
163
+ student_logits = tlx .nn .Softmax ()(student_logits )
164
+ teacher_logits = tlx .nn .Softmax ()(teacher_logits )
165
+ g_loss += g_train_one_step (student_logits , teacher_logits )
166
+ # print('g_loss:', g_loss)
167
+ generator .set_eval ()
168
+
169
+ total_correct = 0
170
+ for data in val_loader :
171
+ test_logits = student (data .x , data .edge_index , data .x .shape [0 ], data .batch )
172
+ teacher_logits = teacher (data .x , data .edge_index , data .batch )
173
+ pred = tlx .argmax (test_logits , axis = - 1 )
174
+ total_correct += int ((numpy .sum (tlx .convert_to_numpy (pred == data ['y' ]).astype (int ))))
175
+ test_acc = total_correct / len (val_set )
176
+
177
+ if test_acc > best_acc :
178
+ best_acc = test_acc
179
+ student .save_weights (args .student + "_" + args .dataset + ".npz" , format = 'npz_dict' )
180
+ print ("Epoch [{:0>3d}] " .format (epoch + 1 )
181
+ + " acc: {:.4f}" .format (test_acc ))
182
+
183
+ total_correct = 0
184
+ for data in test_loader :
185
+ teacher_logits = teacher (data .x , data .edge_index , data .batch )
186
+ pred = tlx .argmax (teacher_logits , axis = - 1 )
187
+ total_correct += int ((numpy .sum (tlx .convert_to_numpy (pred == data ['y' ]).astype (int ))))
188
+ teacher_acc = total_correct / len (test_set )
189
+
190
+ student .load_weights (args .student + "_" + args .dataset + ".npz" , format = 'npz_dict' )
191
+ total_correct = 0
192
+ for data in test_loader :
193
+ student_logits = student (data .x , data .edge_index , data .x .shape [0 ], data .batch )
194
+ pred = tlx .argmax (student_logits , axis = - 1 )
195
+ total_correct += int ((numpy .sum (tlx .convert_to_numpy (pred == data ['y' ]).astype (int ))))
196
+ student_acc = total_correct / len (test_set )
197
+
198
+ print ('teacher_acc:' , teacher_acc )
199
+ print ('student_acc:' , student_acc )
200
+
201
+ if __name__ == '__main__' :
202
+ parser = argparse .ArgumentParser ()
203
+ parser .add_argument ("--student" , type = str , default = 'gcn' , help = "student model" )
204
+ parser .add_argument ("--student_lr" , type = float , default = 0.005 , help = "learning rate of student model" )
205
+ parser .add_argument ("--generator_lr" , type = float , default = 0.005 , help = "learning rate of generator" )
206
+ parser .add_argument ("--n_epoch" , type = int , default = 100 , help = "number of epoch" )
207
+ parser .add_argument ("--student_epoch" , type = int , default = 5 )
208
+ parser .add_argument ("--g_epoch" , type = int , default = 5 )
209
+ parser .add_argument ("--num_layers" , type = int , default = 5 )
210
+ parser .add_argument ("--t_hidden" , type = int , default = 128 , help = "dimention of hidden layers of teacher model" )
211
+ parser .add_argument ("--t_layers" , type = int , default = 5 )
212
+ parser .add_argument ("--hidden_units" , type = int , default = 128 , help = "dimention of hidden layers" )
213
+ parser .add_argument ("--student_l2_coef" , type = float , default = 5e-4 , help = "l2 loss coeficient for student" )
214
+ parser .add_argument ("--generator_l2_coef" , type = float , default = 5e-4 , help = "l2 loss coeficient for generator" )
215
+ parser .add_argument ('--dataset' , type = str , default = 'MUTAG' , help = 'dataset(MUTAG/IMDB-BINARY/REDDIT-BINARY)' )
216
+ parser .add_argument ("--dataset_path" , type = str , default = r'' , help = "path to save dataset" )
217
+ parser .add_argument ('--vertexes' , type = int , default = 40 , help = 'dimension of domain labels' )
218
+ parser .add_argument ("--generator_dropout" , type = float , default = 0. )
219
+ parser .add_argument ("--student_dropout" , type = float , default = 0. )
220
+ parser .add_argument ("--batch_size" , type = int , default = 128 )
221
+ parser .add_argument ("--nz" , type = int , default = 32 )
222
+ parser .add_argument ("--gpu" , type = int , default = 0 )
223
+ args = parser .parse_args ()
224
+
225
+ if args .gpu >= 0 :
226
+ tlx .set_device ("GPU" , args .gpu )
227
+ else :
228
+ tlx .set_device ("CPU" )
229
+
230
+ train_student (args )
0 commit comments