Skip to content

Commit ec6d99d

Browse files
committed
update get_split.py
1 parent dc0511c commit ec6d99d

File tree

5 files changed

+52
-55
lines changed

5 files changed

+52
-55
lines changed

examples/glnn/readme.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ TL_BACKEND="mindspore" python train_student.py --dataset cora --teacher SAGE
3535
| Cora | 80.54±1.35 | 80.94±0.31 | 80.84±0.30 | 80.90±0.21 | 81.04±0.30 |
3636
| Citeseer | 71.77±2.01 | 70.74±0.87 | 71.34±0.55 | 71.18±1.20 | 70.58±1.14 |
3737
| Pubmed | 75.42±2.31 | 77.90±0.07 | 77.88±0.23 | 77.78±0.19 | 77.78±0.13 |
38-
| Computers | 83.03±1.87 | 81.51±0.60 | 81.73±0.48 | 81.46±0.72 | 81.24±1.27 |
39-
| Photo | 92.11±1.08 | 92.05±0.56 | 91.92±0.53 | 92.00±0.55 | 91.77±0.91 |
38+
| Computers | 83.03±1.87 | 83.45±0.61 | 82.78±0.47 | 83.03±0.14 | 83.40±0.45 |
39+
| Photo | 92.11±1.08 | 91.93±0.16 | 91.91±0.24 | 91.89±0.27 | 91.88±0.21 |
4040

4141
- The model performance is the average of 5 tests

examples/glnn/train_student.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def train_student(args):
7272
raise ValueError('Unknown dataset: {}'.format(args.dataset))
7373
if args.dataset in ['cora', 'pubmed', 'citeseer']:
7474
dataset = Planetoid(args.dataset_path, args.dataset)
75-
elif args.dataset in ['computers', 'photo']:
76-
dataset = Amazon(args.dataset_path, args.dataset, train_per_class=20, val_per_class=30)
75+
elif args.dataset == 'computers':
76+
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=200/13752, val_ratio=(200/13752)*1.5)
77+
elif args.dataset == 'photo':
78+
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=160/7650, val_ratio=(160/7650)*1.5)
7779
graph = dataset[0]
7880

7981
# load teacher_logits from .npy file

examples/glnn/train_teacher.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ def train_teacher(args):
6565
raise ValueError('Unknown dataset: {}'.format(args.dataset))
6666
if args.dataset in ['cora', 'pubmed', 'citeseer']:
6767
dataset = Planetoid(args.dataset_path, args.dataset)
68-
elif args.dataset in ['computers', 'photo']:
69-
dataset = Amazon(args.dataset_path, args.dataset, train_per_class=20, val_per_class=30)
68+
elif args.dataset == 'computers':
69+
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=200/13752, val_ratio=(200/13752)*1.5)
70+
elif args.dataset == 'photo':
71+
dataset = Amazon(args.dataset_path, args.dataset, train_ratio=160/7650, val_ratio=(160/7650)*1.5)
7072
graph = dataset[0]
7173
edge_index = graph.edge_index
7274
edge_weight = tlx.convert_to_tensor(calc_gcn_norm(edge_index, graph.num_nodes))

gammagl/datasets/amazon.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ class Amazon(InMemoryDataset):
3535
force_reload : bool, optional
3636
Whether to re-process the dataset.
3737
(default: :obj:`False`)
38-
train_per_class : int, optional
39-
Number of training samples per class.
40-
(default: :obj:`20`)
41-
val_per_class : int, optional
42-
Number of validation samples per class.
43-
(default: :obj:`20`)
38+
train_ratio : float, optional
39+
Ratio of training samples.
40+
(default: :obj:`0.1`)
41+
val_ratio : float, optional
42+
Ratio of validation samples.
43+
(default: :obj:`0.15`)
4444
4545
Stats:
4646
.. list-table::
@@ -70,15 +70,15 @@ def __init__(self, root: str = None, name: str = 'computers',
7070
transform: Optional[Callable] = None,
7171
pre_transform: Optional[Callable] = None,
7272
force_reload: bool = False,
73-
train_per_class: int = 20,
74-
val_per_class: int = 20):
73+
train_ratio: float = 0.1,
74+
val_ratio: float = 0.15):
7575
self.name = name.lower()
7676
assert self.name in ['computers', 'photo']
7777
super().__init__(root, transform, pre_transform, force_reload = force_reload)
7878
self.data, self.slices = self.load_data(self.processed_paths[0])
7979

8080
data = self.get(0)
81-
data.train_mask, data.val_mask, data.test_mask = get_train_val_test_split(self.data, train_per_class, val_per_class, self.num_classes)
81+
data.train_mask, data.val_mask, data.test_mask = get_train_val_test_split(self.data, train_ratio, val_ratio)
8282
self.data, self.slices = self.collate([data])
8383

8484
@property

gammagl/utils/get_split.py

+33-40
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,57 @@
11
import tensorlayerx as tlx
22
import numpy as np
3+
from sklearn.model_selection import train_test_split
34

45

5-
def get_train_val_test_split(graph, train_per_class, val_per_class, num_classes):
6-
"""Split the dataset into train, validation, and test sets.
6+
def get_train_val_test_split(graph, train_ratio, val_ratio):
7+
"""
8+
Split the dataset into train, validation, and test sets.
79
810
Parameters
911
----------
1012
graph :
1113
The graph to split.
12-
train_per_class : int
13-
The number of training examples per class.
14-
val_per_class : int
15-
The number of validation examples per class.
16-
num_classes : int
17-
The number of classes in the dataset.
14+
train_ratio : float
15+
The proportion of the dataset to include in the train split.
16+
val_ratio : float
17+
The proportion of the dataset to include in the validation split.
1818
1919
Returns
2020
-------
2121
:class:`tuple` of :class:`tensor`
22-
2322
"""
24-
random_state = np.random.RandomState(0)
25-
labels = tlx.nn.OneHot(depth=num_classes)(graph.y).numpy()
26-
num_samples, num_classes = graph.num_nodes, num_classes
27-
remaining_indices = set(range(num_samples))
28-
forbidden_indices = set()
2923

30-
train_indices = sample_per_class(random_state, num_samples, num_classes, labels, train_per_class, forbidden_indices=forbidden_indices)
31-
forbidden_indices.update(train_indices)
32-
val_indices = sample_per_class(random_state, num_samples, num_classes, labels, val_per_class, forbidden_indices=forbidden_indices)
33-
forbidden_indices.update(val_indices)
34-
test_indices = np.array(list(remaining_indices - forbidden_indices))
24+
random_state = np.random.RandomState(0)
25+
num_samples = graph.num_nodes
26+
all_indices = np.arange(num_samples)
3527

36-
return generate_masks(graph.num_nodes, train_indices, val_indices, test_indices)
28+
# split into train and (val + test)
29+
train_indices, val_test_indices = train_test_split(
30+
all_indices, train_size=train_ratio, random_state=random_state
31+
)
3732

33+
# calculate the ratio of validation and test splits in the remaining data
34+
test_ratio = 1.0 - train_ratio - val_ratio
35+
val_size_ratio = val_ratio / (val_ratio + test_ratio)
3836

39-
def sample_per_class(random_state, num_samples, num_classes, labels, num_examples_per_class, forbidden_indices=None):
40-
sample_indices_per_class = {index: [] for index in range(num_classes)}
41-
forbidden_set = set(forbidden_indices) if forbidden_indices is not None else set()
37+
# split val + test into validation and test sets
38+
val_indices, test_indices = train_test_split(
39+
val_test_indices, train_size=val_size_ratio, random_state=random_state
40+
)
4241

43-
for class_index in range(num_classes):
44-
for sample_index in range(num_samples):
45-
if labels[sample_index, class_index] > 0.0 and sample_index not in forbidden_set:
46-
sample_indices_per_class[class_index].append(sample_index)
42+
return generate_masks(num_samples, train_indices, val_indices, test_indices)
4743

48-
return np.concatenate(
49-
[random_state.choice(sample_indices_per_class[class_index], num_examples_per_class, replace=False)
50-
for class_index in range(num_classes)
51-
])
5244

45+
def generate_masks(num_nodes, train_indices, val_indices, test_indices):
46+
np_train_mask = np.zeros(num_nodes, dtype=bool)
47+
np_train_mask[train_indices] = 1
48+
np_val_mask = np.zeros(num_nodes, dtype=bool)
49+
np_val_mask[val_indices] = 1
50+
np_test_mask = np.zeros(num_nodes, dtype=bool)
51+
np_test_mask[test_indices] = 1
5352

54-
def generate_masks(num_nodes, train_indices, val_indices, test_indices):
55-
np_train_mask = np.zeros(num_nodes)
56-
np_train_mask[train_indices] = 1
57-
np_val_mask = np.zeros(num_nodes)
58-
np_val_mask[val_indices] = 1
59-
np_test_mask = np.zeros(num_nodes)
60-
np_test_mask[test_indices] = 1
61-
train_mask = tlx.ops.convert_to_tensor(np_train_mask, dtype=tlx.bool)
53+
train_mask = tlx.ops.convert_to_tensor(np_train_mask, dtype=tlx.bool)
6254
val_mask = tlx.ops.convert_to_tensor(np_val_mask, dtype=tlx.bool)
63-
test_mask = tlx.ops.convert_to_tensor(np_test_mask, dtype=tlx.bool)
55+
test_mask = tlx.ops.convert_to_tensor(np_test_mask, dtype=tlx.bool)
56+
6457
return train_mask, val_mask, test_mask

0 commit comments

Comments
 (0)