|
| 1 | +import os |
| 2 | +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
| 3 | +# os.environ['TL_BACKEND'] = 'torch' |
| 4 | +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
| 5 | +# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR |
| 6 | + |
| 7 | +import argparse |
| 8 | +import random |
| 9 | +import numpy as np |
| 10 | +import tensorlayerx as tlx |
| 11 | +from tensorlayerx.model import TrainOneStep |
| 12 | +from sklearn.metrics import roc_auc_score |
| 13 | +from gammagl.models import DHNModel |
| 14 | +from gammagl.datasets import ACM4DHN |
| 15 | +from gammagl.utils import k_hop_subgraph, find_all_simple_paths |
| 16 | + |
| 17 | + |
| 18 | +type2idx = { |
| 19 | + 'M': 0, |
| 20 | + 'A': 1, |
| 21 | + # 'C': 2, |
| 22 | + # 'T': 3 |
| 23 | +} |
| 24 | + |
| 25 | + |
| 26 | +def dist_encoder(src, dest, G, k_hop): |
| 27 | + if (G.size(1) == 0): |
| 28 | + paths = [] |
| 29 | + else: |
| 30 | + paths = find_all_simple_paths(G, src, dest, k_hop + 2) |
| 31 | + |
| 32 | + node_type = len(type2idx) |
| 33 | + cnt = [k_hop + 1] * node_type # Default truncation for max_spd exceeded |
| 34 | + for path in paths: |
| 35 | + res = [0] * node_type |
| 36 | + for i in path: |
| 37 | + if i >= 0: |
| 38 | + res[type2idx['M']] += 1 |
| 39 | + else: |
| 40 | + res[type2idx['A']] += 1 |
| 41 | + |
| 42 | + for k in range(node_type): |
| 43 | + cnt[k] = min(cnt[k], res[k]) |
| 44 | + |
| 45 | + # Generate one-hot encoding |
| 46 | + if args.one_hot: |
| 47 | + one_hot_list = [np.eye(k_hop + 2, dtype=np.float64)[cnt[i]] |
| 48 | + for i in range(node_type)] |
| 49 | + return np.concatenate(one_hot_list) |
| 50 | + return cnt |
| 51 | + |
| 52 | + |
| 53 | +def type_encoder(node): |
| 54 | + node_type = len(type2idx) |
| 55 | + res = [0] * node_type |
| 56 | + if node.item() >= 0: |
| 57 | + res[type2idx['M']] = 1.0 |
| 58 | + else: |
| 59 | + res[type2idx['A']] = 1.0 |
| 60 | + return res |
| 61 | + |
| 62 | + |
| 63 | +mini_batch = [] |
| 64 | +fea_batch = [] |
| 65 | + |
| 66 | + |
| 67 | +def gen_fea_batch(G, root, fea_dict, k_hop): |
| 68 | + fea_batch = [] |
| 69 | + mini_batch.append(root) |
| 70 | + |
| 71 | + a = [0] * (k_hop + 2) * 4 + type_encoder(root) |
| 72 | + |
| 73 | + node_type = len(type2idx) |
| 74 | + num_fea = (k_hop + 2) * 4 + node_type |
| 75 | + fea_batch.append(np.asarray(a, |
| 76 | + dtype=np.float32 |
| 77 | + ).reshape(-1, num_fea) |
| 78 | + ) |
| 79 | + |
| 80 | + # 1-order neighbor sampling |
| 81 | + ns_1 = [] |
| 82 | + src, dst = G |
| 83 | + for node in mini_batch[-1]: |
| 84 | + if node.item() >= 0: |
| 85 | + neighbors_mask = src == node |
| 86 | + else: |
| 87 | + neighbors_mask = dst == node |
| 88 | + neighbors = list(tlx.convert_to_numpy(dst[neighbors_mask])) |
| 89 | + neighbors.append(node.item()) |
| 90 | + random_choice_list = np.random.choice(neighbors, args.num_neighbor, replace=True) |
| 91 | + ns_1.append(random_choice_list.tolist()) |
| 92 | + ns_1 = tlx.convert_to_tensor(ns_1) |
| 93 | + mini_batch.append(ns_1[0]) |
| 94 | + |
| 95 | + de_1 = [ |
| 96 | + np.concatenate([fea_dict[ns_1[0][i].item()], np.asarray(type_encoder(ns_1[0][i]))], axis=0) |
| 97 | + for i in range(0, ns_1[0].shape[0]) |
| 98 | + ] |
| 99 | + |
| 100 | + fea_batch.append(np.asarray(de_1, |
| 101 | + dtype=np.float32).reshape(1, -1) |
| 102 | + ) |
| 103 | + |
| 104 | + # 2-order neighbor sampling |
| 105 | + ns_2 = [] |
| 106 | + for node in mini_batch[-1]: |
| 107 | + if node.item() >= 0: |
| 108 | + neighbors_mask = src == node |
| 109 | + else: |
| 110 | + neighbors_mask = dst == node |
| 111 | + neighbors = list(tlx.convert_to_numpy(dst[neighbors_mask])) |
| 112 | + neighbors.append(node.item()) |
| 113 | + random_choice_list = np.random.choice(neighbors, args.num_neighbor, replace=True) |
| 114 | + ns_2.append(random_choice_list.tolist()) |
| 115 | + ns_2 = tlx.convert_to_tensor(ns_2) |
| 116 | + |
| 117 | + de_2 = [] |
| 118 | + for i in range(len(ns_2)): |
| 119 | + tmp = [] |
| 120 | + for j in range(len(ns_2[0])): |
| 121 | + tmp.append( |
| 122 | + np.concatenate([fea_dict[ns_2[i][j].item()], np.asarray(type_encoder(ns_2[i][j]))], axis=0) |
| 123 | + ) |
| 124 | + de_2.append(tmp) |
| 125 | + |
| 126 | + fea_batch.append(np.asarray(de_2, |
| 127 | + dtype=np.float32).reshape(1, -1) |
| 128 | + ) |
| 129 | + |
| 130 | + return np.concatenate(fea_batch, axis=1) |
| 131 | + |
| 132 | + |
| 133 | +def subgraph_sampling_with_DE_node_pair(G, node_pair, k_hop=2): |
| 134 | + [A, B] = node_pair |
| 135 | + |
| 136 | + edge_index = tlx.concat([G['M', 'MA', 'A'].edge_index, reversed(G['M', 'MA', 'A'].edge_index)], axis=1) |
| 137 | + |
| 138 | + # Find k-hop subgraphs of A and B |
| 139 | + sub_G_for_AB = k_hop_subgraph([A, B], k_hop, edge_index) |
| 140 | + |
| 141 | + # Remove edges using Boolean indexes |
| 142 | + # Note: Just remove the edges, the points remain |
| 143 | + edge_index_np = tlx.convert_to_numpy(sub_G_for_AB[1]) |
| 144 | + remove_indices = tlx.convert_to_tensor([ |
| 145 | + ((edge_index_np[0, i] == A) & (edge_index_np[1, i] == B)) | ( |
| 146 | + (edge_index_np[0, i] == B) & (edge_index_np[1, i] == A)) |
| 147 | + for i in range(sub_G_for_AB[1].shape[1]) |
| 148 | + ]) |
| 149 | + remove_indices = tlx.convert_to_numpy(remove_indices) |
| 150 | + sub_G_index = sub_G_for_AB[1][:, ~remove_indices] |
| 151 | + |
| 152 | + sub_G_nodes = set(np.unique(tlx.convert_to_numpy(sub_G_for_AB[0]))) | set( |
| 153 | + np.unique(tlx.convert_to_numpy(sub_G_for_AB[1]))) # Gets the points in the graph |
| 154 | + sub_G_nodes = tlx.convert_to_tensor(list(sub_G_nodes)) |
| 155 | + |
| 156 | + # Distance from all points in the subgraph to the node pair |
| 157 | + SPD_based_on_node_pair = {} |
| 158 | + for node in sub_G_nodes: |
| 159 | + tmpA = dist_encoder(A, node, sub_G_index, k_hop) |
| 160 | + tmpB = dist_encoder(B, node, sub_G_index, k_hop) |
| 161 | + |
| 162 | + SPD_based_on_node_pair[node.item()] = np.concatenate([tmpA, tmpB], axis=0) |
| 163 | + |
| 164 | + A_fea_batch = gen_fea_batch(sub_G_index, A, |
| 165 | + SPD_based_on_node_pair, k_hop) |
| 166 | + B_fea_batch = gen_fea_batch(sub_G_index, B, |
| 167 | + SPD_based_on_node_pair, k_hop) |
| 168 | + |
| 169 | + return A_fea_batch, B_fea_batch |
| 170 | + |
| 171 | + |
| 172 | +def batch_data(G, batch_size): |
| 173 | + edge_index = G['M', 'MA', 'A'].edge_index |
| 174 | + nodes = set(tlx.convert_to_tensor(np.unique(tlx.convert_to_numpy(edge_index[0])))) | set( |
| 175 | + tlx.convert_to_tensor(np.unique(tlx.convert_to_numpy(edge_index[1])))) |
| 176 | + |
| 177 | + nodes_list = [] |
| 178 | + for node in nodes: |
| 179 | + nodes_list.append(node.item()) |
| 180 | + |
| 181 | + num_batch = int(len(edge_index[0]) / batch_size) |
| 182 | + |
| 183 | + # Shuffle the order of the edges |
| 184 | + edge_index_np = tlx.convert_to_numpy(edge_index) |
| 185 | + permutation = np.random.permutation(edge_index_np.shape[1]) # Generate a randomly arranged index |
| 186 | + edge_index_np = edge_index_np[:, permutation] # Use this permutation index to scramble edge_index |
| 187 | + edge_index = tlx.convert_to_tensor(edge_index_np) |
| 188 | + |
| 189 | + for idx in range(num_batch): |
| 190 | + batch_edge = edge_index[:, idx * batch_size:(idx + 1) * batch_size] # Take out batch_size edges |
| 191 | + batch_label = [1.0] * batch_size |
| 192 | + |
| 193 | + batch_A_fea = [] |
| 194 | + batch_B_fea = [] |
| 195 | + batch_x = [] |
| 196 | + batch_y = [] |
| 197 | + |
| 198 | + i = 0 |
| 199 | + for by in batch_label: |
| 200 | + bx = batch_edge[:, i:i + 1] |
| 201 | + |
| 202 | + # Positive sample |
| 203 | + posA, posB = subgraph_sampling_with_DE_node_pair(G, bx, k_hop=args.k_hop) |
| 204 | + batch_A_fea.append(posA) |
| 205 | + batch_B_fea.append(posB) |
| 206 | + batch_y.append(np.asarray(by, dtype=np.float32)) |
| 207 | + |
| 208 | + # Negative sample |
| 209 | + neg_tmpB_id = random.choice(nodes_list) |
| 210 | + node_pair = tlx.convert_to_tensor([[bx[0].item()], [neg_tmpB_id]]) |
| 211 | + |
| 212 | + negA, negB = subgraph_sampling_with_DE_node_pair(G, node_pair, k_hop=args.k_hop) |
| 213 | + batch_A_fea.append(negA) |
| 214 | + batch_B_fea.append(negB) |
| 215 | + batch_y.append(np.asarray(0.0, dtype=np.float32)) |
| 216 | + |
| 217 | + yield np.asarray(np.squeeze(batch_A_fea)), np.asarray(np.squeeze(batch_B_fea)), np.asarray( |
| 218 | + batch_y).reshape(batch_size * 2, 1) |
| 219 | + |
| 220 | + |
| 221 | +class Loss(tlx.model.WithLoss): |
| 222 | + def __init__(self, net, loss_fn): |
| 223 | + super(Loss, self).__init__(backbone=net, loss_fn=loss_fn) |
| 224 | + |
| 225 | + def forward(self, data, y): |
| 226 | + logits = self.backbone_network(data['n1'], data['n2'], data['label']) |
| 227 | + y = tlx.convert_to_tensor(y) |
| 228 | + loss = self._loss_fn(logits, y) |
| 229 | + return loss |
| 230 | + |
| 231 | + |
| 232 | +class AUCMetric: |
| 233 | + def __init__(self): |
| 234 | + self.true_labels = [] |
| 235 | + self.predicted_scores = [] |
| 236 | + |
| 237 | + def update_state(self, y_true, y_pred): |
| 238 | + self.true_labels.extend(y_true) |
| 239 | + self.predicted_scores.extend(y_pred) |
| 240 | + |
| 241 | + def result(self): |
| 242 | + auc = roc_auc_score(self.true_labels, self.predicted_scores) |
| 243 | + return auc |
| 244 | + |
| 245 | + |
| 246 | +def main(args): |
| 247 | + if str.lower(args.dataset) not in ['acm']: |
| 248 | + raise ValueError('Unknown dataset: {}'.format(args.dataset)) |
| 249 | + if str.lower(args.dataset) == 'acm': |
| 250 | + data = ACM4DHN(root=args.dataset_path, test_ratio=args.test_ratio) |
| 251 | + |
| 252 | + graph = data[0] |
| 253 | + |
| 254 | + G_train = graph['train'] |
| 255 | + G_val = graph['val'] |
| 256 | + G_test = graph['test'] |
| 257 | + |
| 258 | + node_type = len(type2idx) |
| 259 | + num_fea = (args.k_hop + 2) * 4 + node_type |
| 260 | + |
| 261 | + model = DHNModel(num_fea, args.batch_size, args.num_neighbor, name="DHN") |
| 262 | + |
| 263 | + optim = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.drop_rate) |
| 264 | + train_weights = model.trainable_weights |
| 265 | + |
| 266 | + net_with_loss = Loss(model, loss_fn=tlx.losses.sigmoid_cross_entropy) |
| 267 | + net_with_train = TrainOneStep(net_with_loss, optim, train_weights) |
| 268 | + |
| 269 | + tra_auc_metric = AUCMetric() |
| 270 | + val_auc_metric = AUCMetric() |
| 271 | + test_auc_metric = AUCMetric() |
| 272 | + |
| 273 | + best_val_auc = 0 |
| 274 | + for epoch in range(args.n_epoch): |
| 275 | + |
| 276 | + # train |
| 277 | + model.set_train() |
| 278 | + tra_batch_A_fea, tra_batch_B_fea, tra_batch_y = batch_data(G_train, args.batch_size).__next__() |
| 279 | + tra_out = model(tra_batch_A_fea, tra_batch_B_fea, tra_batch_y) |
| 280 | + |
| 281 | + data = { |
| 282 | + "n1": tra_batch_A_fea, |
| 283 | + "n2": tra_batch_B_fea, |
| 284 | + "label": tra_batch_y |
| 285 | + } |
| 286 | + |
| 287 | + tra_loss = net_with_train(data, tra_batch_y) |
| 288 | + tra_auc_metric.update_state(y_true=tra_batch_y, y_pred=tlx.convert_to_numpy(tlx.sigmoid(tra_out))) |
| 289 | + tra_auc = tra_auc_metric.result() |
| 290 | + |
| 291 | + # val |
| 292 | + model.set_eval() |
| 293 | + val_batch_A_fea, val_batch_B_fea, val_batch_y = batch_data(G_val, args.batch_size).__next__() |
| 294 | + val_out = model(val_batch_A_fea, val_batch_B_fea, val_batch_y) |
| 295 | + |
| 296 | + val_auc_metric.update_state(y_true=val_batch_y, y_pred=tlx.convert_to_numpy(tlx.sigmoid(val_out))) |
| 297 | + val_auc = val_auc_metric.result() |
| 298 | + |
| 299 | + print("Epoch [{:0>3d}] ".format(epoch+1)\ |
| 300 | + + " train loss: {:.4f}".format(tra_loss.item())\ |
| 301 | + + " val auc: {:.4f}".format(val_auc)) |
| 302 | + |
| 303 | + if val_auc > best_val_auc: |
| 304 | + best_val_auc = val_auc |
| 305 | + model.save_weights(args.best_model_path+model.name+".npz", format='npz_dict') |
| 306 | + |
| 307 | + model.load_weights(args.best_model_path+model.name+".npz", format='npz_dict') |
| 308 | + # test |
| 309 | + test_batch_A_fea, test_batch_B_fea, test_batch_y = batch_data(G_test, args.batch_size).__next__() |
| 310 | + test_out = model(test_batch_A_fea, test_batch_B_fea, test_batch_y) |
| 311 | + |
| 312 | + test_auc_metric.update_state(y_true=test_batch_y, y_pred=tlx.convert_to_numpy(tlx.sigmoid(test_out))) |
| 313 | + test_auc = test_auc_metric.result() |
| 314 | + print("Test auc: {:.4f}".format(test_auc)) |
| 315 | + |
| 316 | + |
| 317 | +if __name__ == '__main__': |
| 318 | + # parameters setting |
| 319 | + parser = argparse.ArgumentParser() |
| 320 | + parser.add_argument("--test_ratio", type=float, default=0.3, help="ratio of dividing the data set") |
| 321 | + parser.add_argument("--one_hot", type=bool, default=True, help="use one-hot encoding") |
| 322 | + parser.add_argument("--k_hop", type=int, default=2, help="hops of the generated subgraph") |
| 323 | + parser.add_argument("--num_neighbor", type=int, default=5, help="neighbor sample number") |
| 324 | + parser.add_argument("--batch_size", type=int, default=32, help="batch size") |
| 325 | + parser.add_argument("--lr", type=float, default=0.001, help="learning rate") |
| 326 | + parser.add_argument("--n_epoch", type=int, default=100, help="number of epoch") |
| 327 | + parser.add_argument("--drop_rate", type=float, default=0.01, help="drop_rate") |
| 328 | + parser.add_argument('--dataset', type=str, default='acm', help='dataset') |
| 329 | + parser.add_argument("--dataset_path", type=str, default=r"", help='dataset_path') |
| 330 | + parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") |
| 331 | + parser.add_argument("--gpu", type=int, default=-1) |
| 332 | + |
| 333 | + args = parser.parse_args() |
| 334 | + if args.gpu >= 0: |
| 335 | + tlx.set_device("GPU", args.gpu) |
| 336 | + else: |
| 337 | + tlx.set_device("CPU") |
| 338 | + |
| 339 | + main(args) |
0 commit comments