Skip to content

Commit cf0a261

Browse files
Adding Batch support for LSTM_AE
1 parent 3bcdf1c commit cf0a261

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

sequitur/models/lstm_ae.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def __init__(self, input_dim, out_dim, h_dims, h_activ, out_activ):
2727
self.h_activ, self.out_activ = h_activ, out_activ
2828

2929
def forward(self, x):
30-
x = x.unsqueeze(0)
3130
for index, layer in enumerate(self.layers):
3231
x, (h_n, c_n) = layer(x)
3332

@@ -36,7 +35,7 @@ def forward(self, x):
3635
elif self.out_activ and index == self.num_layers - 1:
3736
return self.out_activ(h_n).squeeze()
3837

39-
return h_n.squeeze()
38+
return h_n
4039

4140

4241
class Decoder(nn.Module):
@@ -56,20 +55,21 @@ def __init__(self, input_dim, out_dim, h_dims, h_activ):
5655
self.layers.append(layer)
5756

5857
self.h_activ = h_activ
59-
self.dense_matrix = nn.Parameter(
60-
torch.rand((layer_dims[-1], out_dim), dtype=torch.float),
61-
requires_grad=True
62-
)
58+
self.dense_layer = nn.Linear(layer_dims[-1], out_dim)
6359

6460
def forward(self, x, seq_len):
65-
x = x.repeat(seq_len, 1).unsqueeze(0)
61+
if len(x.shape) == 1 : # In case the batch dimension is not there
62+
x = x.repeat(seq_len, 1) # Add the sequence dimension by repeating the embedding
63+
else :
64+
x = x.unsqueeze(1).repeat(1, seq_len, 1) # Add the sequence dimension by repeating the embedding
65+
6666
for index, layer in enumerate(self.layers):
6767
x, (h_n, c_n) = layer(x)
6868

6969
if self.h_activ and index < self.num_layers - 1:
7070
x = self.h_activ(x)
7171

72-
return torch.mm(x.squeeze(), self.dense_matrix)
72+
return self.dense_layer(x)
7373

7474

7575
######
@@ -88,7 +88,10 @@ def __init__(self, input_dim, encoding_dim, h_dims=[], h_activ=nn.Sigmoid(),
8888
h_activ)
8989

9090
def forward(self, x):
91-
seq_len = x.shape[0]
91+
if len(x.shape) <= 2 : # In case the batch dimension is not there
92+
seq_len = x.shape[0]
93+
else :
94+
seq_len = x.shape[1]
9295
x = self.encoder(x)
9396
x = self.decoder(x, seq_len)
9497

0 commit comments

Comments
 (0)