Skip to content

Commit

Permalink
Fix RNN layers dynamic trainable attr
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 4, 2017
1 parent cc08f0f commit 3292aa5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
4 changes: 4 additions & 0 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,13 +731,17 @@ def from_config(cls, config, custom_objects=None):

@property
def trainable_weights(self):
if not self.trainable:
return []
if isinstance(self.cell, Layer):
return self.cell.trainable_weights
return []

@property
def non_trainable_weights(self):
if isinstance(self.cell, Layer):
if not self.trainable:
return self.cell.weights
return self.cell.non_trainable_weights
return []

Expand Down
17 changes: 17 additions & 0 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,23 @@ def test_regularizer(layer_class):
assert len(layer.get_losses_for(x)) == 1


@rnn_test
def test_trainability(layer_class):
layer = layer_class(units)
layer.build((None, None, embedding_dim))
assert len(layer.weights) == 3
assert len(layer.trainable_weights) == 3
assert len(layer.non_trainable_weights) == 0
layer.trainable = False
assert len(layer.weights) == 3
assert len(layer.trainable_weights) == 0
assert len(layer.non_trainable_weights) == 3
layer.trainable = True
assert len(layer.weights) == 3
assert len(layer.trainable_weights) == 3
assert len(layer.non_trainable_weights) == 0


@keras_test
def test_masking_layer():
''' This test based on a previously failing issue here:
Expand Down

0 comments on commit 3292aa5

Please sign in to comment.