1
+ import argparse
2
+ import os .path as osp
3
+ import os
4
+ # os.environ['TL_BACKEND'] = 'torch'
5
+ from time import perf_counter as t
6
+ import yaml
7
+ from yaml import SafeLoader
8
+ import numpy as np
9
+ import pickle
10
+ import tensorlayerx as tlx
11
+ from tensorlayerx .model import TrainOneStep , WithLoss
12
+ from tensorlayerx .dataflow import random_split
13
+ from gammagl .layers .conv import GCNConv
14
+ from gammagl .datasets import Planetoid , Coauthor , Amazon
15
+ import gammagl .transforms as T
16
+
17
+ from gammagl .models .grace_pot import Grace_POT_Encoder , Grace_POT_Model
18
+ from eval_gracepot import log_regression , MulticlassEvaluator
19
+
20
+ A_upper_1 = None
21
+ A_upper_2 = None
22
+ A_lower_1 = None
23
+ A_lower_2 = None
24
+
25
+ class train_loss (WithLoss ):
26
+ def __init__ (self , model , drop_edge_rate_1 , drop_edge_rate_2 , use_pot = False , pot_batch = - 1 , kappa = 0.5 ):
27
+ super (train_loss , self ).__init__ (backbone = model , loss_fn = None )
28
+ self .drop_edge_rate_1 = drop_edge_rate_1
29
+ self .drop_edge_rate_2 = drop_edge_rate_2
30
+ self .use_pot = use_pot
31
+ self .pot_batch = pot_batch
32
+ self .kappa = kappa
33
+
34
+ def forward (self , model , x , edge_index , epoch , data = None ):
35
+ edge_index_1 = dropout_adj (edge_index , p = self .drop_edge_rate_1 )[0 ]
36
+ edge_index_2 = dropout_adj (edge_index , p = self .drop_edge_rate_2 )[0 ]
37
+ x_1 , x_2 = x , x
38
+ z1 = model (x_1 , edge_index_1 )
39
+ z2 = model (x_2 , edge_index_2 )
40
+ node_list = np .arange (z1 .shape [0 ])
41
+ np .random .shuffle (node_list )
42
+
43
+ batch_size = 4096 if args .dataset in ["PubMed" , "Computers" , "WikiCS" ] else None
44
+
45
+ if batch_size is not None :
46
+ node_list_batch = get_batch (node_list , batch_size , epoch )
47
+
48
+ # nce loss
49
+ if batch_size is not None :
50
+ z11 = z1 [node_list_batch ]
51
+ z22 = z2 [node_list_batch ]
52
+ nce_loss = model .loss (z11 , z22 )
53
+ else :
54
+ nce_loss = model .loss (z1 , z2 )
55
+
56
+ # pot loss
57
+ if self .use_pot :
58
+ # get node_list_tmp, the nodes to calculate pot_loss
59
+ if self .pot_batch != - 1 :
60
+ if batch_size is None :
61
+ node_list_tmp = get_batch (node_list , self .pot_batch , epoch )
62
+ else :
63
+ node_list_tmp = get_batch (node_list_batch , self .pot_batch , epoch )
64
+ else :
65
+ # full pot batch
66
+ if batch_size is None :
67
+ node_list_tmp = node_list
68
+ else :
69
+ node_list_tmp = node_list_batch
70
+
71
+ z11 = tlx .gather (z1 , tlx .convert_to_tensor (node_list_tmp ))
72
+ z22 = tlx .gather (z2 , tlx .convert_to_tensor (node_list_tmp ))
73
+
74
+ global A_upper_1 , A_upper_2 , A_lower_1 , A_lower_2
75
+ if A_upper_1 is None or A_upper_2 is None :
76
+ A_upper_1 , A_lower_1 = get_A_bounds (args .dataset , self .drop_edge_rate_1 , args .cache )
77
+ A_upper_2 , A_lower_2 = get_A_bounds (args .dataset , self .drop_edge_rate_2 , args .cache )
78
+
79
+ pot_loss_1 = model .pot_loss (z11 , z22 , data .x , data .edge_index , edge_index_1 , local_changes = self .drop_edge_rate_1 ,
80
+ node_list = node_list_tmp , A_upper = A_upper_1 , A_lower = A_lower_1 )
81
+ pot_loss_2 = model .pot_loss (z22 , z11 , data .x , data .edge_index , edge_index_2 , local_changes = self .drop_edge_rate_2 ,
82
+ node_list = node_list_tmp , A_upper = A_upper_2 , A_lower = A_lower_2 )
83
+ pot_loss = (pot_loss_1 + pot_loss_2 ) / 2
84
+ loss = (1 - self .kappa ) * nce_loss + self .kappa * pot_loss
85
+ else :
86
+ loss = nce_loss
87
+
88
+ return loss
89
+
90
+
91
+ def test (model , data , dataset , split ):
92
+ model .set_eval ()
93
+ z = model (data .x , data .edge_index )
94
+ evaluator = MulticlassEvaluator ()
95
+ res = log_regression (z , dataset , evaluator , split = 'preloaded' , num_epochs = 3000 , preload_split = split )
96
+ return res
97
+
98
+ def get_dataset (path , name ):
99
+ assert name in ['Cora' , 'CiteSeer' , 'PubMed' , 'Coauthor-CS' , 'Coauthor-Phy' , 'Computers' , 'Photo' ]
100
+ name = 'dblp' if name == 'DBLP' else name
101
+
102
+ if name == 'Coauthor-CS' :
103
+ return Coauthor (root = path , name = 'cs' , transform = T .NormalizeFeatures ())
104
+
105
+ if name == 'Coauthor-Phy' :
106
+ return Coauthor (root = path , name = 'physics' , transform = T .NormalizeFeatures ())
107
+
108
+ if name == 'Computers' :
109
+ return Amazon (root = path , name = 'computers' , transform = T .NormalizeFeatures ())
110
+
111
+ if name == 'Photo' :
112
+ return Amazon (root = path , name = 'photo' , transform = T .NormalizeFeatures ())
113
+
114
+
115
+ return (Planetoid )(path , name , transform = T .NormalizeFeatures ()) # public split
116
+
117
+ def generate_split (num_samples : int , train_ratio : float , val_ratio : float ):
118
+ train_len = int (num_samples * train_ratio )
119
+ val_len = int (num_samples * val_ratio )
120
+ test_len = num_samples - train_len - val_len
121
+
122
+ train_set , test_set , val_set = random_split (tlx .arange (0 , num_samples ), (train_len , test_len , val_len ))
123
+
124
+ idx_train , idx_test , idx_val = train_set .indices , test_set .indices , val_set .indices
125
+ train_mask = tlx .zeros ((num_samples ,)).to (tlx .bool )
126
+ test_mask = tlx .zeros ((num_samples ,)).to (tlx .bool )
127
+ val_mask = tlx .zeros ((num_samples ,)).to (tlx .bool )
128
+
129
+ train_mask [idx_train ] = True
130
+ test_mask [idx_test ] = True
131
+ val_mask [idx_val ] = True
132
+
133
+ return train_mask , test_mask , val_mask
134
+
135
+ def get_batch (node_list , batch_size , epoch ):
136
+ num_nodes = len (node_list )
137
+ num_batches = (num_nodes - 1 ) // batch_size + 1
138
+ i = epoch % num_batches
139
+ if (i + 1 ) * batch_size >= len (node_list ):
140
+ node_list_batch = node_list [i * batch_size :]
141
+ else :
142
+ node_list_batch = node_list [i * batch_size :(i + 1 ) * batch_size ]
143
+ return node_list_batch
144
+
145
+ def get_A_bounds (dataset , drop_rate , cache ):
146
+ upper_lower_file = osp .join (cache , f"{ dataset } _{ drop_rate } _upper_lower.pkl" )
147
+ if osp .exists (upper_lower_file ):
148
+ with open (upper_lower_file , 'rb' ) as file :
149
+ A_upper , A_lower = pickle .load (file )
150
+ else :
151
+ A_upper , A_lower = None , None
152
+ return A_upper , A_lower
153
+
154
+ def filter_adj (row , col , edge_attr , mask ):
155
+ mask = tlx .convert_to_tensor (mask , dtype = tlx .bool )
156
+ return row [mask ], col [mask ], None if edge_attr is None else edge_attr [mask ]
157
+
158
+ def dropout_adj (
159
+ edge_index ,
160
+ edge_attr = None ,
161
+ p = 0.5 ,
162
+ force_undirected = False ,
163
+ num_nodes = None ,
164
+ training = True ,
165
+ ):
166
+
167
+ if p < 0. or p > 1. :
168
+ raise ValueError (f'Dropout probability has to be between 0 and 1 '
169
+ f'(got { p } ' )
170
+
171
+ if not training or p == 0.0 :
172
+ return edge_index , edge_attr
173
+
174
+ # row, col = edge_index
175
+ row = edge_index [0 ]
176
+ col = edge_index [1 ]
177
+
178
+ mask = np .random .random (tlx .get_tensor_shape (row )) >= p
179
+
180
+ if force_undirected :
181
+ mask [row > col ] = False
182
+
183
+ row , col , edge_attr = filter_adj (row , col , edge_attr , mask )
184
+
185
+ if force_undirected :
186
+ edge_index = tlx .stack (
187
+ [tlx .concat ([row , col ], 0 ),
188
+ tlx .concat ([col , row ], 0 )], dim = 0 )
189
+ if edge_attr is not None :
190
+ edge_attr = tlx .concat ([edge_attr , edge_attr ], 0 )
191
+ else :
192
+ edge_index = tlx .stack ([row , col ])
193
+
194
+ return edge_index , edge_attr
195
+
196
+
197
+ def main (args ):
198
+ if args .gpu_id >= 0 :
199
+ tlx .set_device (device = 'GPU' , id = args .gpu_id )
200
+ else :
201
+ tlx .set_device (device = 'CPU' )
202
+
203
+ config = yaml .load (open (args .config ), Loader = SafeLoader )[args .dataset ]
204
+ # for hyperparameter tuning
205
+ if args .drop_1 != - 1 :
206
+ config ['drop_edge_rate_1' ] = args .drop_1
207
+ if args .drop_2 != - 1 :
208
+ config ['drop_edge_rate_2' ] = args .drop_2
209
+ if args .tau != - 1 :
210
+ config ['tau' ] = args .tau
211
+ if args .num_epochs != - 1 :
212
+ config ['num_epochs' ] = args .num_epochs
213
+ print (args )
214
+ print (config )
215
+
216
+ learning_rate = config ['learning_rate' ]
217
+ num_hidden = config ['num_hidden' ]
218
+ num_proj_hidden = config ['num_proj_hidden' ]
219
+ activation = ({'relu' : tlx .nn .ReLU , 'prelu' : tlx .nn .PRelu ()})[config ['activation' ]]
220
+ base_model = ({'GCNConv' : GCNConv })[config ['base_model' ]]
221
+ num_layers = config ['num_layers' ]
222
+
223
+ drop_edge_rate_1 = config ['drop_edge_rate_1' ]
224
+ drop_edge_rate_2 = config ['drop_edge_rate_2' ]
225
+ tau = config ['tau' ]
226
+ num_epochs = config ['num_epochs' ]
227
+ weight_decay = config ['weight_decay' ]
228
+ use_pot = args .use_pot
229
+ kappa = args .kappa
230
+ pot_batch = args .pot_batch
231
+
232
+ dataset = get_dataset (args .path , args .dataset )
233
+ data = dataset [0 ]
234
+
235
+ # generate split
236
+ if args .dataset in ["Cora" , "CiteSeer" , "PubMed" ]:
237
+ split = data .train_mask , data .val_mask , data .test_mask
238
+ print ("Public Split" )
239
+ else :
240
+ split = generate_split (data .num_nodes , train_ratio = 0.1 , val_ratio = 0.1 )
241
+ print ("Random Split" )
242
+
243
+ encoder = Grace_POT_Encoder (dataset .num_features , num_hidden , activation ,
244
+ base_model = base_model , k = num_layers )
245
+ model = Grace_POT_Model (encoder , num_hidden , num_proj_hidden , tau , dataset = args .dataset , cached = args .cache )
246
+ train_weights = model .trainable_weights
247
+ optimizer = tlx .optimizers .Adam (lr = learning_rate , weight_decay = weight_decay )
248
+ loss_func = train_loss (model , drop_edge_rate_1 , drop_edge_rate_2 , use_pot , pot_batch , kappa )
249
+ train_one_step = TrainOneStep (loss_func , optimizer , train_weights )
250
+
251
+ #timing
252
+ start = t ()
253
+ prev = start
254
+ for epoch in range (1 , num_epochs + 1 ):
255
+ model .set_train ()
256
+ loss = train_one_step (model , data .x , data .edge_index , epoch ,data )
257
+ now = t ()
258
+ print (f'(T) | Epoch={ epoch :03d} , loss={ loss :.4f} , '
259
+ f'this epoch { now - prev :.4f} , total { now - start :.4f} ' )
260
+ if epoch % 100 == 0 :
261
+ res = test (model , data , dataset , split )
262
+ print (res )
263
+ prev = now
264
+
265
+ print ("=== Final ===" )
266
+ res = test (model , data , dataset , split )
267
+ print (res )
268
+
269
+
270
+ if __name__ == '__main__' :
271
+ # parameters setting
272
+ parser = argparse .ArgumentParser ()
273
+ parser .add_argument ('--path' , type = str , default = "./" )
274
+ parser .add_argument ('--cache' , type = str , default = "./" )
275
+ parser .add_argument ('--dataset' , type = str , default = 'Cora' )
276
+ parser .add_argument ('--gpu_id' , type = int , default = 0 )
277
+ parser .add_argument ('--config' , type = str , default = './config.yaml' )
278
+ parser .add_argument ('--use_pot' , default = True ) # whether to use pot in loss
279
+ parser .add_argument ('--kappa' , type = float , default = 0.5 )
280
+ parser .add_argument ('--pot_batch' , type = int , default = - 1 )
281
+ parser .add_argument ('--drop_1' , type = float , default = 0.4 )
282
+ parser .add_argument ('--drop_2' , type = float , default = 0.3 )
283
+ parser .add_argument ('--tau' , type = float , default = 0.9 ) # temperature of nce loss
284
+ parser .add_argument ('--num_epochs' ,type = int ,default = - 1 )
285
+ parser .add_argument ('--save_file' , type = str , default = "." )
286
+ parser .add_argument ('--seed' , type = int , default = 12345 )
287
+ args = parser .parse_args ()
288
+ main (args )
289
+
290
+
0 commit comments