@@ -27,7 +27,6 @@ def __init__(self, input_dim, out_dim, h_dims, h_activ, out_activ):
27
27
self .h_activ , self .out_activ = h_activ , out_activ
28
28
29
29
def forward (self , x ):
30
- x = x .unsqueeze (0 )
31
30
for index , layer in enumerate (self .layers ):
32
31
x , (h_n , c_n ) = layer (x )
33
32
@@ -36,7 +35,7 @@ def forward(self, x):
36
35
elif self .out_activ and index == self .num_layers - 1 :
37
36
return self .out_activ (h_n ).squeeze ()
38
37
39
- return h_n . squeeze ()
38
+ return h_n
40
39
41
40
42
41
class Decoder (nn .Module ):
@@ -56,20 +55,21 @@ def __init__(self, input_dim, out_dim, h_dims, h_activ):
56
55
self .layers .append (layer )
57
56
58
57
self .h_activ = h_activ
59
- self .dense_matrix = nn .Parameter (
60
- torch .rand ((layer_dims [- 1 ], out_dim ), dtype = torch .float ),
61
- requires_grad = True
62
- )
58
+ self .dense_layer = nn .Linear (layer_dims [- 1 ], out_dim )
63
59
64
60
def forward (self , x , seq_len ):
65
- x = x .repeat (seq_len , 1 ).unsqueeze (0 )
61
+ if len (x .shape ) == 1 : # In case the batch dimension is not there
62
+ x = x .repeat (seq_len , 1 ) # Add the sequence dimension by repeating the embedding
63
+ else :
64
+ x = x .unsqueeze (1 ).repeat (1 , seq_len , 1 ) # Add the sequence dimension by repeating the embedding
65
+
66
66
for index , layer in enumerate (self .layers ):
67
67
x , (h_n , c_n ) = layer (x )
68
68
69
69
if self .h_activ and index < self .num_layers - 1 :
70
70
x = self .h_activ (x )
71
71
72
- return torch . mm ( x . squeeze (), self . dense_matrix )
72
+ return self . dense_layer ( x )
73
73
74
74
75
75
######
@@ -88,7 +88,10 @@ def __init__(self, input_dim, encoding_dim, h_dims=[], h_activ=nn.Sigmoid(),
88
88
h_activ )
89
89
90
90
def forward (self , x ):
91
- seq_len = x .shape [0 ]
91
+ if len (x .shape ) <= 2 : # In case the batch dimension is not there
92
+ seq_len = x .shape [0 ]
93
+ else :
94
+ seq_len = x .shape [1 ]
92
95
x = self .encoder (x )
93
96
x = self .decoder (x , seq_len )
94
97
0 commit comments