@@ -398,7 +398,8 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
398
398
h2h = F .FullyConnected (data = states [0 ], weight = h2h_weight , bias = h2h_bias ,
399
399
num_hidden = self ._hidden_size ,
400
400
name = prefix + 'h2h' )
401
- output = self ._get_activation (F , i2h + h2h , self ._activation ,
401
+ i2h_plus_h2h = F .elemwise_add (i2h , h2h , name = prefix + 'plus0' )
402
+ output = self ._get_activation (F , i2h_plus_h2h , self ._activation ,
402
403
name = prefix + 'out' )
403
404
404
405
return output , [output ]
@@ -511,7 +512,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
511
512
num_hidden = self ._hidden_size * 4 , name = prefix + 'i2h' )
512
513
h2h = F .FullyConnected (data = states [0 ], weight = h2h_weight , bias = h2h_bias ,
513
514
num_hidden = self ._hidden_size * 4 , name = prefix + 'h2h' )
514
- gates = i2h + h2h
515
+ gates = F . elemwise_add ( i2h , h2h , name = prefix + 'plus0' )
515
516
slice_gates = F .SliceChannel (gates , num_outputs = 4 , name = prefix + 'slice' )
516
517
in_gate = self ._get_activation (
517
518
F , slice_gates [0 ], self ._recurrent_activation , name = prefix + 'i' )
@@ -521,9 +522,10 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
521
522
F , slice_gates [2 ], self ._activation , name = prefix + 'c' )
522
523
out_gate = self ._get_activation (
523
524
F , slice_gates [3 ], self ._recurrent_activation , name = prefix + 'o' )
524
- next_c = F ._internal ._plus (forget_gate * states [1 ], in_gate * in_transform ,
525
+ next_c = F ._internal ._plus (F .elemwise_mul (forget_gate , states [1 ], name = prefix + 'mul0' ),
526
+ F .elemwise_mul (in_gate , in_transform , name = prefix + 'mul1' ),
525
527
name = prefix + 'state' )
526
- next_h = F ._internal ._mul (out_gate , F .Activation (next_c , act_type = self ._activation ),
528
+ next_h = F ._internal ._mul (out_gate , F .Activation (next_c , act_type = self ._activation , name = prefix + 'activation0' ),
527
529
name = prefix + 'out' )
528
530
529
531
return next_h , [next_h , next_c ]
@@ -635,15 +637,22 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
635
637
h2h_r , h2h_z , h2h = F .SliceChannel (h2h , num_outputs = 3 ,
636
638
name = prefix + 'h2h_slice' )
637
639
638
- reset_gate = F .Activation (i2h_r + h2h_r , act_type = "sigmoid" ,
640
+ reset_gate = F .Activation (F . elemwise_add ( i2h_r , h2h_r , name = prefix + 'plus0' ) , act_type = "sigmoid" ,
639
641
name = prefix + 'r_act' )
640
- update_gate = F .Activation (i2h_z + h2h_z , act_type = "sigmoid" ,
642
+ update_gate = F .Activation (F . elemwise_add ( i2h_z , h2h_z , name = prefix + 'plus1' ) , act_type = "sigmoid" ,
641
643
name = prefix + 'z_act' )
642
644
643
- next_h_tmp = F .Activation (i2h + reset_gate * h2h , act_type = "tanh" ,
645
+ next_h_tmp = F .Activation (F .elemwise_add (i2h ,
646
+ F .elemwise_mul (reset_gate , h2h , name = prefix + 'mul0' ),
647
+ name = prefix + 'plus2' ),
648
+ act_type = "tanh" ,
644
649
name = prefix + 'h_act' )
645
650
646
- next_h = F ._internal ._plus ((1. - update_gate ) * next_h_tmp , update_gate * prev_state_h ,
651
+ ones = F .ones_like (update_gate , name = prefix + "ones_like0" )
652
+ next_h = F ._internal ._plus (F .elemwise_mul (F .elemwise_sub (ones , update_gate , name = prefix + 'minus0' ),
653
+ next_h_tmp ,
654
+ name = prefix + 'mul1' ),
655
+ F .elemwise_mul (update_gate , prev_state_h , name = prefix + 'mul20' ),
647
656
name = prefix + 'out' )
648
657
649
658
return next_h , [next_h ]
0 commit comments