1
1
import torch
2
2
from torch import nn
3
3
import torch .nn .functional as F
4
- from torch_geometric .nn import GMMConv
4
+ import torch_geometric .nn as gnn
5
5
6
6
7
7
class Encoder (torch .nn .Module ):
@@ -27,7 +27,7 @@ class Encoder(torch.nn.Module):
27
27
forward(data): A convenience function that calls the encoder method.
28
28
"""
29
29
30
- def __init__ (self , hidden_channels , bottleneck , input_size , ffn , skip , act = F .elu ):
30
+ def __init__ (self , hidden_channels , bottleneck , input_size , ffn , skip , act = F .elu , conv = 'GMMConv' ):
31
31
super ().__init__ ()
32
32
self .hidden_channels = hidden_channels
33
33
self .depth = len (self .hidden_channels )
@@ -36,22 +36,33 @@ def __init__(self, hidden_channels, bottleneck, input_size, ffn, skip, act=F.elu
36
36
self .skip = skip
37
37
self .bottleneck = bottleneck
38
38
self .input_size = input_size
39
+ self .conv = conv
39
40
40
41
self .down_convs = torch .nn .ModuleList ()
41
42
for i in range (self .depth - 1 ):
42
- self .down_convs .append (GMMConv (self .hidden_channels [i ], self .hidden_channels [i + 1 ], dim = 1 , kernel_size = 5 ))
43
+ if self .conv == 'GMMConv' :
44
+ self .down_convs .append (gnn .GMMConv (self .hidden_channels [i ], self .hidden_channels [i + 1 ], dim = 1 , kernel_size = 5 ))
45
+ elif self .conv == 'ChebConv' :
46
+ self .down_convs .append (gnn .ChebConv (self .hidden_channels [i ], self .hidden_channels [i + 1 ], K = 5 ))
47
+ elif self .conv == 'GCNConv' :
48
+ self .down_convs .append (gnn .GCNConv (self .hidden_channels [i ], self .hidden_channels [i + 1 ]))
49
+ elif self .conv == 'GATConv' :
50
+ self .down_convs .append (gnn .GATConv (self .hidden_channels [i ], self .hidden_channels [i + 1 ]))
51
+ else :
52
+ raise NotImplementedError ('Invalid convolution selected. Please select one of [GMMConv, ChebConv, GCNConv, GATConv]' )
43
53
44
54
self .fc_in1 = nn .Linear (self .input_size * self .hidden_channels [- 1 ], self .ffn )
45
55
self .fc_in2 = nn .Linear (self .ffn , self .bottleneck )
46
56
self .reset_parameters ()
47
57
48
58
def encoder (self , data ):
49
- edge_weight = data .edge_attr
50
- edge_index = data .edge_index
51
59
x = data .x
52
60
idx = 0
53
61
for layer in self .down_convs :
54
- x = self .act (layer (x , edge_index , edge_weight .unsqueeze (1 )))
62
+ if self .conv in ['GMMConv' , 'ChebConv' , 'GCNConv' ]:
63
+ x = self .act (layer (x , data .edge_index , data .edge_weight ))
64
+ elif self .conv in ['GATConv' ]:
65
+ x = self .act (layer (x , data .edge_index , data .edge_attr ))
55
66
if self .skip :
56
67
x = x + data .x
57
68
idx += 1
@@ -103,7 +114,7 @@ class Decoder(torch.nn.Module):
103
114
Performs a forward pass on the input data x and returns the output.
104
115
"""
105
116
106
- def __init__ (self , hidden_channels , bottleneck , input_size , ffn , skip , act = F .elu ):
117
+ def __init__ (self , hidden_channels , bottleneck , input_size , ffn , skip , act = F .elu , conv = 'GMMConv' ):
107
118
super ().__init__ ()
108
119
self .hidden_channels = hidden_channels
109
120
self .depth = len (self .hidden_channels )
@@ -112,31 +123,41 @@ def __init__(self, hidden_channels, bottleneck, input_size, ffn, skip, act=F.elu
112
123
self .skip = skip
113
124
self .bottleneck = bottleneck
114
125
self .input_size = input_size
126
+ self .conv = conv
115
127
116
128
self .fc_out1 = nn .Linear (self .bottleneck , self .ffn )
117
129
self .fc_out2 = nn .Linear (self .ffn , self .input_size * self .hidden_channels [- 1 ])
118
130
119
131
self .up_convs = torch .nn .ModuleList ()
120
132
for i in range (self .depth - 1 ):
121
- self .up_convs .append (GMMConv (self .hidden_channels [self .depth - 1 - i ], self .hidden_channels [self .depth - i - 2 ], dim = 1 , kernel_size = 5 ))
133
+ if self .conv == 'GMMConv' :
134
+ self .up_convs .append (gnn .GMMConv (self .hidden_channels [self .depth - i - 1 ], self .hidden_channels [self .depth - i - 2 ], dim = 1 , kernel_size = 5 ))
135
+ elif self .conv == 'ChebConv' :
136
+ self .up_convs .append (gnn .ChebConv (self .hidden_channels [self .depth - i - 1 ], self .hidden_channels [self .depth - i - 2 ], K = 5 ))
137
+ elif self .conv == 'GCNConv' :
138
+ self .up_convs .append (gnn .GCNConv (self .hidden_channels [self .depth - i - 1 ], self .hidden_channels [self .depth - i - 2 ]))
139
+ elif self .conv == 'GATConv' :
140
+ self .up_convs .append (gnn .GATConv (self .hidden_channels [self .depth - i - 1 ], self .hidden_channels [self .depth - i - 2 ]))
141
+ else :
142
+ raise NotImplementedError ('Invalid convolution selected. Please select one of [GMMConv, ChebConv, GCNConv, GATConv]' )
143
+
122
144
123
145
self .reset_parameters ()
124
146
125
147
126
148
def decoder (self , x , data ):
127
- edge_weight = data .edge_attr
128
- edge_index = data .edge_index
129
-
130
149
x = self .act (self .fc_out1 (x ))
131
150
x = self .act (self .fc_out2 (x ))
132
151
h = x .reshape (data .num_graphs * self .input_size , self .hidden_channels [- 1 ])
133
152
x = h
134
153
idx = 0
135
154
for layer in self .up_convs :
136
- if (idx == self .depth - 2 ):
137
- x = layer (x , edge_index , edge_weight .unsqueeze (1 ))
138
- else :
139
- x = self .act (layer (x , edge_index , edge_weight .unsqueeze (1 )))
155
+ if self .conv in ['GMMConv' , 'ChebConv' , 'GCNConv' ]:
156
+ x = layer (x , data .edge_index , data .edge_weight )
157
+ elif self .conv in ['GATConv' ]:
158
+ x = layer (x , data .edge_index , data .edge_attr )
159
+ if (idx != self .depth - 2 ):
160
+ x = self .act (x )
140
161
if self .skip :
141
162
x = x + h
142
163
idx += 1
0 commit comments