Skip to content

Commit ef4924e

Browse files
committed
update amazon
1 parent 67c6ced commit ef4924e

File tree

4 files changed

+30
-31
lines changed

4 files changed

+30
-31
lines changed

examples/glnn/train_student.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorlayerx as tlx
1313
from gammagl.datasets import Planetoid, Amazon
1414
from gammagl.models import MLP
15-
from gammagl.utils import mask_to_index, get_train_val_test_split
15+
from gammagl.utils import mask_to_index
1616
from tensorlayerx.model import TrainOneStep, WithLoss
1717

1818

@@ -73,25 +73,18 @@ def train_student(args):
7373
if args.dataset in ['cora', 'pubmed', 'citeseer']:
7474
dataset = Planetoid(args.dataset_path, args.dataset)
7575
elif args.dataset in ['computers', 'photo']:
76-
dataset = Amazon(args.dataset_path, args.dataset)
76+
dataset = Amazon(args.dataset_path, args.dataset, train_per_class=20, val_per_class=30)
7777
graph = dataset[0]
7878

7979
# load teacher_logits from .npy file
8080
teacher_logits = tlx.files.load_npy_to_any(path = r'./', name = f'{args.dataset}_{args.teacher}_logits.npy')
8181
teacher_logits = tlx.ops.convert_to_tensor(teacher_logits)
8282

8383
# for mindspore, it should be passed into node indices
84-
if args.dataset in ['cora', 'pubmed', 'citeseer']:
85-
train_idx = mask_to_index(graph.train_mask)
86-
test_idx = mask_to_index(graph.test_mask)
87-
val_idx = mask_to_index(graph.val_mask)
88-
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0)
89-
elif args.dataset in ['computers', 'photo']:
90-
train_mask, val_mask, test_mask = get_train_val_test_split(dataset, train_per_class=20, val_per_class=30)
91-
train_idx = mask_to_index(train_mask)
92-
val_idx = mask_to_index(val_mask)
93-
test_idx = mask_to_index(test_mask)
94-
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0)
84+
train_idx = mask_to_index(graph.train_mask)
85+
test_idx = mask_to_index(graph.test_mask)
86+
val_idx = mask_to_index(graph.val_mask)
87+
t_idx = tlx.concat([train_idx, test_idx, val_idx], axis=0)
9588

9689
net = MLP(in_channels=dataset.num_node_features,
9790
hidden_channels=conf["hidden_dim"],

examples/glnn/train_teacher.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorlayerx as tlx
1313
from gammagl.datasets import Planetoid, Amazon
1414
from gammagl.models import GCNModel, GraphSAGE_Full_Model, GATModel, APPNPModel, MLP
15-
from gammagl.utils import mask_to_index, calc_gcn_norm, get_train_val_test_split
15+
from gammagl.utils import mask_to_index, calc_gcn_norm
1616
from tensorlayerx.model import TrainOneStep, WithLoss
1717

1818

@@ -66,21 +66,15 @@ def train_teacher(args):
6666
if args.dataset in ['cora', 'pubmed', 'citeseer']:
6767
dataset = Planetoid(args.dataset_path, args.dataset)
6868
elif args.dataset in ['computers', 'photo']:
69-
dataset = Amazon(args.dataset_path, args.dataset)
69+
dataset = Amazon(args.dataset_path, args.dataset, train_per_class=20, val_per_class=30)
7070
graph = dataset[0]
7171
edge_index = graph.edge_index
7272
edge_weight = tlx.convert_to_tensor(calc_gcn_norm(edge_index, graph.num_nodes))
7373

7474
# for mindspore, it should be passed into node indices
75-
if args.dataset in ['cora', 'pubmed', 'citeseer']:
76-
train_idx = mask_to_index(graph.train_mask)
77-
test_idx = mask_to_index(graph.test_mask)
78-
val_idx = mask_to_index(graph.val_mask)
79-
elif args.dataset in ['computers', 'photo']:
80-
train_mask, val_mask, test_mask = get_train_val_test_split(dataset, train_per_class=20, val_per_class=30)
81-
train_idx = mask_to_index(train_mask)
82-
val_idx = mask_to_index(val_mask)
83-
test_idx = mask_to_index(test_mask)
75+
train_idx = mask_to_index(graph.train_mask)
76+
test_idx = mask_to_index(graph.test_mask)
77+
val_idx = mask_to_index(graph.val_mask)
8478

8579
if args.teacher == "GCN":
8680
net = GCNModel(feature_dim=dataset.num_node_features,

gammagl/datasets/amazon.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tensorlayerx as tlx
44
from gammagl.data import InMemoryDataset, download_url
55
from gammagl.io.npz import read_npz
6+
from gammagl.utils import get_train_val_test_split
67

78

89
class Amazon(InMemoryDataset):
@@ -33,6 +34,10 @@ class Amazon(InMemoryDataset):
3334
being saved to disk. (default: :obj:`None`)
3435
force_reload (bool, optional): Whether to re-process the dataset.
3536
(default: :obj:`False`)
37+
train_per_class (int, optional): Number of training samples per class.
38+
(default: :obj:`20`)
39+
val_per_class (int, optional): Number of validation samples per class.
40+
(default: :obj:`20`)
3641
3742
Stats:
3843
.. list-table::
@@ -61,12 +66,18 @@ class Amazon(InMemoryDataset):
6166
def __init__(self, root: str = None, name: str = 'computers',
6267
transform: Optional[Callable] = None,
6368
pre_transform: Optional[Callable] = None,
64-
force_reload: bool = False):
69+
force_reload: bool = False,
70+
train_per_class: int = 20,
71+
val_per_class: int = 20):
6572
self.name = name.lower()
6673
assert self.name in ['computers', 'photo']
6774
super().__init__(root, transform, pre_transform, force_reload = force_reload)
6875
self.data, self.slices = self.load_data(self.processed_paths[0])
6976

77+
data = self.get(0)
78+
data.train_mask, data.val_mask, data.test_mask = get_train_val_test_split(self.data, train_per_class, val_per_class, self.num_classes)
79+
self.data, self.slices = self.collate([data])
80+
7081
@property
7182
def raw_dir(self) -> str:
7283
return osp.join(self.root, self.name.capitalize(), 'raw')

gammagl/utils/get_split.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,28 @@
22
import numpy as np
33

44

5-
def get_train_val_test_split(dataset, train_per_class, val_per_class):
5+
def get_train_val_test_split(graph, train_per_class, val_per_class, num_classes):
66
"""Split the dataset into train, validation, and test sets.
77
88
Parameters
99
----------
10-
dataset :
11-
The dataset to split.
10+
graph :
11+
The graph to split.
1212
train_per_class : int
1313
The number of training examples per class.
1414
val_per_class : int
1515
The number of validation examples per class.
16+
num_classes : int
17+
The number of classes in the dataset.
1618
1719
Returns
1820
-------
1921
:class:`tuple` of :class:`tensor`
2022
2123
"""
22-
graph = dataset[0]
2324
random_state = np.random.RandomState(0)
24-
labels = tlx.nn.OneHot(depth=dataset.num_classes)(graph.y).numpy()
25-
num_samples, num_classes = graph.num_nodes, dataset.num_classes
25+
labels = tlx.nn.OneHot(depth=num_classes)(graph.y).numpy()
26+
num_samples, num_classes = graph.num_nodes, num_classes
2627
remaining_indices = set(range(num_samples))
2728
forbidden_indices = set()
2829

0 commit comments

Comments
 (0)