1
1
import tensorlayerx as tlx
2
2
import numpy as np
3
+ from sklearn .model_selection import train_test_split
3
4
4
5
5
- def get_train_val_test_split (graph , train_per_class , val_per_class , num_classes ):
6
- """Split the dataset into train, validation, and test sets.
6
+ def get_train_val_test_split (graph , train_ratio , val_ratio ):
7
+ """
8
+ Split the dataset into train, validation, and test sets.
7
9
8
10
Parameters
9
11
----------
10
12
graph :
11
13
The graph to split.
12
- train_per_class : int
13
- The number of training examples per class.
14
- val_per_class : int
15
- The number of validation examples per class.
16
- num_classes : int
17
- The number of classes in the dataset.
14
+ train_ratio : float
15
+ The proportion of the dataset to include in the train split.
16
+ val_ratio : float
17
+ The proportion of the dataset to include in the validation split.
18
18
19
19
Returns
20
20
-------
21
21
:class:`tuple` of :class:`tensor`
22
-
23
22
"""
24
- random_state = np .random .RandomState (0 )
25
- labels = tlx .nn .OneHot (depth = num_classes )(graph .y ).numpy ()
26
- num_samples , num_classes = graph .num_nodes , num_classes
27
- remaining_indices = set (range (num_samples ))
28
- forbidden_indices = set ()
29
23
30
- train_indices = sample_per_class (random_state , num_samples , num_classes , labels , train_per_class , forbidden_indices = forbidden_indices )
31
- forbidden_indices .update (train_indices )
32
- val_indices = sample_per_class (random_state , num_samples , num_classes , labels , val_per_class , forbidden_indices = forbidden_indices )
33
- forbidden_indices .update (val_indices )
34
- test_indices = np .array (list (remaining_indices - forbidden_indices ))
24
+ random_state = np .random .RandomState (0 )
25
+ num_samples = graph .num_nodes
26
+ all_indices = np .arange (num_samples )
35
27
36
- return generate_masks (graph .num_nodes , train_indices , val_indices , test_indices )
28
+ # split into train and (val + test)
29
+ train_indices , val_test_indices = train_test_split (
30
+ all_indices , train_size = train_ratio , random_state = random_state
31
+ )
37
32
33
+ # calculate the ratio of validation and test splits in the remaining data
34
+ test_ratio = 1.0 - train_ratio - val_ratio
35
+ val_size_ratio = val_ratio / (val_ratio + test_ratio )
38
36
39
- def sample_per_class (random_state , num_samples , num_classes , labels , num_examples_per_class , forbidden_indices = None ):
40
- sample_indices_per_class = {index : [] for index in range (num_classes )}
41
- forbidden_set = set (forbidden_indices ) if forbidden_indices is not None else set ()
37
+ # split val + test into validation and test sets
38
+ val_indices , test_indices = train_test_split (
39
+ val_test_indices , train_size = val_size_ratio , random_state = random_state
40
+ )
42
41
43
- for class_index in range (num_classes ):
44
- for sample_index in range (num_samples ):
45
- if labels [sample_index , class_index ] > 0.0 and sample_index not in forbidden_set :
46
- sample_indices_per_class [class_index ].append (sample_index )
42
+ return generate_masks (num_samples , train_indices , val_indices , test_indices )
47
43
48
- return np .concatenate (
49
- [random_state .choice (sample_indices_per_class [class_index ], num_examples_per_class , replace = False )
50
- for class_index in range (num_classes )
51
- ])
52
44
45
+ def generate_masks (num_nodes , train_indices , val_indices , test_indices ):
46
+ np_train_mask = np .zeros (num_nodes , dtype = bool )
47
+ np_train_mask [train_indices ] = 1
48
+ np_val_mask = np .zeros (num_nodes , dtype = bool )
49
+ np_val_mask [val_indices ] = 1
50
+ np_test_mask = np .zeros (num_nodes , dtype = bool )
51
+ np_test_mask [test_indices ] = 1
53
52
54
- def generate_masks (num_nodes , train_indices , val_indices , test_indices ):
55
- np_train_mask = np .zeros (num_nodes )
56
- np_train_mask [train_indices ] = 1
57
- np_val_mask = np .zeros (num_nodes )
58
- np_val_mask [val_indices ] = 1
59
- np_test_mask = np .zeros (num_nodes )
60
- np_test_mask [test_indices ] = 1
61
- train_mask = tlx .ops .convert_to_tensor (np_train_mask , dtype = tlx .bool )
53
+ train_mask = tlx .ops .convert_to_tensor (np_train_mask , dtype = tlx .bool )
62
54
val_mask = tlx .ops .convert_to_tensor (np_val_mask , dtype = tlx .bool )
63
- test_mask = tlx .ops .convert_to_tensor (np_test_mask , dtype = tlx .bool )
55
+ test_mask = tlx .ops .convert_to_tensor (np_test_mask , dtype = tlx .bool )
56
+
64
57
return train_mask , val_mask , test_mask
0 commit comments