Skip to content

Commit 18e5bba

Browse files
authored
Merge pull request #3 from ChujieChen/develop
finished LSTM for polarity
2 parents 4493380 + 241098b commit 18e5bba

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

src/yews/models/polarity.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def __init__(self, **kwargs):
150150
self.contains_unkown = kwargs["contains_unkown"]
151151
self.start = kwargs['start']
152152
self.end = kwargs['end']
153-
self.lstm = nn.LSTM(input_size, self.hidden_size,bidirectional=self.bidirectional)
153+
self.num_layers = kwargs['num_layers']
154+
self.lstm = nn.LSTM(input_size, self.hidden_size, self.num_layers,
155+
bidirectional=self.bidirectional)
154156
self.fc = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), 3 if self.contains_unkown else 2)
155157

156158
def forward(self, x):
@@ -163,11 +165,16 @@ def forward(self, x):
163165
return out
164166
def polarity_lstm(**kwargs):
165167
r"""A LSTM based model.
166-
Args:
167-
pretrained (bool): If True, returns a model pre-trained on Wenchuan)
168-
progress (bool): If True, displays a progress bar of the download to stderr
168+
Kwargs (form like a dict and should be pass like **kwargs):
169+
hidden_size (default 64): recommended to be similar as the length of trimmed subsequence
170+
num_layers (default 2): layers are stacked and results are from the final layer
171+
start (default 250): start index of the subsequence
172+
end (default 350): end index of the subsequence
173+
bidirectional (default False): run lstm from left to right and from right to left
174+
contains_unkown (default False): True if targets have 0,1,2
169175
"""
170176
default_kwargs = {"hidden_size":64,
177+
"num_layers":2,
171178
"start": 250,
172179
"end": 350,
173180
"bidirectional":False,
@@ -179,6 +186,6 @@ def polarity_lstm(**kwargs):
179186
print(default_kwargs)
180187
print("\n##########################")
181188
if(default_kwargs['end'] < default_kwargs['start']):
182-
raise ValueError('<-- end must be largger than start -->')
189+
raise ValueError('<-- end cannot be smaller than start -->')
183190
model = PolarityLSTM(**default_kwargs)
184191
return model

0 commit comments

Comments
 (0)