Skip to content

Commit 340d5f5

Browse files
LibertyZXYgyzhou2000gyzhou2000
authored
[Model] add dhn (#207)
* [model] add dhn * Modify the dhn code and use MessagePassing * Add dataset and test code * Add dataset and test code * modify the code and add test code * update * update * update --------- Co-authored-by: Guangyu Zhou <77875480+gyzhou2000@users.noreply.github.com> Co-authored-by: gyzhou2000 <gyzhou2000@gmail.com> Co-authored-by: gyzhou2000 <gyzhou@bupt.edu.cn>
1 parent 1a754c5 commit 340d5f5

File tree

14 files changed

+642
-6
lines changed

14 files changed

+642
-6
lines changed

docs/source/api/gammagl.utils.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ gammagl.utils
2727
gammagl.utils.to_scipy_sparse_matrix
2828
gammagl.utils.read_embeddings
2929
gammagl.utils.homophily
30-
gammagl.utils.get_train_val_test_split
30+
gammagl.utils.get_train_val_test_split
31+
gammagl.utils.find_all_simple_paths

examples/dhn/README.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Distance encoding based Heterogeneous graph neural Network (DHN)
2+
- Paper link: [https://ieeexplore.ieee.org/document/10209229](https://ieeexplore.ieee.org/document/10209229)
3+
- Author's code repo: [https://github.com/BUPT-GAMMA/HDE](https://github.com/BUPT-GAMMA/HDE)
4+
5+
## Dataset Statics
6+
| Dataset | # Nodes | # Edges |
7+
|----------|---------|---------|
8+
| acm | 3908 | 4500 |
9+
10+
## Results
11+
```bash
12+
TL_BACKEND="torch" python dhn_trainer.py --test_ratio 0.3 --one_hot True --k_hop 2 --num_neighbor 5 --batch_size 32 --lr 0.001 --n_epoch 100 --drop_rate 0.01 --dataset 'acm'
13+
```
14+
15+
| Dataset | Paper(AUC) | Our(th)(AUC) |
16+
| -------- | ----- | ----------- |
17+
| acm | 95.07 | 95.54±0.18 |

examples/dhn/dhn_trainer.py

+339
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
import os
2+
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3+
# os.environ['TL_BACKEND'] = 'torch'
4+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
5+
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
6+
7+
import argparse
8+
import random
9+
import numpy as np
10+
import tensorlayerx as tlx
11+
from tensorlayerx.model import TrainOneStep
12+
from sklearn.metrics import roc_auc_score
13+
from gammagl.models import DHNModel
14+
from gammagl.datasets import ACM4DHN
15+
from gammagl.utils import k_hop_subgraph, find_all_simple_paths
16+
17+
18+
type2idx = {
19+
'M': 0,
20+
'A': 1,
21+
# 'C': 2,
22+
# 'T': 3
23+
}
24+
25+
26+
def dist_encoder(src, dest, G, k_hop):
27+
if (G.size(1) == 0):
28+
paths = []
29+
else:
30+
paths = find_all_simple_paths(G, src, dest, k_hop + 2)
31+
32+
node_type = len(type2idx)
33+
cnt = [k_hop + 1] * node_type # Default truncation for max_spd exceeded
34+
for path in paths:
35+
res = [0] * node_type
36+
for i in path:
37+
if i >= 0:
38+
res[type2idx['M']] += 1
39+
else:
40+
res[type2idx['A']] += 1
41+
42+
for k in range(node_type):
43+
cnt[k] = min(cnt[k], res[k])
44+
45+
# Generate one-hot encoding
46+
if args.one_hot:
47+
one_hot_list = [np.eye(k_hop + 2, dtype=np.float64)[cnt[i]]
48+
for i in range(node_type)]
49+
return np.concatenate(one_hot_list)
50+
return cnt
51+
52+
53+
def type_encoder(node):
54+
node_type = len(type2idx)
55+
res = [0] * node_type
56+
if node.item() >= 0:
57+
res[type2idx['M']] = 1.0
58+
else:
59+
res[type2idx['A']] = 1.0
60+
return res
61+
62+
63+
mini_batch = []
64+
fea_batch = []
65+
66+
67+
def gen_fea_batch(G, root, fea_dict, k_hop):
68+
fea_batch = []
69+
mini_batch.append(root)
70+
71+
a = [0] * (k_hop + 2) * 4 + type_encoder(root)
72+
73+
node_type = len(type2idx)
74+
num_fea = (k_hop + 2) * 4 + node_type
75+
fea_batch.append(np.asarray(a,
76+
dtype=np.float32
77+
).reshape(-1, num_fea)
78+
)
79+
80+
# 1-order neighbor sampling
81+
ns_1 = []
82+
src, dst = G
83+
for node in mini_batch[-1]:
84+
if node.item() >= 0:
85+
neighbors_mask = src == node
86+
else:
87+
neighbors_mask = dst == node
88+
neighbors = list(tlx.convert_to_numpy(dst[neighbors_mask]))
89+
neighbors.append(node.item())
90+
random_choice_list = np.random.choice(neighbors, args.num_neighbor, replace=True)
91+
ns_1.append(random_choice_list.tolist())
92+
ns_1 = tlx.convert_to_tensor(ns_1)
93+
mini_batch.append(ns_1[0])
94+
95+
de_1 = [
96+
np.concatenate([fea_dict[ns_1[0][i].item()], np.asarray(type_encoder(ns_1[0][i]))], axis=0)
97+
for i in range(0, ns_1[0].shape[0])
98+
]
99+
100+
fea_batch.append(np.asarray(de_1,
101+
dtype=np.float32).reshape(1, -1)
102+
)
103+
104+
# 2-order neighbor sampling
105+
ns_2 = []
106+
for node in mini_batch[-1]:
107+
if node.item() >= 0:
108+
neighbors_mask = src == node
109+
else:
110+
neighbors_mask = dst == node
111+
neighbors = list(tlx.convert_to_numpy(dst[neighbors_mask]))
112+
neighbors.append(node.item())
113+
random_choice_list = np.random.choice(neighbors, args.num_neighbor, replace=True)
114+
ns_2.append(random_choice_list.tolist())
115+
ns_2 = tlx.convert_to_tensor(ns_2)
116+
117+
de_2 = []
118+
for i in range(len(ns_2)):
119+
tmp = []
120+
for j in range(len(ns_2[0])):
121+
tmp.append(
122+
np.concatenate([fea_dict[ns_2[i][j].item()], np.asarray(type_encoder(ns_2[i][j]))], axis=0)
123+
)
124+
de_2.append(tmp)
125+
126+
fea_batch.append(np.asarray(de_2,
127+
dtype=np.float32).reshape(1, -1)
128+
)
129+
130+
return np.concatenate(fea_batch, axis=1)
131+
132+
133+
def subgraph_sampling_with_DE_node_pair(G, node_pair, k_hop=2):
134+
[A, B] = node_pair
135+
136+
edge_index = tlx.concat([G['M', 'MA', 'A'].edge_index, reversed(G['M', 'MA', 'A'].edge_index)], axis=1)
137+
138+
# Find k-hop subgraphs of A and B
139+
sub_G_for_AB = k_hop_subgraph([A, B], k_hop, edge_index)
140+
141+
# Remove edges using Boolean indexes
142+
# Note: Just remove the edges, the points remain
143+
edge_index_np = tlx.convert_to_numpy(sub_G_for_AB[1])
144+
remove_indices = tlx.convert_to_tensor([
145+
((edge_index_np[0, i] == A) & (edge_index_np[1, i] == B)) | (
146+
(edge_index_np[0, i] == B) & (edge_index_np[1, i] == A))
147+
for i in range(sub_G_for_AB[1].shape[1])
148+
])
149+
remove_indices = tlx.convert_to_numpy(remove_indices)
150+
sub_G_index = sub_G_for_AB[1][:, ~remove_indices]
151+
152+
sub_G_nodes = set(np.unique(tlx.convert_to_numpy(sub_G_for_AB[0]))) | set(
153+
np.unique(tlx.convert_to_numpy(sub_G_for_AB[1]))) # Gets the points in the graph
154+
sub_G_nodes = tlx.convert_to_tensor(list(sub_G_nodes))
155+
156+
# Distance from all points in the subgraph to the node pair
157+
SPD_based_on_node_pair = {}
158+
for node in sub_G_nodes:
159+
tmpA = dist_encoder(A, node, sub_G_index, k_hop)
160+
tmpB = dist_encoder(B, node, sub_G_index, k_hop)
161+
162+
SPD_based_on_node_pair[node.item()] = np.concatenate([tmpA, tmpB], axis=0)
163+
164+
A_fea_batch = gen_fea_batch(sub_G_index, A,
165+
SPD_based_on_node_pair, k_hop)
166+
B_fea_batch = gen_fea_batch(sub_G_index, B,
167+
SPD_based_on_node_pair, k_hop)
168+
169+
return A_fea_batch, B_fea_batch
170+
171+
172+
def batch_data(G, batch_size):
173+
edge_index = G['M', 'MA', 'A'].edge_index
174+
nodes = set(tlx.convert_to_tensor(np.unique(tlx.convert_to_numpy(edge_index[0])))) | set(
175+
tlx.convert_to_tensor(np.unique(tlx.convert_to_numpy(edge_index[1]))))
176+
177+
nodes_list = []
178+
for node in nodes:
179+
nodes_list.append(node.item())
180+
181+
num_batch = int(len(edge_index[0]) / batch_size)
182+
183+
# Shuffle the order of the edges
184+
edge_index_np = tlx.convert_to_numpy(edge_index)
185+
permutation = np.random.permutation(edge_index_np.shape[1]) # Generate a randomly arranged index
186+
edge_index_np = edge_index_np[:, permutation] # Use this permutation index to scramble edge_index
187+
edge_index = tlx.convert_to_tensor(edge_index_np)
188+
189+
for idx in range(num_batch):
190+
batch_edge = edge_index[:, idx * batch_size:(idx + 1) * batch_size] # Take out batch_size edges
191+
batch_label = [1.0] * batch_size
192+
193+
batch_A_fea = []
194+
batch_B_fea = []
195+
batch_x = []
196+
batch_y = []
197+
198+
i = 0
199+
for by in batch_label:
200+
bx = batch_edge[:, i:i + 1]
201+
202+
# Positive sample
203+
posA, posB = subgraph_sampling_with_DE_node_pair(G, bx, k_hop=args.k_hop)
204+
batch_A_fea.append(posA)
205+
batch_B_fea.append(posB)
206+
batch_y.append(np.asarray(by, dtype=np.float32))
207+
208+
# Negative sample
209+
neg_tmpB_id = random.choice(nodes_list)
210+
node_pair = tlx.convert_to_tensor([[bx[0].item()], [neg_tmpB_id]])
211+
212+
negA, negB = subgraph_sampling_with_DE_node_pair(G, node_pair, k_hop=args.k_hop)
213+
batch_A_fea.append(negA)
214+
batch_B_fea.append(negB)
215+
batch_y.append(np.asarray(0.0, dtype=np.float32))
216+
217+
yield np.asarray(np.squeeze(batch_A_fea)), np.asarray(np.squeeze(batch_B_fea)), np.asarray(
218+
batch_y).reshape(batch_size * 2, 1)
219+
220+
221+
class Loss(tlx.model.WithLoss):
222+
def __init__(self, net, loss_fn):
223+
super(Loss, self).__init__(backbone=net, loss_fn=loss_fn)
224+
225+
def forward(self, data, y):
226+
logits = self.backbone_network(data['n1'], data['n2'], data['label'])
227+
y = tlx.convert_to_tensor(y)
228+
loss = self._loss_fn(logits, y)
229+
return loss
230+
231+
232+
class AUCMetric:
233+
def __init__(self):
234+
self.true_labels = []
235+
self.predicted_scores = []
236+
237+
def update_state(self, y_true, y_pred):
238+
self.true_labels.extend(y_true)
239+
self.predicted_scores.extend(y_pred)
240+
241+
def result(self):
242+
auc = roc_auc_score(self.true_labels, self.predicted_scores)
243+
return auc
244+
245+
246+
def main(args):
247+
if str.lower(args.dataset) not in ['acm']:
248+
raise ValueError('Unknown dataset: {}'.format(args.dataset))
249+
if str.lower(args.dataset) == 'acm':
250+
data = ACM4DHN(root=args.dataset_path, test_ratio=args.test_ratio)
251+
252+
graph = data[0]
253+
254+
G_train = graph['train']
255+
G_val = graph['val']
256+
G_test = graph['test']
257+
258+
node_type = len(type2idx)
259+
num_fea = (args.k_hop + 2) * 4 + node_type
260+
261+
model = DHNModel(num_fea, args.batch_size, args.num_neighbor, name="DHN")
262+
263+
optim = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.drop_rate)
264+
train_weights = model.trainable_weights
265+
266+
net_with_loss = Loss(model, loss_fn=tlx.losses.sigmoid_cross_entropy)
267+
net_with_train = TrainOneStep(net_with_loss, optim, train_weights)
268+
269+
tra_auc_metric = AUCMetric()
270+
val_auc_metric = AUCMetric()
271+
test_auc_metric = AUCMetric()
272+
273+
best_val_auc = 0
274+
for epoch in range(args.n_epoch):
275+
276+
# train
277+
model.set_train()
278+
tra_batch_A_fea, tra_batch_B_fea, tra_batch_y = batch_data(G_train, args.batch_size).__next__()
279+
tra_out = model(tra_batch_A_fea, tra_batch_B_fea, tra_batch_y)
280+
281+
data = {
282+
"n1": tra_batch_A_fea,
283+
"n2": tra_batch_B_fea,
284+
"label": tra_batch_y
285+
}
286+
287+
tra_loss = net_with_train(data, tra_batch_y)
288+
tra_auc_metric.update_state(y_true=tra_batch_y, y_pred=tlx.convert_to_numpy(tlx.sigmoid(tra_out)))
289+
tra_auc = tra_auc_metric.result()
290+
291+
# val
292+
model.set_eval()
293+
val_batch_A_fea, val_batch_B_fea, val_batch_y = batch_data(G_val, args.batch_size).__next__()
294+
val_out = model(val_batch_A_fea, val_batch_B_fea, val_batch_y)
295+
296+
val_auc_metric.update_state(y_true=val_batch_y, y_pred=tlx.convert_to_numpy(tlx.sigmoid(val_out)))
297+
val_auc = val_auc_metric.result()
298+
299+
print("Epoch [{:0>3d}] ".format(epoch+1)\
300+
+ " train loss: {:.4f}".format(tra_loss.item())\
301+
+ " val auc: {:.4f}".format(val_auc))
302+
303+
if val_auc > best_val_auc:
304+
best_val_auc = val_auc
305+
model.save_weights(args.best_model_path+model.name+".npz", format='npz_dict')
306+
307+
model.load_weights(args.best_model_path+model.name+".npz", format='npz_dict')
308+
# test
309+
test_batch_A_fea, test_batch_B_fea, test_batch_y = batch_data(G_test, args.batch_size).__next__()
310+
test_out = model(test_batch_A_fea, test_batch_B_fea, test_batch_y)
311+
312+
test_auc_metric.update_state(y_true=test_batch_y, y_pred=tlx.convert_to_numpy(tlx.sigmoid(test_out)))
313+
test_auc = test_auc_metric.result()
314+
print("Test auc: {:.4f}".format(test_auc))
315+
316+
317+
if __name__ == '__main__':
318+
# parameters setting
319+
parser = argparse.ArgumentParser()
320+
parser.add_argument("--test_ratio", type=float, default=0.3, help="ratio of dividing the data set")
321+
parser.add_argument("--one_hot", type=bool, default=True, help="use one-hot encoding")
322+
parser.add_argument("--k_hop", type=int, default=2, help="hops of the generated subgraph")
323+
parser.add_argument("--num_neighbor", type=int, default=5, help="neighbor sample number")
324+
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
325+
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
326+
parser.add_argument("--n_epoch", type=int, default=100, help="number of epoch")
327+
parser.add_argument("--drop_rate", type=float, default=0.01, help="drop_rate")
328+
parser.add_argument('--dataset', type=str, default='acm', help='dataset')
329+
parser.add_argument("--dataset_path", type=str, default=r"", help='dataset_path')
330+
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
331+
parser.add_argument("--gpu", type=int, default=-1)
332+
333+
args = parser.parse_args()
334+
if args.gpu >= 0:
335+
tlx.set_device("GPU", args.gpu)
336+
else:
337+
tlx.set_device("CPU")
338+
339+
main(args)

gammagl/datasets/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .facebook import FacebookPagePage
2424
from .acm4heco import ACM4HeCo
2525
from .yelp import Yelp
26+
from .acm4dhn import ACM4DHN
2627

2728
__all__ = [
2829
'ACM4HeCo',
@@ -48,7 +49,8 @@
4849
'MoleculeNet',
4950
'FacebookPagePage',
5051
'NGSIM_US_101',
51-
'Yelp'
52+
'Yelp',
53+
'ACM4DHN'
5254
]
5355

5456
classes = __all__

0 commit comments

Comments
 (0)