-
Notifications
You must be signed in to change notification settings - Fork 83
Add glnn #205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add glnn #205
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
71c4a99
Add GLNN to examples
UNAOUN 67c6ced
add glnn
UNAOUN ef4924e
update amazon
UNAOUN dc0511c
update amazon
UNAOUN ec6d99d
update get_split.py
UNAOUN 080b4d5
add test_get_split.py
UNAOUN 0d3d7d6
Merge remote-tracking branch 'upstream/main' into add-glnn
9c9f644
update readme
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Graph-less Neural Networks (GLNN) | ||
|
||
- Paper link: [https://arxiv.org/pdf/2110.08727](https://arxiv.org/pdf/2110.08727) | ||
- Author's code repo: [https://github.com/snap-research/graphless-neural-networks](https://github.com/snap-research/graphless-neural-networks) | ||
|
||
# Dataset Statics | ||
| Dataset | # Nodes | # Edges | # Classes | | ||
| -------- | ------- | ------- | --------- | | ||
| Cora | 2,708 | 10,556 | 7 | | ||
| Citeseer | 3,327 | 9,228 | 6 | | ||
| Pubmed | 19,717 | 88,651 | 3 | | ||
| Computers| 13,752 | 491,722 | 10 | | ||
| Photo | 7,650 | 238,162 | 8 | | ||
|
||
Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid), [Amazon](https://gammagl.readthedocs.io/en/latest/generated/gammagl.datasets.Amazon.html#gammagl.datasets.Amazon). | ||
|
||
# Results | ||
|
||
- Available dataset: "cora", "citeseer", "pubmed", "computers", "photo" | ||
- Available teacher: "SAGE", "GCN", "GAT", "APPNP", "MLP" | ||
|
||
```bash | ||
TL_BACKEND="tensorflow" python train_teacher.py --dataset cora --teacher SAGE | ||
TL_BACKEND="tensorflow" python train_student.py --dataset cora --teacher SAGE | ||
TL_BACKEND="torch" python train_teacher.py --dataset cora --teacher SAGE | ||
TL_BACKEND="torch" python train_student.py --dataset cora --teacher SAGE | ||
TL_BACKEND="paddle" python train_teacher.py --dataset cora --teacher SAGE | ||
TL_BACKEND="paddle" python train_student.py --dataset cora --teacher SAGE | ||
TL_BACKEND="mindspore" python train_teacher.py --dataset cora --teacher SAGE | ||
TL_BACKEND="mindspore" python train_student.py --dataset cora --teacher SAGE | ||
``` | ||
|
||
| Dataset | Paper | Our(tf) | Our(th) | Our(pd) | Our(ms) | | ||
| --------- | ---------- | ---------- | ---------- | ---------- | ---------- | | ||
| Cora | 80.54±1.35 | 80.94±0.31 | 80.84±0.30 | 80.90±0.21 | 81.04±0.30 | | ||
| Citeseer | 71.77±2.01 | 70.74±0.87 | 71.34±0.55 | 71.18±1.20 | 70.58±1.14 | | ||
| Pubmed | 75.42±2.31 | 77.90±0.07 | 77.88±0.23 | 77.78±0.19 | 77.78±0.13 | | ||
| Computers | 83.03±1.87 | 81.51±0.60 | 81.73±0.48 | 81.46±0.72 | 81.24±1.27 | | ||
| Photo | 92.11±1.08 | 92.05±0.56 | 91.92±0.53 | 92.00±0.55 | 91.77±0.91 | | ||
|
||
- The model performance is the average of 5 tests |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
global: | ||
num_layers: 2 | ||
hidden_dim: 128 | ||
learning_rate: 0.01 | ||
|
||
cora: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.01 | ||
weight_decay: 0.005 | ||
dropout_ratio: 0.6 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 | ||
|
||
|
||
citeseer: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.01 | ||
weight_decay: 0.001 | ||
dropout_ratio: 0.1 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 | ||
|
||
pubmed: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.005 | ||
weight_decay: 0 | ||
dropout_ratio: 0.4 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 | ||
|
||
computers: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.001 | ||
weight_decay: 0.002 | ||
dropout_ratio: 0.3 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 | ||
|
||
photo: | ||
SAGE: | ||
fan_out: 5,5 | ||
learning_rate: 0.01 | ||
dropout_ratio: 0 | ||
weight_decay: 0.0005 | ||
|
||
GCN: | ||
hidden_dim: 64 | ||
dropout_ratio: 0.8 | ||
weight_decay: 0.001 | ||
|
||
MLP: | ||
learning_rate: 0.005 | ||
weight_decay: 0.002 | ||
dropout_ratio: 0.3 | ||
|
||
GAT: | ||
dropout_ratio: 0.6 | ||
weight_decay: 0.01 | ||
num_heads: 8 | ||
attn_dropout_ratio: 0.3 | ||
|
||
APPNP: | ||
dropout_ratio: 0.5 | ||
weight_decay: 0.01 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# !/usr/bin/env python | ||
# -*- encoding: utf-8 -*- | ||
|
||
import os | ||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
# os.environ['TL_BACKEND'] = 'torch' | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR | ||
|
||
import yaml | ||
import argparse | ||
import tensorlayerx as tlx | ||
from gammagl.datasets import Planetoid, Amazon | ||
from gammagl.models import MLP | ||
from gammagl.utils import mask_to_index, get_train_val_test_split | ||
from tensorlayerx.model import TrainOneStep, WithLoss | ||
|
||
|
||
class SemiSpvzLoss(WithLoss): | ||
def __init__(self, net, loss_fn): | ||
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn) | ||
|
||
def forward(self, data, teacher_logits): | ||
student_logits = self.backbone_network(data['x']) | ||
train_y = tlx.gather(data['y'], data['t_idx']) | ||
train_teacher_logits = tlx.gather(teacher_logits, data['t_idx']) | ||
train_student_logits = tlx.gather(student_logits, data['t_idx']) | ||
loss = self._loss_fn(train_y, train_student_logits, train_teacher_logits, args.lamb) | ||
return loss | ||
|
||
|
||
def get_training_config(config_path, model_name, dataset): | ||
with open(config_path, "r") as conf: | ||
full_config = yaml.load(conf, Loader=yaml.FullLoader) | ||
dataset_specific_config = full_config["global"] | ||
model_specific_config = full_config[dataset][model_name] | ||
|
||
if model_specific_config is not None: | ||
specific_config = dict(dataset_specific_config, **model_specific_config) | ||
else: | ||
specific_config = dataset_specific_config | ||
|
||
specific_config["model_name"] = model_name | ||
return specific_config | ||
|
||
|
||
def calculate_acc(logits, y, metrics): | ||
metrics.update(logits, y) | ||
rst = metrics.result() | ||
metrics.reset() | ||
return rst | ||
|
||
|
||
def kl_divergence(teacher_logits, student_logits): | ||
# convert logits to probabilities | ||
teacher_probs = tlx.softmax(teacher_logits) | ||
student_probs = tlx.softmax(student_logits) | ||
# compute KL divergence | ||
kl_div = tlx.reduce_sum(teacher_probs * (tlx.log(teacher_probs+1e-10) - tlx.log(student_probs+1e-10)), axis=-1) | ||
return tlx.reduce_mean(kl_div) | ||
|
||
|
||
def cal_mlp_loss(labels, student_logits, teacher_logits, lamb): | ||
loss_l = tlx.losses.softmax_cross_entropy_with_logits(student_logits, labels) | ||
loss_t = kl_divergence(teacher_logits, student_logits) | ||
return lamb * loss_l + (1 - lamb) * loss_t | ||
|
||
|
||
def train_student(args): | ||
# load datasets | ||
if str.lower(args.dataset) not in ['cora','pubmed','citeseer','computers','photo']: | ||
raise ValueError('Unknown dataset: {}'.format(args.dataset)) | ||
if args.dataset in ['cora', 'pubmed', 'citeseer']: | ||
dataset = Planetoid(args.dataset_path, args.dataset) | ||
elif args.dataset in ['computers', 'photo']: | ||
dataset = Amazon(args.dataset_path, args.dataset) | ||
graph = dataset[0] | ||
|
||
# load teacher_logits from .npy file | ||
teacher_logits = tlx.files.load_npy_to_any(path = r'./', name = f'{args.dataset}_{args.teacher}_logits.npy') | ||
teacher_logits = tlx.ops.convert_to_tensor(teacher_logits) | ||
|
||
# for mindspore, it should be passed into node indices | ||
if args.dataset in ['cora', 'pubmed', 'citeseer']: | ||
train_idx = mask_to_index(graph.train_mask) | ||
test_idx = mask_to_index(graph.test_mask) | ||
val_idx = mask_to_index(graph.val_mask) | ||
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0) | ||
elif args.dataset in ['computers', 'photo']: | ||
train_mask, val_mask, test_mask = get_train_val_test_split(dataset, train_per_class=20, val_per_class=30) | ||
train_idx = mask_to_index(train_mask) | ||
val_idx = mask_to_index(val_mask) | ||
test_idx = mask_to_index(test_mask) | ||
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0) | ||
|
||
net = MLP(in_channels=dataset.num_node_features, | ||
hidden_channels=conf["hidden_dim"], | ||
out_channels=dataset.num_classes, | ||
num_layers=conf["num_layers"], | ||
act=tlx.nn.ReLU(), | ||
norm=None, | ||
dropout=float(conf["dropout_ratio"])) | ||
|
||
optimizer = tlx.optimizers.Adam(lr=conf["learning_rate"], weight_decay=conf["weight_decay"]) | ||
metrics = tlx.metrics.Accuracy() | ||
train_weights = net.trainable_weights | ||
|
||
loss_func = SemiSpvzLoss(net, cal_mlp_loss) | ||
train_one_step = TrainOneStep(loss_func, optimizer, train_weights) | ||
|
||
data = { | ||
"x": graph.x, | ||
"y": graph.y, | ||
"train_idx": train_idx, | ||
"test_idx": test_idx, | ||
"val_idx": val_idx, | ||
"t_idx": t_idx | ||
} | ||
|
||
best_val_acc = 0 | ||
for epoch in range(args.n_epoch): | ||
net.set_train() | ||
train_loss = train_one_step(data, teacher_logits) | ||
net.set_eval() | ||
logits = net(data['x']) | ||
val_logits = tlx.gather(logits, data['val_idx']) | ||
val_y = tlx.gather(data['y'], data['val_idx']) | ||
val_acc = calculate_acc(val_logits, val_y, metrics) | ||
|
||
print("Epoch [{:0>3d}] ".format(epoch+1)\ | ||
+ " train loss: {:.4f}".format(train_loss.item())\ | ||
+ " val acc: {:.4f}".format(val_acc)) | ||
|
||
# save best model on evaluation set | ||
if val_acc > best_val_acc: | ||
best_val_acc = val_acc | ||
net.save_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict') | ||
|
||
net.load_weights(args.best_model_path+args.dataset+"_"+args.teacher+"_MLP.npz", format='npz_dict') | ||
net.set_eval() | ||
logits = net(data['x']) | ||
test_logits = tlx.gather(logits, data['test_idx']) | ||
test_y = tlx.gather(data['y'], data['test_idx']) | ||
test_acc = calculate_acc(test_logits, test_y, metrics) | ||
print("Test acc: {:.4f}".format(test_acc)) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
# parameters setting | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_config_path",type=str,default="./train.conf.yaml",help="path to modelconfigeration") | ||
parser.add_argument("--teacher", type=str, default="SAGE", help="teacher model") | ||
parser.add_argument("--lamb", type=float, default=0, help="parameter balances loss from hard labels and teacher outputs") | ||
parser.add_argument("--n_epoch", type=int, default=200, help="number of epoch") | ||
parser.add_argument('--dataset', type=str, default="cora", help="dataset") | ||
parser.add_argument("--dataset_path", type=str, default=r'./data', help="path to save dataset") | ||
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model") | ||
parser.add_argument("--gpu", type=int, default=0) | ||
|
||
args = parser.parse_args() | ||
|
||
conf = {} | ||
if args.model_config_path is not None: | ||
conf = get_training_config(args.model_config_path, args.teacher, args.dataset) | ||
conf = dict(args.__dict__, **conf) | ||
|
||
if args.gpu >= 0: | ||
tlx.set_device("GPU", args.gpu) | ||
else: | ||
tlx.set_device("CPU") | ||
|
||
train_student(args) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally, you may get the mask through the
dataset.train_mask
. So, it would be better to modify thegammagl.dataset.Amazon
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requested modifications have been made to
train_teacher.py
,train_student.py
,gammagl.dataset.Amazon
andgammagl.utils.get_train_val_test_split
.Thank you for the reminder!