diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index 136fd3c895f6..8b2d882713ad 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -731,6 +731,8 @@ def from_config(cls, config, custom_objects=None): @property def trainable_weights(self): + if not self.trainable: + return [] if isinstance(self.cell, Layer): return self.cell.trainable_weights return [] @@ -738,6 +740,8 @@ def trainable_weights(self): @property def non_trainable_weights(self): if isinstance(self.cell, Layer): + if not self.trainable: + return self.cell.weights return self.cell.non_trainable_weights return [] diff --git a/tests/keras/layers/recurrent_test.py b/tests/keras/layers/recurrent_test.py index 20f0065f376c..f70dca3c878f 100644 --- a/tests/keras/layers/recurrent_test.py +++ b/tests/keras/layers/recurrent_test.py @@ -212,6 +212,23 @@ def test_regularizer(layer_class): assert len(layer.get_losses_for(x)) == 1 +@rnn_test +def test_trainability(layer_class): + layer = layer_class(units) + layer.build((None, None, embedding_dim)) + assert len(layer.weights) == 3 + assert len(layer.trainable_weights) == 3 + assert len(layer.non_trainable_weights) == 0 + layer.trainable = False + assert len(layer.weights) == 3 + assert len(layer.trainable_weights) == 0 + assert len(layer.non_trainable_weights) == 3 + layer.trainable = True + assert len(layer.weights) == 3 + assert len(layer.trainable_weights) == 3 + assert len(layer.non_trainable_weights) == 0 + + @keras_test def test_masking_layer(): ''' This test based on a previously failing issue here: