Skip to content

Commit 59e00df

Browse files
lostellalanking520
authored andcommitted
fixed symbols naming in RNNCell, LSTMCell, GRUCell (apache#12794)
* fixed symbols naming in RNNCell and LSTMCell * fixed GRUCell as well * added test * fixed tests?
1 parent c105738 commit 59e00df

File tree

2 files changed

+65
-8
lines changed

2 files changed

+65
-8
lines changed

python/mxnet/gluon/rnn/rnn_cell.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,8 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
398398
h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
399399
num_hidden=self._hidden_size,
400400
name=prefix+'h2h')
401-
output = self._get_activation(F, i2h + h2h, self._activation,
401+
i2h_plus_h2h = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
402+
output = self._get_activation(F, i2h_plus_h2h, self._activation,
402403
name=prefix+'out')
403404

404405
return output, [output]
@@ -511,7 +512,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
511512
num_hidden=self._hidden_size*4, name=prefix+'i2h')
512513
h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
513514
num_hidden=self._hidden_size*4, name=prefix+'h2h')
514-
gates = i2h + h2h
515+
gates = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
515516
slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
516517
in_gate = self._get_activation(
517518
F, slice_gates[0], self._recurrent_activation, name=prefix+'i')
@@ -521,9 +522,10 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
521522
F, slice_gates[2], self._activation, name=prefix+'c')
522523
out_gate = self._get_activation(
523524
F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
524-
next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
525+
next_c = F._internal._plus(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'),
526+
F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'),
525527
name=prefix+'state')
526-
next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation),
528+
next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'),
527529
name=prefix+'out')
528530

529531
return next_h, [next_h, next_c]
@@ -635,15 +637,22 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
635637
h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3,
636638
name=prefix+'h2h_slice')
637639

638-
reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid",
640+
reset_gate = F.Activation(F.elemwise_add(i2h_r, h2h_r, name=prefix+'plus0'), act_type="sigmoid",
639641
name=prefix+'r_act')
640-
update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid",
642+
update_gate = F.Activation(F.elemwise_add(i2h_z, h2h_z, name=prefix+'plus1'), act_type="sigmoid",
641643
name=prefix+'z_act')
642644

643-
next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh",
645+
next_h_tmp = F.Activation(F.elemwise_add(i2h,
646+
F.elemwise_mul(reset_gate, h2h, name=prefix+'mul0'),
647+
name=prefix+'plus2'),
648+
act_type="tanh",
644649
name=prefix+'h_act')
645650

646-
next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h,
651+
ones = F.ones_like(update_gate, name=prefix+"ones_like0")
652+
next_h = F._internal._plus(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'),
653+
next_h_tmp,
654+
name=prefix+'mul1'),
655+
F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'),
647656
name=prefix+'out')
648657

649658
return next_h, [next_h]

tests/python/unittest/test_gluon_rnn.py

+48
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,54 @@ def test_rnn_cells():
379379
net.add(gluon.rnn.GRUCell(100, input_size=100))
380380
check_rnn_forward(net, mx.nd.ones((8, 3, 200)))
381381

382+
383+
def test_rnn_cells_export_import():
384+
class RNNLayer(gluon.HybridBlock):
385+
def __init__(self):
386+
super(RNNLayer, self).__init__()
387+
with self.name_scope():
388+
self.cell = gluon.rnn.RNNCell(hidden_size=1)
389+
390+
def hybrid_forward(self, F, seq):
391+
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
392+
return outputs
393+
394+
class LSTMLayer(gluon.HybridBlock):
395+
def __init__(self):
396+
super(LSTMLayer, self).__init__()
397+
with self.name_scope():
398+
self.cell = gluon.rnn.LSTMCell(hidden_size=1)
399+
400+
def hybrid_forward(self, F, seq):
401+
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
402+
return outputs
403+
404+
class GRULayer(gluon.HybridBlock):
405+
def __init__(self):
406+
super(GRULayer, self).__init__()
407+
with self.name_scope():
408+
self.cell = gluon.rnn.GRUCell(hidden_size=1)
409+
410+
def hybrid_forward(self, F, seq):
411+
outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True)
412+
return outputs
413+
414+
for hybrid in [RNNLayer(), LSTMLayer(), GRULayer()]:
415+
hybrid.initialize()
416+
hybrid.hybridize()
417+
input = mx.nd.ones(shape=(1, 2, 1))
418+
output1 = hybrid(input)
419+
hybrid.export(path="./model", epoch=0)
420+
symbol = mx.gluon.SymbolBlock.imports(
421+
symbol_file="./model-symbol.json",
422+
input_names=["data"],
423+
param_file="./model-0000.params",
424+
ctx=mx.Context.default_ctx
425+
)
426+
output2 = symbol(input)
427+
assert_almost_equal(output1.asnumpy(), output2.asnumpy())
428+
429+
382430
def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
383431
layer.collect_params().initialize()
384432
inputs.attach_grad()

0 commit comments

Comments
 (0)