Skip to content

Commit 080b4d5

Browse files
committed
add test_get_split.py
1 parent ec6d99d commit 080b4d5

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/utils/test_get_split.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import tensorlayerx as tlx
2+
from gammagl.utils import get_train_val_test_split
3+
import numpy as np
4+
5+
6+
class Graph:
7+
def __init__(self, num_nodes):
8+
self.num_nodes = num_nodes
9+
10+
def test_get_split():
11+
num_nodes = 1000
12+
graph = Graph(num_nodes)
13+
14+
train_ratio = 0.6
15+
val_ratio = 0.2
16+
17+
train_mask, val_mask, test_mask = get_train_val_test_split(graph, train_ratio, val_ratio)
18+
19+
assert tlx.ops.is_tensor(train_mask)
20+
assert tlx.ops.is_tensor(val_mask)
21+
assert tlx.ops.is_tensor(test_mask)
22+
23+
train_mask = tlx.convert_to_numpy(train_mask)
24+
val_mask = tlx.convert_to_numpy(val_mask)
25+
test_mask = tlx.convert_to_numpy(test_mask)
26+
27+
assert np.sum(train_mask) == int(num_nodes * train_ratio)
28+
assert np.sum(val_mask) == int(num_nodes * val_ratio)
29+
assert np.sum(test_mask) == num_nodes - int(num_nodes * train_ratio) - int(num_nodes * val_ratio)
30+
31+
assert np.all(train_mask + val_mask + test_mask == 1)
32+

0 commit comments

Comments
 (0)