Skip to content

Commit ff62eb2

Browse files
authored
Refactor regularizers and add add_weight method. (#4703)
* Refactor regularizers, introduce layer.add_weight * Fix BN add_update syntax * Fix eigenvalue regularizer * Style fixes.
1 parent 2b33675 commit ff62eb2

15 files changed

+524
-539
lines changed

keras/backend/theano_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def to_dense(tensor):
5757

5858

5959
def variable(value, dtype=_FLOATX, name=None):
60-
'''Instantiate a tensor variable.
60+
'''Instantiates a variable.
6161
'''
6262
if hasattr(value, 'tocoo'):
6363
_assert_sparse_module()

keras/engine/topology.py

+152-28
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from six.moves import zip
1414

1515
from .. import backend as K
16+
from .. import initializations
1617
from ..utils.io_utils import ask_to_proceed_with_overwrite
1718
from ..utils.generic_utils import func_dump, func_load
1819

@@ -28,6 +29,11 @@ def to_list(x):
2829
return [x]
2930

3031

32+
def object_list_uid(object_list):
33+
object_list = to_list(object_list)
34+
return ', '.join([str(abs(id(x))) for x in object_list])
35+
36+
3137
class InputSpec(object):
3238
'''This specifies the ndim, dtype and shape of every input to a layer.
3339
Every layer should expose (if appropriate) an `input_spec` attribute:
@@ -239,7 +245,6 @@ class Layer(object):
239245
non_trainable_weights: List of variables.
240246
weights: The concatenation of the lists trainable_weights and
241247
non_trainable_weights (in this order).
242-
regularizers: List of regularizers.
243248
constraints: Dict mapping weights to constraints.
244249
245250
# Methods
@@ -294,8 +299,8 @@ def __init__(self, **kwargs):
294299
self.trainable_weights = []
295300
if not hasattr(self, 'non_trainable_weights'):
296301
self.non_trainable_weights = []
297-
if not hasattr(self, 'regularizers'):
298-
self.regularizers = []
302+
if not hasattr(self, 'losses'):
303+
self.losses = []
299304
if not hasattr(self, 'constraints'):
300305
self.constraints = {} # dict {tensor: constraint instance}
301306
self.built = False
@@ -354,6 +359,19 @@ def non_trainable_weights(self):
354359
def non_trainable_weights(self, weights):
355360
self._non_trainable_weights = weights
356361

362+
@property
363+
def regularizers(self):
364+
warnings.warn('The `regularizers` property of layers/models is deprecated. '
365+
'Regularization losses are now managed via the `losses` '
366+
'layer/model property.')
367+
return []
368+
369+
@regularizers.setter
370+
def regularizers(self, _):
371+
warnings.warn('The `regularizers` property of layers/models is deprecated. '
372+
'Regularization losses are now managed via the `losses` '
373+
'layer/model property.')
374+
357375
def create_input_layer(self, batch_input_shape,
358376
input_dtype=None, name=None):
359377
if not name:
@@ -373,6 +391,32 @@ def create_input_layer(self, batch_input_shape,
373391
# to the input layer we just created.
374392
self(x)
375393

394+
def add_weight(self, shape, initializer, name=None,
395+
trainable=True,
396+
regularizer=None,
397+
constraint=None):
398+
'''Adds a weight variable to the layer.
399+
400+
# Arguments:
401+
shape: The shape tuple of the weight.
402+
initializer: An Initializer instance (callable).
403+
trainable: A boolean, whether the weight should
404+
be trained via backprop or not (assuming
405+
that the layer itself is also trainable).
406+
regularizer: An optional Regularizer instance.
407+
'''
408+
initializer = initializations.get(initializer)
409+
weight = initializer(shape, name=name)
410+
if regularizer is not None:
411+
self.add_loss(regularizer(weight))
412+
if constraint is not None:
413+
self.constraints[weight] = constraint
414+
if trainable:
415+
self.trainable_weights.append(weight)
416+
else:
417+
self.non_trainable_weights.append(weight)
418+
return weight
419+
376420
def assert_input_compatibility(self, input):
377421
'''This checks that the tensor(s) `input`
378422
verify the input assumptions of the layer
@@ -519,15 +563,21 @@ def __call__(self, x, mask=None):
519563
self.add_inbound_node(inbound_layers, node_indices, tensor_indices)
520564
# Outputs were already computed when calling self.add_inbound_node.
521565
outputs = self.inbound_nodes[-1].output_tensors
522-
# If single output tensor: return it,
523-
# else return a list (at least 2 elements).
524-
if len(outputs) == 1:
525-
return outputs[0]
526-
else:
527-
return outputs
528566
else:
529567
# This case appears if the input was not a Keras tensor.
530-
return self.call(x, mask)
568+
outputs = to_list(self.call(x, mask))
569+
570+
# Apply activity regularizer if any:
571+
if hasattr(self, 'activity_regularizer') and self.activity_regularizer is not None:
572+
regularization_losses = [self.activity_regularizer(x) for x in outputs]
573+
self.add_loss(regularization_losses, input_tensors)
574+
575+
# If single output tensor: return it,
576+
# else return a list (at least 2 elements).
577+
if len(outputs) == 1:
578+
return outputs[0]
579+
else:
580+
return outputs
531581

532582
def add_inbound_node(self, inbound_layers,
533583
node_indices=None, tensor_indices=None):
@@ -806,33 +856,78 @@ def output_shape(self):
806856
'ill-defined for the layer. ' +
807857
'Use `get_output_shape_at(node_index)` instead.')
808858

809-
def add_updates(self, updates, inputs):
859+
def add_loss(self, losses, inputs=None):
860+
if losses is None:
861+
return
862+
# Update self.losses
863+
losses = to_list(losses)
864+
if not hasattr(self, 'losses'):
865+
self.losses = []
866+
try:
867+
self.losses += losses
868+
except AttributeError:
869+
# In case self.losses isn't settable
870+
# (i.e. it's a getter method).
871+
# In that case the `losses` property is
872+
# auto-computed and shouldn't be set.
873+
pass
874+
# Update self._per_input_updates
875+
if not hasattr(self, '_per_input_losses'):
876+
self._per_input_losses = {}
877+
if inputs is not None:
878+
inputs_hash = object_list_uid(inputs)
879+
else:
880+
# Updates indexed by None are unconditional
881+
# rather than input-dependent
882+
inputs_hash = None
883+
if inputs_hash not in self._per_input_losses:
884+
self._per_input_losses[inputs_hash] = []
885+
self._per_input_losses[inputs_hash] += losses
886+
887+
def add_update(self, updates, inputs=None):
888+
if updates is None:
889+
return
810890
# Update self.updates
891+
updates = to_list(updates)
811892
if not hasattr(self, 'updates'):
812893
self.updates = []
813894
try:
814895
self.updates += updates
815896
except AttributeError:
897+
# In case self.updates isn't settable
898+
# (i.e. it's a getter method).
899+
# In that case the `updates` property is
900+
# auto-computed and shouldn't be set.
816901
pass
817902
# Update self._per_input_updates
818903
if not hasattr(self, '_per_input_updates'):
819904
self._per_input_updates = {}
820-
inputs = to_list(inputs)
821-
updates = to_list(updates)
822-
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
905+
if inputs is not None:
906+
inputs_hash = object_list_uid(inputs)
907+
else:
908+
# Updates indexed by None are unconditional
909+
# rather than input-dependent
910+
inputs_hash = None
823911
if inputs_hash not in self._per_input_updates:
824912
self._per_input_updates[inputs_hash] = []
825913
self._per_input_updates[inputs_hash] += updates
826914

827915
def get_updates_for(self, inputs):
828916
if not hasattr(self, '_per_input_updates'):
829917
return []
830-
inputs = to_list(inputs)
831-
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
918+
inputs_hash = object_list_uid(inputs)
832919
if inputs_hash in self._per_input_updates:
833920
return self._per_input_updates[inputs_hash]
834921
return []
835922

923+
def get_losses_for(self, inputs):
924+
if not hasattr(self, '_per_input_losses'):
925+
return []
926+
inputs_hash = object_list_uid(inputs)
927+
if inputs_hash in self._per_input_losses:
928+
return self._per_input_losses[inputs_hash]
929+
return []
930+
836931
@property
837932
def weights(self):
838933
return self.trainable_weights + self.non_trainable_weights
@@ -950,7 +1045,6 @@ def __init__(self, input_shape=None, batch_input_shape=None,
9501045

9511046
self.trainable_weights = []
9521047
self.non_trainable_weights = []
953-
self.regularizers = []
9541048
self.constraints = {}
9551049

9561050
self.sparse = sparse
@@ -1151,7 +1245,6 @@ def __init__(self, layers=None, mode='sum', concat_axis=-1,
11511245
self.inbound_nodes = []
11521246
self.outbound_nodes = []
11531247
self.constraints = {}
1154-
self.regularizers = []
11551248
self.trainable_weights = []
11561249
self.non_trainable_weights = []
11571250
self.supports_masking = True
@@ -1587,7 +1680,6 @@ class Container(Layer):
15871680
supports_masking (boolean)
15881681
trainable_weights (list of variables)
15891682
non_trainable_weights (list of variables)
1590-
regularizers (list of regularizers)
15911683
constraints (list of tuples (weight, constraint))
15921684
15931685
# Methods
@@ -1901,7 +1993,6 @@ def build_map_of_graph(tensor, seen_nodes=set(), depth=0,
19011993
self.supports_masking = False
19021994
# The following are implemented as property functions:
19031995
# self.constraints
1904-
# self.regularizers
19051996
# self.trainable_weights
19061997
# self.non_trainable_weights
19071998
# self.input_spec
@@ -1946,14 +2037,38 @@ def updates(self):
19462037
if len(layer.inbound_nodes) == 1:
19472038
updates += layer.updates
19482039
else:
2040+
# Collect updates that are dependent on inputs
2041+
# that are part of the model.
19492042
for node_index, node in enumerate(layer.inbound_nodes):
19502043
node_key = layer.name + '_ib-' + str(node_index)
19512044
if node_key in self.container_nodes:
19522045
# The model owns this layer node.
19532046
inputs = node.input_tensors
19542047
updates += layer.get_updates_for(inputs)
2048+
# Collect unconditional updates.
2049+
updates += layer.get_updates_for(None)
19552050
return updates
19562051

2052+
@property
2053+
def losses(self):
2054+
losses = []
2055+
for layer in self.layers:
2056+
if hasattr(layer, 'losses'):
2057+
if len(layer.inbound_nodes) == 1:
2058+
losses += layer.losses
2059+
else:
2060+
# Collect losses that are dependent on inputs
2061+
# that are part of the model.
2062+
for node_index, node in enumerate(layer.inbound_nodes):
2063+
node_key = layer.name + '_ib-' + str(node_index)
2064+
if node_key in self.container_nodes:
2065+
# The model owns this layer node.
2066+
inputs = node.input_tensors
2067+
losses += layer.get_losses_for(inputs)
2068+
# Collect unconditional losses.
2069+
losses += layer.get_losses_for(None)
2070+
return losses
2071+
19572072
@property
19582073
def stateful(self):
19592074
return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])
@@ -1990,10 +2105,13 @@ def constraints(self):
19902105

19912106
@property
19922107
def regularizers(self):
1993-
regs = []
1994-
for layer in self.layers:
1995-
regs += layer.regularizers
1996-
return regs
2108+
warnings.warn('The `regularizers` attribute of layers/models '
2109+
'is deprecated. '
2110+
'Regularization losses are now managed via the `losses` '
2111+
'layer/model property.\n'
2112+
'The `regularizers` attribute will be removed '
2113+
'after 06/2017.')
2114+
return []
19972115

19982116
@property
19992117
def trainable_weights(self):
@@ -2061,8 +2179,7 @@ def uses_learning_phase(self):
20612179
'''True if any layer in the graph uses it.
20622180
'''
20632181
layers_learning_phase = any([layer.uses_learning_phase for layer in self.layers])
2064-
regs_learning_phase = any([reg.uses_learning_phase for reg in self.regularizers])
2065-
return layers_learning_phase or regs_learning_phase
2182+
return layers_learning_phase
20662183

20672184
def call(self, input, mask=None):
20682185
'''`call` just reapplies all ops in the graph to the new inputs
@@ -2239,9 +2356,16 @@ def run_internal_graph(self, inputs, masks=None):
22392356
output_tensors = to_list(layer.call(computed_tensors, computed_masks))
22402357
output_masks = to_list(layer.compute_mask(computed_tensors, computed_masks))
22412358

2242-
# update model updates
2359+
# Update model updates and losses:
22432360
layer_inputs = [x[0] for x in computed_data]
2244-
self.add_updates(layer.get_updates_for(layer_inputs), inputs)
2361+
# Keep track of updates that depend on the inputs (e.g. BN updates).
2362+
self.add_update(layer.get_updates_for(layer_inputs), inputs)
2363+
# Keep track of unconditional updates (e.g. a counter).
2364+
self.add_update(layer.get_updates_for(None), None)
2365+
# Keep track of losses that depend on the inputs (e.g. activity regularizers).
2366+
self.add_loss(layer.get_losses_for(layer_inputs), inputs)
2367+
# Keep track of unconditional losses (e.g. weight regularizers).
2368+
self.add_loss(layer.get_losses_for(None), None)
22452369

22462370
# Update _keras_shape.
22472371
if all([hasattr(x, '_keras_shape') for x in computed_tensors]):

keras/engine/training.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,10 @@ def compile(self, optimizer, loss, metrics=[], loss_weights=None,
611611
else:
612612
total_loss += loss_weight * output_loss
613613

614-
# add regularization penalties to the loss
615-
for r in self.regularizers:
616-
total_loss = r(total_loss)
614+
# add regularization penalties
615+
# and other layer-specific losses
616+
for loss_tensor in self.losses:
617+
total_loss += loss_tensor
617618

618619
# list of same size as output_names.
619620
# contains tuples (metrics for output, names of metrics)

0 commit comments

Comments
 (0)