-
Notifications
You must be signed in to change notification settings - Fork 83
[Model] HeCo #204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[Model] HeCo #204
Changes from 33 commits
Commits
Show all changes
86 commits
Select commit
Hold shift + click to select a range
58d5c18
Add files via upload
tanjjjjrr a5b71fe
HeCo
tanjjjjrr 0e096eb
Merge pull request #4 from tanjjjjrr/tan-jiarui-HeCo-3
tanjjjjrr b90c488
new01 HeCo
tanjjjjrr d968cef
Delete gammagl/datasets/data_process.py
tanjjjjrr 7a99322
Update and rename ACM_data_process.py to acm4heco.py
tanjjjjrr c0dfd7a
Add files via upload
tanjjjjrr d760aaf
Delete examples/HeCo_train.py
tanjjjjrr 911eeee
Update HeCo_encoder.py
tanjjjjrr 012539d
Delete gammagl/layers/conv/gcn_forheco.py
tanjjjjrr 041d43f
Update HeCo_encoder.py
tanjjjjrr 6fb8795
Update and rename HeCo_train.py to HeCo_trainer.py
tanjjjjrr 9ccce17
Update HeCo_trainer.py
tanjjjjrr 3fe9752
Update HeCo_trainer.py
tanjjjjrr 193f338
Update acm4heco.py
tanjjjjrr eb12762
Update acm4heco.py
tanjjjjrr 8ab5a6e
Update HeCo_trainer.py
tanjjjjrr 597c829
Update acm4heco.py
tanjjjjrr a7e5551
Delete gammagl/models/contrast_learningHeCo.py
tanjjjjrr 217aaa9
Add files via upload
tanjjjjrr 67b20df
Update __init__.py
tanjjjjrr 253a0a0
Update HeCo_encoder.py
tanjjjjrr 00ef4a5
Update HeCo_encoder.py
tanjjjjrr 87c49a1
Update __init__.py
tanjjjjrr 8796d2a
Update acm4heco.py
tanjjjjrr fa24189
Update HeCo.py
tanjjjjrr 0af4b5e
Update acm4heco.py
tanjjjjrr 06315ef
Delete gammagl/layers/attention/meta_path_attention.py
tanjjjjrr 9d3ba0a
Delete gammagl/layers/attention/network_schema_attention.py
tanjjjjrr 4d6db04
Update HeCo_encoder.py
tanjjjjrr b9c05f1
Update HeCo_encoder.py
tanjjjjrr 0ce3806
Update HeCo_trainer.py
tanjjjjrr 80c8226
Update __init__.py
tanjjjjrr b7b6109
Update HeCo_encoder.py
tanjjjjrr ebcc60c
Update HeCo_trainer.py
tanjjjjrr b0114e8
Update README.md
tanjjjjrr 43b0326
Update HeCo_trainer.py
tanjjjjrr cdd05b7
Update acm4heco.py
tanjjjjrr 93476b1
Update README.md
tanjjjjrr 19f6c22
Add files via upload
tanjjjjrr e970659
Update README.md
tanjjjjrr c879fd5
Update README.md
tanjjjjrr 6d060d8
Update HeCo_trainer.py
tanjjjjrr 203a7ae
Update README.md
tanjjjjrr 473f54d
Update README.md
tanjjjjrr 28cb5cd
Update README.md
tanjjjjrr 08a6c4a
Update README.md
tanjjjjrr d917726
Update HeCo.py
tanjjjjrr 141d79c
Update HeCo_trainer.py
tanjjjjrr 7dfea08
Update HeCo.py
tanjjjjrr 8cf7465
Update acm4heco.py
tanjjjjrr 02babc2
Update HeCo_trainer.py
tanjjjjrr a320a68
Update test_acm.py
tanjjjjrr 3944643
Update test_acm.py
tanjjjjrr 23c5340
Update HeCo_trainer.py
tanjjjjrr 079cef3
Update acm4heco.py
tanjjjjrr 84f9a7f
Update acm4heco.py
tanjjjjrr 5993053
Update acm4heco.py
tanjjjjrr 7656d6b
Update HeCo_trainer.py
tanjjjjrr 2f5df89
Update HeCo_encoder.py
tanjjjjrr 8f74481
Update acm4heco.py
tanjjjjrr 0139e2f
Update HeCo_trainer.py
tanjjjjrr f2c47f7
Update HeCo_encoder.py
tanjjjjrr 213facc
Update HeCo_trainer.py
tanjjjjrr 9e23c13
Rename HeCo_trainer.py to heco_trainer.py
tanjjjjrr 5aedaea
Update and rename HeCo_encoder.py to heco_encoder.py
tanjjjjrr 343556e
Update and rename HeCo.py to heco.py
tanjjjjrr 4fbe63f
Update heco_trainer.py
tanjjjjrr d28e0f4
Delete examples/HeCo/README.md
tanjjjjrr 0d11d52
Delete examples/HeCo/heco_trainer.py
tanjjjjrr 50380ab
Add files via upload
tanjjjjrr e870ed0
Update __init__.py
tanjjjjrr 1da343d
Update __init__.py
tanjjjjrr 32b4d25
Update acm4heco.py
tanjjjjrr 93fd295
Update heco_trainer.py
tanjjjjrr 76d7d2f
Update acm4heco.py
tanjjjjrr cc0b907
Update heco_trainer.py
tanjjjjrr 234ee70
Update acm4heco.py
tanjjjjrr 211efdf
Update acm4heco.py
tanjjjjrr 7936b25
Update heco_trainer.py
tanjjjjrr 6de075d
Update heco_trainer.py
tanjjjjrr ce8e7a9
Update acm4heco.py
tanjjjjrr 54111d7
Update acm4heco.py
tanjjjjrr d2a14c8
Merge branch 'main' into tanjjjjrr-new01-HeCo
gyzhou2000 8b1506d
update
gyzhou2000 42e0524
Merge branch 'tanjjjjrr-new01-HeCo' of https://github.com/tanjjjjrr/G…
gyzhou2000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,265 @@ | ||
""" | ||
@File : HeCo_trainer.py | ||
@Time : | ||
@Author : tan jiarui | ||
""" | ||
import os | ||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
# os.environ['TL_BACKEND'] = 'torch' | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR | ||
import numpy | ||
import random | ||
import argparse | ||
import warnings | ||
import numpy as np | ||
import scipy.sparse as sp | ||
import tensorlayerx as tlx | ||
import tensorlayerx.nn as nn | ||
from sklearn.metrics import f1_score | ||
from sklearn.metrics import roc_auc_score | ||
from sklearn.preprocessing import OneHotEncoder | ||
from gammagl.models.HeCo import HeCo | ||
from tensorlayerx.model import WithLoss | ||
from gammagl.datasets.acm4heco import ACM4HeCo | ||
|
||
import scipy.sparse as sp | ||
#Mention: all 'str' in this code should be replaced with your own file directories | ||
class Contrast(nn.Module): | ||
def __init__(self, hidden_dim, tau, lam): | ||
super(Contrast, self).__init__() | ||
self.proj = nn.Sequential( | ||
nn.Linear(in_features=hidden_dim, out_features=hidden_dim, W_init='he_normal'), | ||
nn.ELU(), | ||
nn.Linear(in_features=hidden_dim, out_features=hidden_dim, W_init='he_normal') | ||
) | ||
self.tau = tau | ||
self.lam = lam | ||
def sim(self, z1, z2): | ||
z1_norm = tlx.l2_normalize(z1, axis=-1) | ||
z2_norm = tlx.l2_normalize(z2, axis=-1) | ||
z1_norm = tlx.reshape(tlx.reduce_mean(z1/z1_norm, axis=-1), (-1, 1)) | ||
z2_norm = tlx.reshape(tlx.reduce_mean(z2/z2_norm, axis=-1), (-1, 1)) | ||
dot_numerator = tlx.matmul(z1, tlx.transpose(z2)) | ||
dot_denominator = tlx.matmul(z1_norm, tlx.transpose(z2_norm)) | ||
sim_matrix = tlx.exp(dot_numerator / dot_denominator / self.tau) | ||
return sim_matrix | ||
|
||
def forward(self , z, pos): | ||
z_mp = z.get("z_mp") | ||
z_sc = z.get("z_sc") | ||
z_proj_mp = self.proj(z_mp) | ||
z_proj_sc = self.proj(z_sc) | ||
matrix_mp2sc = self.sim(z_proj_mp, z_proj_sc) | ||
matrix_sc2mp = tlx.transpose(matrix_mp2sc) | ||
|
||
matrix_mp2sc = matrix_mp2sc / (tlx.reshape(tlx.reduce_sum(matrix_mp2sc, axis=1), (-1, 1)) + 1e-8) | ||
lori_mp = -tlx.reduce_mean(tlx.log(tlx.reduce_sum(tlx.multiply(matrix_mp2sc, pos), axis=-1))) | ||
|
||
matrix_sc2mp = matrix_sc2mp / (tlx.reshape(tlx.reduce_sum(matrix_sc2mp, axis=1), (-1, 1)) + 1e-8) | ||
lori_sc = -tlx.reduce_mean(tlx.log(tlx.reduce_sum(tlx.multiply(matrix_sc2mp, pos), axis=-1))) | ||
return self.lam * lori_mp + (1 - self.lam) * lori_sc | ||
|
||
class Contrast_Loss(WithLoss): | ||
def __init__(self, net, loss_fn): | ||
super(Contrast_Loss, self).__init__(backbone=net, loss_fn=loss_fn) | ||
|
||
def forward(self, datas, pos): | ||
z = self.backbone_network(datas) | ||
loss = self._loss_fn(z, pos) | ||
return loss | ||
|
||
class LogReg(nn.Module): | ||
def __init__(self, ft_in, nb_classes): | ||
super(LogReg, self).__init__() | ||
self.fc = nn.Linear(in_features=ft_in, out_features=nb_classes, W_init='xavier_uniform', b_init='constant') | ||
|
||
def forward(self, seq): | ||
ret = self.fc(seq) | ||
return ret | ||
|
||
|
||
def evaluate(embeds, ratio, idx_train, idx_val, idx_test, label, nb_classes, dataset, lr, wd, isTest=True): | ||
hid_units = tlx.get_tensor_shape(embeds)[1] | ||
train_embs_list = [] | ||
val_embs_list = [] | ||
test_embs_list = [] | ||
label_train_list = [] | ||
label_val_list = [] | ||
label_test_list = [] | ||
for i in range(0, len(idx_train)): | ||
train_embs_list.append(embeds[idx_train[i]]) | ||
label_train_list.append(label[idx_train[i]]) | ||
for i in range(0, len(idx_val)): | ||
val_embs_list.append(embeds[idx_val[i]]) | ||
label_val_list.append(label[idx_val[i]]) | ||
for i in range(0, len(idx_test)): | ||
test_embs_list.append(embeds[idx_test[i]]) | ||
label_test_list.append(label[idx_test[i]]) | ||
train_embs = tlx.stack(train_embs_list, axis = 0) | ||
val_embs = tlx.stack(val_embs_list, axis = 0) | ||
test_embs = tlx.stack(test_embs_list, axis = 0) | ||
label_train = tlx.stack(label_train_list, axis = 0) | ||
label_val = tlx.stack(label_val_list, axis = 0) | ||
label_test = tlx.stack(label_test_list, axis = 0) | ||
train_lbls_idx = tlx.argmax(label_train, axis=-1) | ||
val_lbls_idx = tlx.argmax(label_val, axis=-1) | ||
test_lbls_idx = tlx.argmax(label_test, axis=-1) | ||
accs = [] | ||
micro_f1s = [] | ||
macro_f1s = [] | ||
macro_f1s_val = [] | ||
auc_score_list = [] | ||
#this is the training process for pytorch and paddle(all are recommended) | ||
for _ in range(50): | ||
log = LogReg(hid_units, nb_classes) | ||
#print(lr) | ||
optimizer = tlx.optimizers.Adam(lr=lr, weight_decay=float(wd)) #Adam method | ||
loss = tlx.losses.softmax_cross_entropy_with_logits | ||
log_with_loss = tlx.model.WithLoss(log, loss) | ||
train_one_step = tlx.model.TrainOneStep(log_with_loss, optimizer, log.trainable_weights) | ||
val_accs = [] | ||
test_accs = [] | ||
val_micro_f1s = [] | ||
test_micro_f1s = [] | ||
val_macro_f1s = [] | ||
test_macro_f1s = [] | ||
logits_list = [] | ||
for iter_ in range(200):#set this parameter: 'acm'=200 | ||
log.set_train() | ||
train_one_step(train_embs, train_lbls_idx) | ||
logits = log(val_embs) | ||
preds = tlx.argmax(logits, axis = 1) | ||
acc_val = 0 | ||
for i in range(0, len(val_lbls_idx)): | ||
if(preds[i] == val_lbls_idx[i]): | ||
acc_val = acc_val + 1 | ||
val_acc = acc_val/len(val_lbls_idx) | ||
val_f1_macro = f1_score(val_lbls_idx.cpu(), preds.cpu(), average='macro') | ||
val_f1_micro = f1_score(val_lbls_idx.cpu(), preds.cpu(), average='micro') | ||
val_accs.append(val_acc) | ||
val_macro_f1s.append(val_f1_macro) | ||
val_micro_f1s.append(val_f1_micro) | ||
logits = log(test_embs) | ||
preds = tlx.argmax(logits, axis=1) | ||
acc_test = 0 | ||
for i in range(0, len(test_lbls_idx)): | ||
if(preds[i] == test_lbls_idx[i]): | ||
acc_test = acc_test + 1 | ||
test_acc = acc_test/len(test_lbls_idx) | ||
test_f1_macro = f1_score(test_lbls_idx.cpu(), preds.cpu(), average='macro') | ||
test_f1_micro = f1_score(test_lbls_idx.cpu(), preds.cpu(), average='micro') | ||
test_accs.append(test_acc) | ||
test_macro_f1s.append(test_f1_macro) | ||
test_micro_f1s.append(test_f1_micro) | ||
logits_list.append(logits) | ||
max_iter = val_accs.index(max(val_accs)) | ||
accs.append(test_accs[max_iter]) | ||
max_iter = val_macro_f1s.index(max(val_macro_f1s)) | ||
macro_f1s.append(test_macro_f1s[max_iter]) | ||
macro_f1s_val.append(val_macro_f1s[max_iter]) | ||
max_iter = val_micro_f1s.index(max(val_micro_f1s)) | ||
micro_f1s.append(test_micro_f1s[max_iter]) | ||
best_logits = logits_list[max_iter] | ||
best_proba = tlx.softmax(best_logits, axis=1) | ||
auc_score_list.append(roc_auc_score(y_true=tlx.convert_to_numpy(test_lbls_idx), | ||
y_score=tlx.convert_to_numpy(best_proba), | ||
multi_class='ovr' | ||
)) | ||
if isTest: | ||
print("\t[Classification] Macro-F1_mean: {:.4f} var: {:.4f} Micro-F1_mean: {:.4f} var: {:.4f} auc: {:.4f} " | ||
.format(np.mean(macro_f1s), | ||
np.std(macro_f1s), | ||
np.mean(micro_f1s), | ||
np.std(micro_f1s), | ||
np.mean(auc_score_list), | ||
np.std(auc_score_list) | ||
) | ||
) | ||
else: | ||
return np.mean(macro_f1s_val), np.mean(macro_f1s) | ||
|
||
def main(args): | ||
dataset = ACM4HeCo(args.LocalFilePath) | ||
graph = dataset[0] | ||
nei_index = graph['paper'].nei | ||
feats = graph['feat_p/a/s'] | ||
mps = graph['metapath'] | ||
pos = graph['pos_set_for_contrast'] | ||
label = graph['paper'].label | ||
idx_train = graph['train'] | ||
idx_val = graph['val'] | ||
idx_test = graph['test'] | ||
isTest=True | ||
datas = { | ||
"feats": feats, | ||
"mps": mps, | ||
"nei_index": nei_index, | ||
} | ||
nb_classes = tlx.get_tensor_shape(label)[1] | ||
feats_dim_list = [tlx.get_tensor_shape(i)[1] for i in feats] | ||
P = int(len(mps)) | ||
print("seed ",args.seed) | ||
print("Dataset: ", args.dataset) | ||
print("The number of meta-paths: ", P) | ||
|
||
model = HeCo(args.hidden_dim, feats_dim_list, args.feat_drop, args.attn_drop, | ||
P, args.sample_rate, args.nei_num) | ||
optimizer = tlx.optimizers.Adam(lr=0.008, weight_decay=args.l2_coef) | ||
contrast_loss = Contrast(args.hidden_dim, args.tau, args.lam) | ||
cnt_wait = 0 | ||
best = 1e9 | ||
best_t = 0 | ||
cnt = 0 | ||
loss_func = Contrast_Loss(model, contrast_loss) | ||
weights_to_train = model.trainable_weights+contrast_loss.trainable_weights | ||
train_one_step = tlx.model.TrainOneStep(loss_func, optimizer, weights_to_train) | ||
for epoch in range(args.nb_epochs): #args.nb_epochs | ||
loss = train_one_step(datas, pos) | ||
print("loss ", loss) | ||
best = loss | ||
best_t = epoch | ||
model.save_weights(model.name+".npz", format='npz_dict') | ||
print('Loading {}th epoch'.format(best_t)) | ||
model.load_weights(model.name+".npz", format='npz_dict') | ||
model.set_eval() | ||
os.remove(model.name+".npz") | ||
embeds = model.get_embeds(feats, mps) | ||
# To evaluate the HeCo model with different numbers of training labels, that is 20,40 and 60, which is indicated in the essay of HeCo | ||
for i in range(len(idx_train)): | ||
evaluate(embeds, args.ratio[i], idx_train[i], idx_val[i], idx_test[i], label, nb_classes, args.dataset, args.eva_lr, args.eva_wd) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--turn', type=int, default=0) | ||
parser.add_argument('--dataset', type=str, default="acm") | ||
parser.add_argument('--ratio', type=int, default=[20, 40, 60]) | ||
parser.add_argument('--gpu', type=int, default=0) | ||
parser.add_argument('--seed', type=int, default=0) | ||
parser.add_argument('--hidden_dim', type=int, default=64) | ||
parser.add_argument('--nb_epochs', type=int, default=10000) | ||
|
||
# The parameters of evaluation | ||
parser.add_argument('--eva_lr', type=float, default=0.05) | ||
parser.add_argument('--eva_wd', type=float, default=0) | ||
|
||
# The parameters of learning process | ||
parser.add_argument('--patience', type=int, default=5) | ||
parser.add_argument('--lr', type=float, default=0.0001) # 0.0008 | ||
parser.add_argument('--l2_coef', type=float, default=0.0) | ||
|
||
# model-specific parameters | ||
parser.add_argument('--tau', type=float, default=0.8) | ||
parser.add_argument('--feat_drop', type=float, default=0.3) | ||
parser.add_argument('--attn_drop', type=float, default=0.5) | ||
parser.add_argument('--sample_rate', nargs='+', type=int, default=[7, 1]) | ||
parser.add_argument('--lam', type=float, default=0.5) | ||
|
||
args, _ = parser.parse_known_args() | ||
args.type_num = [4019, 7167, 60] # the number of every node type | ||
args.nei_num = 2 # the number of neighbors' types | ||
args.LocalFilePath = "path" #With your own local directory | ||
own_str = args.dataset | ||
|
||
main(args) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Self-supervised Heterogeneous Graph Neural Network with Co-contrastive Learning | ||
|
||
- Paper link: [https://arxiv.org/abs/2105.09111](https://arxiv.org/abs/2105.09111) | ||
- Author's code repo: [https://github.com/liun-online/HeCo](https://github.com/liun-online/HeCo) | ||
|
||
# Dataset Statics | ||
| Dataset | # Nodes_paper | # Nodes_author | # Nodes_subject | | ||
|----------|---------------|----------------|-----------------| | ||
| ACM | 4019 | 7167 | 60 | | ||
|
||
Refer to [ACM](https://github.com/AndyJZhao/NSHE/tree/master/data/acm). | ||
|
||
Results For ACM | ||
------- | ||
- Ma-F1 | ||
|
||
| number of train_labels | Paper | Our(tf) | Our(pd) | Our(torch) | | ||
|-------------------------|----------|----------|----------|------------| | ||
| 20 | 88.56±0.8| 79.1±0.4 | 81.7±0.4 | 81.6±0.3 | | ||
| 40 | 87.61±0.5| 83.4±0.3 | 85.4±0.1 | 85.6±0.2 | | ||
| 60 | 89.04±0.5| 81.4±0.4 | 83.4±0.4 | 83.3±0.3 | | ||
|
||
- Mi-F1 | ||
|
||
| number of train_labels | Paper | Our(tf) | Our(pd) | Our(torch) | | ||
|-------------------------|----------|----------|----------|------------| | ||
| 20 | 88.13±0.8| 79.1±0.4 | 80.44±0.8| 80.4±0.7 | | ||
| 40 | 87.45±0.5| 83.4±0.3 | 85.43±0.1| 85.4±0.2 | | ||
| 60 | 88.71±0.5| 79.4±0.4 | 82.2±0.5 | 82.4±0.6 | | ||
|
||
|
||
- AUC | ||
|
||
| number of train_labels | Paper | Our(tf) | Our(pd) | Our(torch) | | ||
|-------------------------|----------|----------|----------|------------| | ||
| 20 | 96.49±0.3| 89.8±0.4 | 92.8±0.4 | 92.8±0.3 | | ||
| 40 | 96.4±0.4 | 92.4±0.3 | 95.2±0.2 | 95.4±0.3 | | ||
| 60 | 96.55±0.3| 89.4±0.4 | 93.6±0.4 | 93.7±0.3 | | ||
|
||
For TensorFlow runs more slowly than paddlepaddle and pytorch, thus pd and torch are more recommended. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gyzhou2000 If this performance is far beyond expected, as the performance is about 5% lower than the original paper.