Skip to content

Commit 2403822

Browse files
author
gyzhou2000
committed
solve the problem of repeated definition of dropout in forward of MLP
1 parent 6ee25a3 commit 2403822

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

gammagl/models/gin.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,15 @@ def __init__(self, in_channels,
4949

5050
self.mlp = MLP([hidden_channels, hidden_channels, out_channels],
5151
norm=None, dropout=0.5)
52+
self.relu = tlx.ReLU()
5253

5354
def forward(self, x, edge_index, batch):
5455
if x is None:
5556
# x = tlx.ones((batch.shape[0], 1), dtype=tlx.float32)
5657
x = tlx.random_normal((batch.shape[0], 1), dtype=tlx.float32)
5758

5859
for conv in self.convs:
59-
x = tlx.relu(conv(x, edge_index))
60+
x = self.relu(conv(x, edge_index))
6061

6162
x = global_sum_pool(x, batch)
6263
return self.mlp(x)

gammagl/models/mlp.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def __init__(self,
4646
dropout[-1] = 0.
4747
assert len(dropout) == len(channel_list) - 1
4848
self.dropout = dropout
49+
self.dropouts = tlx.nn.ModuleList()
50+
for i in range(len(dropout)):
51+
self.dropouts.append(tlx.nn.Dropout(p=dropout[i]))
4952

5053
if isinstance(bias, bool):
5154
bias = [bias] * (len(channel_list) - 1)
@@ -89,12 +92,14 @@ def forward(self, x, return_emb=None):
8992
if self.act is not None and not self.act_first:
9093
x = self.act(x)
9194

92-
x = tlx.nn.Dropout(p=self.dropout[i])(x)
95+
# x = tlx.nn.Dropout(p=self.dropout[i])(x)
96+
x = self.dropouts[i](x)
9397
emb = x
9498

9599
if self.plain_last:
96100
x = self.lins[-1](x)
97-
x = tlx.nn.Dropout(p=self.dropout[-1])(x)
101+
# x = tlx.nn.Dropout(p=self.dropout[-1])(x)
102+
x = self.dropouts[-1](x)
98103

99104
return (x, emb) if isinstance(return_emb, bool) else x
100105

gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ std::tuple<torch::Tensor, torch::Tensor> segment_max_cpu_forward(
3434
auto index_data = index.data_ptr<int64_t>();
3535
auto arg_out_data = arg_out.data_ptr<int64_t>();
3636

37-
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() {
37+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_max_cpu_forward", [&]() {
3838
out.fill_(std::numeric_limits<scalar_t>::lowest());
3939
auto x_data = x.data_ptr<scalar_t>();
4040
auto out_data = out.data_ptr<scalar_t>();

0 commit comments

Comments
 (0)