Skip to content

Commit 92d3a3c

Browse files
authored
[Model]fix model bugs (#71)
* [Model]fix init * [Model]fix model bugs
1 parent f224daf commit 92d3a3c

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

examples/simplehgn/simplehgn_trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def main(args):
118118
val_micro_f1, val_macro_f1 = calculate_f1_score(val_logits, val_y)
119119
print("Epoch [{:0>3d}] ".format(epoch + 1),
120120
" train loss: {:.4f}".format(train_loss.item()),
121-
" val loss: {:.4f}".format(val_loss),
121+
" val loss: {:.4f}".format(val_loss.item()),
122122
" val micro: {:.4f}".format(val_micro_f1),
123123
" val macro: {:.4f}".format(val_macro_f1),)
124124
if val_loss < best_val_loss:
@@ -132,7 +132,7 @@ def main(args):
132132

133133
model.load_weights(args.best_model_path+model.name+".npz", format='npz_dict')
134134
if tlx.BACKEND == 'torch':
135-
model.to(data["x"].device)
135+
model.to(data["x"][0].device)
136136
model.set_eval()
137137
logits = model(data['x'], data['edge_index'], data['e_feat'])
138138
test_logits = tlx.gather(logits, data['test_idx'])

gammagl/layers/conv/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
'JumpingKnowledge',
3131
'HANConv',
3232
'ChebConv',
33-
'SimpleHGNConv'
33+
'SimpleHGNConv',
3434
'FAGCNConv',
3535
'GPRConv',
3636
]

gammagl/layers/conv/simplehgn_conv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def __init__(self,
8585
self.out_feats = out_feats
8686
self.edge_embedding = tlx.nn.Embedding(num_etypes, edge_feats)
8787

88-
self.fc_node = tlx.nn.Linear(out_feats * heads, in_features=in_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414), name='fc_node')
89-
self.fc_edge = tlx.nn.Linear(edge_feats * heads, in_features=edge_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414), name='fc_edge')
88+
self.fc_node = tlx.nn.Linear(out_feats * heads, in_features=in_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414))
89+
self.fc_edge = tlx.nn.Linear(edge_feats * heads, in_features=edge_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414))
9090

9191
self.attn_src = self._get_weights('attn_l', shape=(1, heads, out_feats), init=tlx.initializers.XavierNormal(gain=1.414), order=True)
9292
self.attn_dst = self._get_weights('attn_r', shape=(1, heads, out_feats), init=tlx.initializers.XavierNormal(gain=1.414), order=True)
@@ -96,7 +96,7 @@ def __init__(self,
9696
self.attn_drop = tlx.nn.Dropout(attn_drop)
9797
self.leaky_relu = tlx.nn.LeakyReLU(negative_slope)
9898

99-
self.fc_res = tlx.nn.Linear(heads * out_feats, in_features=in_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414), name='fc_res') if residual else None
99+
self.fc_res = tlx.nn.Linear(heads * out_feats, in_features=in_feats, b_init=None, W_init=tlx.initializers.XavierNormal(gain=1.414)) if residual else None
100100

101101
self.activation = activation
102102

tests/layers/conv/test_simplehgn_conv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
def test_simplehgn_conv():
66
x = tlx.random_normal(shape=(4, 64))
77
edge_index = tlx.convert_to_tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
8-
edge_feat = [0, 1, 2, 3, 4, 5]
8+
edge_feat = tlx.convert_to_tensor([0, 1, 2, 3, 4, 5])
99

1010
conv = SimpleHGNConv(in_feats=64, out_feats=128, num_etypes=6, edge_feats=64, heads=8)
1111
out, _ = conv(x, edge_index, edge_feat=edge_feat)

0 commit comments

Comments
 (0)