Skip to content

Commit c3c06d8

Browse files
committed
[Bug] Fix Reddit
1 parent d54dca8 commit c3c06d8

File tree

3 files changed

+12
-17
lines changed

3 files changed

+12
-17
lines changed

examples/graphsage/reddit_sage.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ def main(args):
4848
test_idx = mask_to_index(graph.test_mask)
4949
val_idx = mask_to_index(graph.val_mask)
5050

51-
train_loader = NeighborSampler(edge_index=graph.edge_index.numpy(),
52-
node_idx=tlx.convert_to_numpy(train_idx),
51+
train_loader = NeighborSampler(edge_index=graph.edge_index,
52+
node_idx=train_idx,
5353
sample_lists=[25, 10], batch_size=2048, shuffle=True, num_workers=0)
5454

55-
val_loader = NeighborSampler(edge_index=graph.edge_index.numpy(),
56-
node_idx=tlx.convert_to_numpy(val_idx),
55+
val_loader = NeighborSampler(edge_index=graph.edge_index,
56+
node_idx=val_idx,
5757
sample_lists=[-1], batch_size=2048 * 2, shuffle=False, num_workers=0)
58-
test_loader = NeighborSampler(edge_index=graph.edge_index.numpy(),
59-
node_idx=tlx.convert_to_numpy(test_idx),
58+
test_loader = NeighborSampler(edge_index=graph.edge_index,
59+
node_idx=test_idx,
6060
sample_lists=[-1], batch_size=2048 * 2, shuffle=False, num_workers=0)
6161

6262
x = tlx.convert_to_tensor(graph.x)
@@ -78,6 +78,9 @@ def main(args):
7878
pbar = tqdm(total=int(len(train_loader.dataset)))
7979
pbar.set_description(f'Epoch {epoch:02d}')
8080
for dst_node, n_id, adjs in train_loader:
81+
print("---------------")
82+
print(adjs)
83+
print(type(adjs))
8184
net.set_train()
8285
# input : sampled subgraphs, sampled node's feat
8386
data = {"x": tlx.gather(x, n_id),

gammagl/datasets/reddit.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ def download(self):
5252

5353
def process(self):
5454
data = np.load(osp.join(self.raw_dir, 'reddit_data.npz'))
55-
x = np.array(data['feature'], dtype=np.float32)
56-
y = np.array(data['label'], np.int32)
55+
x = tlx.convert_to_tensor(data['feature'], dtype=tlx.float32)
56+
y = tlx.convert_to_tensor(data['label'], dtype=tlx.int64)
5757
split = np.array(data['node_types'])
5858

5959
adj = sp.load_npz(osp.join(self.raw_dir, 'reddit_graph.npz'))
6060

61-
edge = np.array([adj.row, adj.col], dtype=np.int64)
61+
edge = tlx.convert_to_tensor([adj.row, adj.col], dtype=tlx.int64)
6262

6363
edge, _ = coalesce(edge, None, x.shape[0], x.shape[0])
6464

gammagl/models/graphsage.py

-8
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,20 @@ def __init__(self, in_feat, hid_feat, out_feat, drop_rate, num_layers, name=None
7575

7676
def forward(self, x, edgeIndices):
7777
for l, (layer, edgeIndex) in enumerate(zip(self.convs, edgeIndices)):
78-
if tlx.BACKEND == 'torch':
79-
edgeIndex.to(x.device)
8078
target_x = tlx.gather(x, tlx.arange(0, edgeIndex.size[1])) # Target nodes are always placed first.
8179
x = layer((x, target_x), edgeIndex.edge_index)
8280
if l != len(self.convs) - 1:
8381
x = self.dropout(x)
8482
return x
8583

8684
def inference(self, feat, dataloader, cur_x):
87-
if tlx.BACKEND == 'torch':
88-
feat = feat.to(cur_x.device)
8985
for l, layer in enumerate(self.convs):
9086
y = tlx.zeros((feat.shape[0], self.num_class if l == len(self.convs) - 1 else self.hid_feat))
91-
if tlx.BACKEND == 'torch':
92-
y = y.to(feat.device)
9387
for dst_node, n_id, adjs in dataloader:
9488
if isinstance(adjs, (List, Tuple)):
9589
sg = adjs[0]
9690
else:
9791
sg = adjs
98-
if tlx.BACKEND == 'torch':
99-
sg.to(y.device)
10092
h = tlx.gather(feat, n_id)
10193
target_feat = tlx.gather(h, tlx.arange(0, sg.size[1]))
10294
h = layer((h, target_feat), sg.edge_index)

0 commit comments

Comments
 (0)