Skip to content

Commit

Permalink
【Hackathon 6th No.24 】 为 paddle.nn.LSTM/RNNBase 功能增强 (#63284)
Browse files Browse the repository at this point in the history
* ✨ Enhance LSTM and RNNBase

* fix ci coverage

* adjust position of proj_size

* fix proj_size to number

* reshape weight_ho

* update

* update docstring

* update docstring

* try to fix docstring

* try to fix docstring
  • Loading branch information
Asthestarsfalll authored Apr 19, 2024
1 parent 4227ea5 commit 67d3fd0
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 60 deletions.
99 changes: 76 additions & 23 deletions python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,12 @@ class LSTMCell(RNNCellBase):
y_{t} & = h_{t}
If `proj_size` is specified, the dimension of hidden state :math:`h_{t}` will be projected to `proj_size`:
.. math::
h_{t} = h_{t}W_{proj\_size}
where :math:`\sigma` is the sigmoid function, and * is the elementwise
multiplication operator.
Expand All @@ -910,12 +916,16 @@ class LSTMCell(RNNCellBase):
`bias_ih`. Default: None.
bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh`. Default: None.
proj_size (int, optional): If specified, the output hidden state
will be projected to `proj_size`. `proj_size` must be smaller than
`hidden_size`. Default: None.
name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`.
Variables:
- **weight_ih** (Parameter): shape (4 * hidden_size, input_size), input to hidden weight, which corresponds to the concatenation of :math:`W_{ii}, W_{if}, W_{ig}, W_{io}` in the formula.
- **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula.
- **weight_hh** (Parameter): shape (4 * hidden_size, hidden_size), hidden to hidden weight, which corresponds to the concatenation of :math:`W_{hi}, W_{hf}, W_{hg}, W_{ho}` in the formula. If proj_size was specified, the shape will be (4 * hidden_size, proj_size).
- **weight_ho** (Parameter, optional): shape (hidden_size, proj_size), project the hidden state.
- **bias_ih** (Parameter): shape (4 * hidden_size, ), input to hidden bias, which corresponds to the concatenation of :math:`b_{ii}, b_{if}, b_{ig}, b_{io}` in the formula.
- **bias_hh** (Parameter): shape (4 * hidden_size, ), hidden to hidden bias, which corresponds to the concatenation of :math:`b_{hi}, b_{hf}, b_{hg}, b_{ho}` in the formula.
Expand All @@ -924,8 +934,9 @@ class LSTMCell(RNNCellBase):
- **states** (list|tuple, optional): a list/tuple of two tensors, each of shape `[batch_size, hidden_size]`, the previous hidden state, corresponding to :math:`h_{t-1}, c_{t-1}` in the formula. When states is None, zero state is used. Defaults to None.
Returns:
- **outputs** (Tensor): shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula.
- **states** (tuple): a tuple of two tensors, each of shape `[batch_size, hidden_size]`, the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula.
- **outputs** (Tensor). Shape `[batch_size, hidden_size]`, the output, corresponding to :math:`h_{t}` in the formula. If `proj_size` is specified, output shape will be `[batch_size, proj_size]`.
- **states** (tuple). A tuple of two tensors, each of shape `[batch_size, hidden_size]`, the new hidden states, corresponding to :math:`h_{t}, c_{t}` in the formula.
If `proj_size` is specified, shape of :math:`h_{t}` will be `[batch_size, proj_size]`.
Notes:
All the weights and bias are initialized with `Uniform(-std, std)` by
Expand Down Expand Up @@ -962,13 +973,22 @@ def __init__(
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
proj_size=0,
name=None,
):
super().__init__()
if hidden_size <= 0:
raise ValueError(
f"hidden_size of {self.__class__.__name__} must be greater than 0, but now equals to {hidden_size}"
)
if proj_size < 0:
raise ValueError(
f"proj_size of {self.__class__.__name__} must be greater than 0, but now equals to {hidden_size}"
)

if proj_size >= hidden_size:
raise ValueError("proj_size must be smaller than hidden_size")

std = 1.0 / math.sqrt(hidden_size)
if weight_ih_attr is not False:
self.weight_ih = self.create_parameter(
Expand All @@ -985,13 +1005,13 @@ def __init__(
self.weight_ih.stop_gradient = True
if weight_hh_attr is not False:
self.weight_hh = self.create_parameter(
(4 * hidden_size, hidden_size),
(4 * hidden_size, proj_size or hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std),
)
else:
self.weight_hh = self.create_parameter(
(4 * hidden_size, hidden_size),
(4 * hidden_size, proj_size or hidden_size),
None,
default_initializer=I.Constant(1.0),
)
Expand Down Expand Up @@ -1027,6 +1047,14 @@ def __init__(
)
self.bias_hh.stop_gradient = True

self.proj_size = proj_size
if proj_size > 0:
self.weight_ho = self.create_parameter(
(hidden_size, proj_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std),
)

self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
Expand All @@ -1050,6 +1078,8 @@ def forward(self, inputs, states=None):
o = self._gate_activation(chunked_gates[3])
c = f * pre_cell + i * self._activation(chunked_gates[2])
h = o * self._activation(c)
if self.proj_size > 0:
h = paddle.matmul(h, self.weight_ho)

return h, (h, c)

Expand All @@ -1061,7 +1091,7 @@ def state_shape(self):
automatically inserted into shape). These two shapes correspond
to :math:`h_{t-1}` and :math:`c_{t-1}` separately.
"""
return ((self.hidden_size,), (self.hidden_size,))
return ((self.hidden_size,), (self.proj_size or self.hidden_size,))

def extra_repr(self):
return '{input_size}, {hidden_size}'.format(**self.__dict__)
Expand Down Expand Up @@ -1436,6 +1466,7 @@ def __init__(
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
proj_size=0,
):
super().__init__()
bidirectional_list = ["bidirectional", "bidirect"]
Expand All @@ -1455,28 +1486,40 @@ def __init__(
"bias_hh_attr": bias_hh_attr,
}

self.proj_size = proj_size
if proj_size > 0:
assert mode == 'LSTM'

if mode == "LSTM":
rnn_cls = LSTMCell
kwargs["proj_size"] = proj_size
elif mode == "GRU":
rnn_cls = GRUCell
elif mode == "RNN_RELU":
rnn_cls = SimpleRNNCell
kwargs["activation"] = 'relu'
elif mode == "RNN_TANH":
rnn_cls = SimpleRNNCell
kwargs["activation"] = 'tanh'
else:
rnn_cls = SimpleRNNCell
kwargs["activation"] = self.activation

in_size = proj_size or hidden_size
if direction in ["forward"]:
is_reverse = False
cell = rnn_cls(input_size, hidden_size, **kwargs)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = rnn_cls(hidden_size, hidden_size, **kwargs)
for _ in range(1, num_layers):
cell = rnn_cls(in_size, hidden_size, **kwargs)
self.append(RNN(cell, is_reverse, time_major))
elif direction in bidirectional_list:
cell_fw = rnn_cls(input_size, hidden_size, **kwargs)
cell_bw = rnn_cls(input_size, hidden_size, **kwargs)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs)
cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs)
for _ in range(1, num_layers):
cell_fw = rnn_cls(2 * in_size, hidden_size, **kwargs)
cell_bw = rnn_cls(2 * in_size, hidden_size, **kwargs)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
raise ValueError(
Expand Down Expand Up @@ -1652,21 +1695,18 @@ def forward(self, inputs, initial_states=None, sequence_length=None):
batch_index = 1 if self.time_major else 0
dtype = inputs.dtype
if initial_states is None:
state_shape = (
self.num_layers * self.num_directions,
-1,
self.hidden_size,
)

fill_shape = list(state_shape)
dims = ([self.proj_size or self.hidden_size], [self.hidden_size])
fill_shape = [self.num_layers * self.num_directions, -1]
if inputs.shape[batch_index] > 0:
fill_shape[1] = inputs.shape[batch_index]
else:
fill_shape[1] = paddle.shape(inputs)[batch_index].item()
initial_states = tuple(
[
paddle.full(shape=fill_shape, fill_value=0, dtype=dtype)
for _ in range(self.state_components)
paddle.full(
shape=fill_shape + dims[i], fill_value=0, dtype=dtype
)
for i in range(self.state_components)
]
)
else:
Expand Down Expand Up @@ -1834,6 +1874,7 @@ def __init__(
weight_hh_attr,
bias_ih_attr,
bias_hh_attr,
0, # proj_size
)


Expand Down Expand Up @@ -1864,6 +1905,12 @@ class LSTM(RNNBase):
y_{t} & = h_{t}
If `proj_size` is specified, the dimension of hidden state :math:`h_{t}` will be projected to `proj_size`:
.. math::
h_{t} = h_{t}W_{proj\_size}
where :math:`\sigma` is the sigmoid function, and * is the elementwise
multiplication operator.
Expand Down Expand Up @@ -1891,6 +1938,9 @@ class LSTM(RNNBase):
`bias_ih` of each cells. Default: None.
bias_hh_attr (ParamAttr, optional): The parameter attribute for the
`bias_hh` of each cells. Default: None.
proj_size (int, optional): If specified, the output hidden state of each layer
will be projected to `proj_size`. `proj_size` must be smaller than `hidden_size`.
Default: 0.
name (str, optional): Name for the operation (optional, default is
None). For more information, please refer to :ref:`api_guide_Name`.
Expand All @@ -1901,9 +1951,9 @@ class LSTM(RNNBase):
Returns:
- **outputs** (Tensor): the output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`, If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence.
- **final_states** (tuple): the final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1.
- **outputs** (Tensor). The output sequence. If `time_major` is True, the shape is `[time_steps, batch_size, num_directions * hidden_size]`. If `proj_size` is specified, shape will be `[time_major, batch_size, num_directions * proj_size]`. If `time_major` is False, the shape is `[batch_size, time_steps, num_directions * hidden_size]`. Note that `num_directions` is 2 if direction is "bidirectional" else 1. `time_steps` means the length of the output sequence.
- **final_states** (tuple). The final state, a tuple of two tensors, h and c. The shape of each is `[num_layers * num_directions, batch_size, hidden_size]`. If `proj_size` is specified, the last dimension of h will be proj_size.
Note that `num_directions` is 2 if direction is "bidirectional" (the index of forward states are 0, 2, 4, 6... and the index of backward states are 1, 3, 5, 7...), else 1.
Variables:
- **weight_ih_l[k]**: the learnable input-hidden weights of the k-th layer. If `k = 0`, the shape is `[hidden_size, input_size]`. Otherwise, the shape is `[hidden_size, num_directions * hidden_size]`.
Expand Down Expand Up @@ -1946,6 +1996,7 @@ def __init__(
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
proj_size=0,
name=None,
):
super().__init__(
Expand All @@ -1960,6 +2011,7 @@ def __init__(
weight_hh_attr,
bias_ih_attr,
bias_hh_attr,
proj_size,
)


Expand Down Expand Up @@ -2079,4 +2131,5 @@ def __init__(
weight_hh_attr,
bias_ih_attr,
bias_hh_attr,
0, # proj_size
)
65 changes: 60 additions & 5 deletions test/deprecated/rnn/test_rnn_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,36 @@


class TestSimpleRNN(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
def __init__(
self, time_major=True, direction="forward", place="cpu", mode='RNN_TANH'
):
super().__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place
self.mode = mode

def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction
16,
32,
2,
time_major=self.time_major,
direction=self.direction,
nonlinearity=self.mode,
)
rnn2 = paddle.nn.SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction
16,
32,
2,
time_major=self.time_major,
direction=self.direction,
activation=self.mode[4:].lower(),
)
convert_params_for_net(rnn1, rnn2)

Expand Down Expand Up @@ -230,7 +243,9 @@ def test_with_initial_state(self):
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(2 * self.num_directions, 4, 32)
prev_h = np.random.randn(
2 * self.num_directions, 4, getattr(self, "proj_size", 32)
)
prev_c = np.random.randn(2 * self.num_directions, 4, 32)

y1, (h1, c1) = rnn1(x, (prev_h, prev_c))
Expand Down Expand Up @@ -292,6 +307,35 @@ def runTest(self):
self.test_predict()


class TestLSTMWithProjSize(TestLSTM):
def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = LSTM(
16,
32,
2,
time_major=self.time_major,
direction=self.direction,
proj_size=8,
)
rnn2 = paddle.nn.LSTM(
16,
32,
2,
time_major=self.time_major,
direction=self.direction,
proj_size=8,
)
convert_params_for_net(rnn1, rnn2)

self.rnn1 = rnn1
self.rnn2 = rnn2
self.proj_size = 8


def predict_test_util(place, mode, stop_gradient=True):
place = paddle.set_device(place)
paddle.seed(123)
Expand Down Expand Up @@ -366,8 +410,19 @@ def load_tests(loader, tests, pattern):
for direction in ["forward", "bidirectional", "bidirect"]:
for time_major in [True, False]:
for device in devices:
for test_class in [TestSimpleRNN, TestLSTM, TestGRU]:
for test_class in [
TestSimpleRNN,
TestLSTM,
TestGRU,
TestLSTMWithProjSize,
]:
suite.addTest(test_class(time_major, direction, device))
if test_class == TestSimpleRNN:
suite.addTest(
test_class(
time_major, direction, device, mode="RNN_RELU"
)
)
return suite


Expand Down
Loading

0 comments on commit 67d3fd0

Please sign in to comment.