Skip to content

Commit 7addbfe

Browse files
authored
Merge pull request #11 from Oisin-M/feat/vary_conv_layers
Feat/vary conv layers
2 parents 81813db + 2a55a5c commit 7addbfe

File tree

3 files changed

+45
-23
lines changed

3 files changed

+45
-23
lines changed

gca_rom/gca.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch import nn
33
import torch.nn.functional as F
4-
from torch_geometric.nn import GMMConv
4+
import torch_geometric.nn as gnn
55

66

77
class Encoder(torch.nn.Module):
@@ -27,7 +27,7 @@ class Encoder(torch.nn.Module):
2727
forward(data): A convenience function that calls the encoder method.
2828
"""
2929

30-
def __init__(self, hidden_channels, bottleneck, input_size, ffn, skip, act=F.elu):
30+
def __init__(self, hidden_channels, bottleneck, input_size, ffn, skip, act=F.elu, conv='GMMConv'):
3131
super().__init__()
3232
self.hidden_channels = hidden_channels
3333
self.depth = len(self.hidden_channels)
@@ -36,22 +36,33 @@ def __init__(self, hidden_channels, bottleneck, input_size, ffn, skip, act=F.elu
3636
self.skip = skip
3737
self.bottleneck = bottleneck
3838
self.input_size = input_size
39+
self.conv = conv
3940

4041
self.down_convs = torch.nn.ModuleList()
4142
for i in range(self.depth-1):
42-
self.down_convs.append(GMMConv(self.hidden_channels[i], self.hidden_channels[i+1], dim=1, kernel_size=5))
43+
if self.conv=='GMMConv':
44+
self.down_convs.append(gnn.GMMConv(self.hidden_channels[i], self.hidden_channels[i+1], dim=1, kernel_size=5))
45+
elif self.conv=='ChebConv':
46+
self.down_convs.append(gnn.ChebConv(self.hidden_channels[i], self.hidden_channels[i+1], K=5))
47+
elif self.conv=='GCNConv':
48+
self.down_convs.append(gnn.GCNConv(self.hidden_channels[i], self.hidden_channels[i+1]))
49+
elif self.conv=='GATConv':
50+
self.down_convs.append(gnn.GATConv(self.hidden_channels[i], self.hidden_channels[i+1]))
51+
else:
52+
raise NotImplementedError('Invalid convolution selected. Please select one of [GMMConv, ChebConv, GCNConv, GATConv]')
4353

4454
self.fc_in1 = nn.Linear(self.input_size*self.hidden_channels[-1], self.ffn)
4555
self.fc_in2 = nn.Linear(self.ffn, self.bottleneck)
4656
self.reset_parameters()
4757

4858
def encoder(self, data):
49-
edge_weight = data.edge_attr
50-
edge_index = data.edge_index
5159
x = data.x
5260
idx = 0
5361
for layer in self.down_convs:
54-
x = self.act(layer(x, edge_index, edge_weight.unsqueeze(1)))
62+
if self.conv in ['GMMConv', 'ChebConv', 'GCNConv']:
63+
x = self.act(layer(x, data.edge_index, data.edge_weight))
64+
elif self.conv in ['GATConv']:
65+
x = self.act(layer(x, data.edge_index, data.edge_attr))
5566
if self.skip:
5667
x = x + data.x
5768
idx += 1
@@ -103,7 +114,7 @@ class Decoder(torch.nn.Module):
103114
Performs a forward pass on the input data x and returns the output.
104115
"""
105116

106-
def __init__(self, hidden_channels, bottleneck, input_size, ffn, skip, act=F.elu):
117+
def __init__(self, hidden_channels, bottleneck, input_size, ffn, skip, act=F.elu, conv='GMMConv'):
107118
super().__init__()
108119
self.hidden_channels = hidden_channels
109120
self.depth = len(self.hidden_channels)
@@ -112,31 +123,41 @@ def __init__(self, hidden_channels, bottleneck, input_size, ffn, skip, act=F.elu
112123
self.skip = skip
113124
self.bottleneck = bottleneck
114125
self.input_size = input_size
126+
self.conv = conv
115127

116128
self.fc_out1 = nn.Linear(self.bottleneck, self.ffn)
117129
self.fc_out2 = nn.Linear(self.ffn, self.input_size * self.hidden_channels[-1])
118130

119131
self.up_convs = torch.nn.ModuleList()
120132
for i in range(self.depth-1):
121-
self.up_convs.append(GMMConv(self.hidden_channels[self.depth-1-i], self.hidden_channels[self.depth-i-2], dim=1, kernel_size=5))
133+
if self.conv=='GMMConv':
134+
self.up_convs.append(gnn.GMMConv(self.hidden_channels[self.depth-i-1], self.hidden_channels[self.depth-i-2], dim=1, kernel_size=5))
135+
elif self.conv=='ChebConv':
136+
self.up_convs.append(gnn.ChebConv(self.hidden_channels[self.depth-i-1], self.hidden_channels[self.depth-i-2], K=5))
137+
elif self.conv=='GCNConv':
138+
self.up_convs.append(gnn.GCNConv(self.hidden_channels[self.depth-i-1], self.hidden_channels[self.depth-i-2]))
139+
elif self.conv=='GATConv':
140+
self.up_convs.append(gnn.GATConv(self.hidden_channels[self.depth-i-1], self.hidden_channels[self.depth-i-2]))
141+
else:
142+
raise NotImplementedError('Invalid convolution selected. Please select one of [GMMConv, ChebConv, GCNConv, GATConv]')
143+
122144

123145
self.reset_parameters()
124146

125147

126148
def decoder(self, x, data):
127-
edge_weight = data.edge_attr
128-
edge_index = data.edge_index
129-
130149
x = self.act(self.fc_out1(x))
131150
x = self.act(self.fc_out2(x))
132151
h = x.reshape(data.num_graphs*self.input_size, self.hidden_channels[-1])
133152
x = h
134153
idx = 0
135154
for layer in self.up_convs:
136-
if (idx == self.depth - 2):
137-
x = layer(x, edge_index, edge_weight.unsqueeze(1))
138-
else:
139-
x = self.act(layer(x, edge_index, edge_weight.unsqueeze(1)))
155+
if self.conv in ['GMMConv', 'ChebConv', 'GCNConv']:
156+
x = layer(x, data.edge_index, data.edge_weight)
157+
elif self.conv in ['GATConv']:
158+
x = layer(x, data.edge_index, data.edge_attr)
159+
if (idx != self.depth - 2):
160+
x = self.act(x)
140161
if self.skip:
141162
x = x + h
142163
idx += 1

gca_rom/network.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class HyperParams:
3333
cross_validation (bool): Whether to perform cross-validation.
3434
"""
3535

36-
def __init__(self, argv):
36+
def __init__(self, argv, **kwargs):
3737
self.net_name = argv[0]
3838
self.variable = argv[1]
3939
self.scaling_type = int(argv[2])
@@ -60,9 +60,10 @@ def __init__(self, argv):
6060
self.miles = []
6161
self.gamma = 0.0001
6262
self.num_nodes = 0
63+
self.conv = 'GMMConv'
6364
self.net_dir = './' + self.net_name + '/' + self.net_run + '/' + self.variable + '_' + self.net_name + '_lmap' + str(self.lambda_map) + '_btt' + str(self.bottleneck_dim) \
6465
+ '_seed' + str(self.seed) + '_lv' + str(len(self.layer_vec)-2) + '_hc' + str(len(self.hidden_channels)) + '_nd' + str(self.nodes) \
65-
+ '_ffn' + str(self.ffn) + '_skip' + str(self.skip) + '_lr' + str(self.learning_rate) + '_sc' + str(self.scaling_type) + '_rate' + str(self.rate) + '/'
66+
+ '_ffn' + str(self.ffn) + '_skip' + str(self.skip) + '_lr' + str(self.learning_rate) + '_sc' + str(self.scaling_type) + '_rate' + str(self.rate) + '_conv' + self.conv + '/'
6667
self.cross_validation = True
6768

6869

@@ -107,8 +108,8 @@ class Net(torch.nn.Module):
107108

108109
def __init__(self, HyperParams):
109110
super().__init__()
110-
self.encoder = gca.Encoder(HyperParams.hidden_channels, HyperParams.bottleneck_dim, HyperParams.num_nodes, ffn=HyperParams.ffn, skip=HyperParams.skip)
111-
self.decoder = gca.Decoder(HyperParams.hidden_channels, HyperParams.bottleneck_dim, HyperParams.num_nodes, ffn=HyperParams.ffn, skip=HyperParams.skip)
111+
self.encoder = gca.Encoder(HyperParams.hidden_channels, HyperParams.bottleneck_dim, HyperParams.num_nodes, ffn=HyperParams.ffn, skip=HyperParams.skip, conv=HyperParams.conv)
112+
self.decoder = gca.Decoder(HyperParams.hidden_channels, HyperParams.bottleneck_dim, HyperParams.num_nodes, ffn=HyperParams.ffn, skip=HyperParams.skip, conv=HyperParams.conv)
112113

113114
self.act_map = HyperParams.act
114115
self.layer_vec = HyperParams.layer_vec

gca_rom/preprocessing.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,16 @@ def graphs_dataset(dataset, HyperParams, param_sample=None):
105105
pos = torch.cat((xx[:, graph].unsqueeze(1), yy[:, graph].unsqueeze(1), zz[:, graph].unsqueeze(1)), 1)
106106
ei = torch.index_select(pos, 0, edge_index[0, :])
107107
ej = torch.index_select(pos, 0, edge_index[1, :])
108-
edge_diff = ej - ei
108+
edge_attr = torch.abs(ej - ei)
109109
if dataset.dim == 2:
110-
edge_attr = torch.sqrt(torch.pow(edge_diff[:, 0], 2) + torch.pow(edge_diff[:, 1], 2))
110+
edge_weight = torch.sqrt(torch.pow(edge_attr[:, 0], 2) + torch.pow(edge_attr[:, 1], 2)).unsqueeze(1)
111111
elif dataset.dim == 3:
112-
edge_attr = torch.sqrt(torch.pow(edge_diff[:, 0], 2) + torch.pow(edge_diff[:, 1], 2) + torch.pow(edge_diff[:, 2], 2))
112+
edge_weight = torch.sqrt(torch.pow(edge_attr[:, 0], 2) + torch.pow(edge_attr[:, 1], 2) + torch.pow(edge_attr[:, 2], 2)).unsqueeze(1)
113113
if HyperParams.comp == 1:
114114
node_features = VAR_all[graph, :]
115115
else:
116116
node_features = VAR_all[graph, :, :]
117-
dataset_graph = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, pos=pos)
117+
dataset_graph = Data(x=node_features, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr, pos=pos)
118118
graphs.append(dataset_graph)
119119

120120
HyperParams.num_nodes = dataset_graph.num_nodes

0 commit comments

Comments
 (0)