13
13
from six .moves import zip
14
14
15
15
from .. import backend as K
16
+ from .. import initializations
16
17
from ..utils .io_utils import ask_to_proceed_with_overwrite
17
18
from ..utils .generic_utils import func_dump , func_load
18
19
@@ -28,6 +29,11 @@ def to_list(x):
28
29
return [x ]
29
30
30
31
32
+ def object_list_uid (object_list ):
33
+ object_list = to_list (object_list )
34
+ return ', ' .join ([str (abs (id (x ))) for x in object_list ])
35
+
36
+
31
37
class InputSpec (object ):
32
38
'''This specifies the ndim, dtype and shape of every input to a layer.
33
39
Every layer should expose (if appropriate) an `input_spec` attribute:
@@ -239,7 +245,6 @@ class Layer(object):
239
245
non_trainable_weights: List of variables.
240
246
weights: The concatenation of the lists trainable_weights and
241
247
non_trainable_weights (in this order).
242
- regularizers: List of regularizers.
243
248
constraints: Dict mapping weights to constraints.
244
249
245
250
# Methods
@@ -294,8 +299,8 @@ def __init__(self, **kwargs):
294
299
self .trainable_weights = []
295
300
if not hasattr (self , 'non_trainable_weights' ):
296
301
self .non_trainable_weights = []
297
- if not hasattr (self , 'regularizers ' ):
298
- self .regularizers = []
302
+ if not hasattr (self , 'losses ' ):
303
+ self .losses = []
299
304
if not hasattr (self , 'constraints' ):
300
305
self .constraints = {} # dict {tensor: constraint instance}
301
306
self .built = False
@@ -354,6 +359,19 @@ def non_trainable_weights(self):
354
359
def non_trainable_weights (self , weights ):
355
360
self ._non_trainable_weights = weights
356
361
362
+ @property
363
+ def regularizers (self ):
364
+ warnings .warn ('The `regularizers` property of layers/models is deprecated. '
365
+ 'Regularization losses are now managed via the `losses` '
366
+ 'layer/model property.' )
367
+ return []
368
+
369
+ @regularizers .setter
370
+ def regularizers (self , _ ):
371
+ warnings .warn ('The `regularizers` property of layers/models is deprecated. '
372
+ 'Regularization losses are now managed via the `losses` '
373
+ 'layer/model property.' )
374
+
357
375
def create_input_layer (self , batch_input_shape ,
358
376
input_dtype = None , name = None ):
359
377
if not name :
@@ -373,6 +391,32 @@ def create_input_layer(self, batch_input_shape,
373
391
# to the input layer we just created.
374
392
self (x )
375
393
394
+ def add_weight (self , shape , initializer , name = None ,
395
+ trainable = True ,
396
+ regularizer = None ,
397
+ constraint = None ):
398
+ '''Adds a weight variable to the layer.
399
+
400
+ # Arguments:
401
+ shape: The shape tuple of the weight.
402
+ initializer: An Initializer instance (callable).
403
+ trainable: A boolean, whether the weight should
404
+ be trained via backprop or not (assuming
405
+ that the layer itself is also trainable).
406
+ regularizer: An optional Regularizer instance.
407
+ '''
408
+ initializer = initializations .get (initializer )
409
+ weight = initializer (shape , name = name )
410
+ if regularizer is not None :
411
+ self .add_loss (regularizer (weight ))
412
+ if constraint is not None :
413
+ self .constraints [weight ] = constraint
414
+ if trainable :
415
+ self .trainable_weights .append (weight )
416
+ else :
417
+ self .non_trainable_weights .append (weight )
418
+ return weight
419
+
376
420
def assert_input_compatibility (self , input ):
377
421
'''This checks that the tensor(s) `input`
378
422
verify the input assumptions of the layer
@@ -519,15 +563,21 @@ def __call__(self, x, mask=None):
519
563
self .add_inbound_node (inbound_layers , node_indices , tensor_indices )
520
564
# Outputs were already computed when calling self.add_inbound_node.
521
565
outputs = self .inbound_nodes [- 1 ].output_tensors
522
- # If single output tensor: return it,
523
- # else return a list (at least 2 elements).
524
- if len (outputs ) == 1 :
525
- return outputs [0 ]
526
- else :
527
- return outputs
528
566
else :
529
567
# This case appears if the input was not a Keras tensor.
530
- return self .call (x , mask )
568
+ outputs = to_list (self .call (x , mask ))
569
+
570
+ # Apply activity regularizer if any:
571
+ if hasattr (self , 'activity_regularizer' ) and self .activity_regularizer is not None :
572
+ regularization_losses = [self .activity_regularizer (x ) for x in outputs ]
573
+ self .add_loss (regularization_losses , input_tensors )
574
+
575
+ # If single output tensor: return it,
576
+ # else return a list (at least 2 elements).
577
+ if len (outputs ) == 1 :
578
+ return outputs [0 ]
579
+ else :
580
+ return outputs
531
581
532
582
def add_inbound_node (self , inbound_layers ,
533
583
node_indices = None , tensor_indices = None ):
@@ -806,33 +856,78 @@ def output_shape(self):
806
856
'ill-defined for the layer. ' +
807
857
'Use `get_output_shape_at(node_index)` instead.' )
808
858
809
- def add_updates (self , updates , inputs ):
859
+ def add_loss (self , losses , inputs = None ):
860
+ if losses is None :
861
+ return
862
+ # Update self.losses
863
+ losses = to_list (losses )
864
+ if not hasattr (self , 'losses' ):
865
+ self .losses = []
866
+ try :
867
+ self .losses += losses
868
+ except AttributeError :
869
+ # In case self.losses isn't settable
870
+ # (i.e. it's a getter method).
871
+ # In that case the `losses` property is
872
+ # auto-computed and shouldn't be set.
873
+ pass
874
+ # Update self._per_input_updates
875
+ if not hasattr (self , '_per_input_losses' ):
876
+ self ._per_input_losses = {}
877
+ if inputs is not None :
878
+ inputs_hash = object_list_uid (inputs )
879
+ else :
880
+ # Updates indexed by None are unconditional
881
+ # rather than input-dependent
882
+ inputs_hash = None
883
+ if inputs_hash not in self ._per_input_losses :
884
+ self ._per_input_losses [inputs_hash ] = []
885
+ self ._per_input_losses [inputs_hash ] += losses
886
+
887
+ def add_update (self , updates , inputs = None ):
888
+ if updates is None :
889
+ return
810
890
# Update self.updates
891
+ updates = to_list (updates )
811
892
if not hasattr (self , 'updates' ):
812
893
self .updates = []
813
894
try :
814
895
self .updates += updates
815
896
except AttributeError :
897
+ # In case self.updates isn't settable
898
+ # (i.e. it's a getter method).
899
+ # In that case the `updates` property is
900
+ # auto-computed and shouldn't be set.
816
901
pass
817
902
# Update self._per_input_updates
818
903
if not hasattr (self , '_per_input_updates' ):
819
904
self ._per_input_updates = {}
820
- inputs = to_list (inputs )
821
- updates = to_list (updates )
822
- inputs_hash = ', ' .join ([str (abs (id (x ))) for x in inputs ])
905
+ if inputs is not None :
906
+ inputs_hash = object_list_uid (inputs )
907
+ else :
908
+ # Updates indexed by None are unconditional
909
+ # rather than input-dependent
910
+ inputs_hash = None
823
911
if inputs_hash not in self ._per_input_updates :
824
912
self ._per_input_updates [inputs_hash ] = []
825
913
self ._per_input_updates [inputs_hash ] += updates
826
914
827
915
def get_updates_for (self , inputs ):
828
916
if not hasattr (self , '_per_input_updates' ):
829
917
return []
830
- inputs = to_list (inputs )
831
- inputs_hash = ', ' .join ([str (abs (id (x ))) for x in inputs ])
918
+ inputs_hash = object_list_uid (inputs )
832
919
if inputs_hash in self ._per_input_updates :
833
920
return self ._per_input_updates [inputs_hash ]
834
921
return []
835
922
923
+ def get_losses_for (self , inputs ):
924
+ if not hasattr (self , '_per_input_losses' ):
925
+ return []
926
+ inputs_hash = object_list_uid (inputs )
927
+ if inputs_hash in self ._per_input_losses :
928
+ return self ._per_input_losses [inputs_hash ]
929
+ return []
930
+
836
931
@property
837
932
def weights (self ):
838
933
return self .trainable_weights + self .non_trainable_weights
@@ -950,7 +1045,6 @@ def __init__(self, input_shape=None, batch_input_shape=None,
950
1045
951
1046
self .trainable_weights = []
952
1047
self .non_trainable_weights = []
953
- self .regularizers = []
954
1048
self .constraints = {}
955
1049
956
1050
self .sparse = sparse
@@ -1151,7 +1245,6 @@ def __init__(self, layers=None, mode='sum', concat_axis=-1,
1151
1245
self .inbound_nodes = []
1152
1246
self .outbound_nodes = []
1153
1247
self .constraints = {}
1154
- self .regularizers = []
1155
1248
self .trainable_weights = []
1156
1249
self .non_trainable_weights = []
1157
1250
self .supports_masking = True
@@ -1587,7 +1680,6 @@ class Container(Layer):
1587
1680
supports_masking (boolean)
1588
1681
trainable_weights (list of variables)
1589
1682
non_trainable_weights (list of variables)
1590
- regularizers (list of regularizers)
1591
1683
constraints (list of tuples (weight, constraint))
1592
1684
1593
1685
# Methods
@@ -1901,7 +1993,6 @@ def build_map_of_graph(tensor, seen_nodes=set(), depth=0,
1901
1993
self .supports_masking = False
1902
1994
# The following are implemented as property functions:
1903
1995
# self.constraints
1904
- # self.regularizers
1905
1996
# self.trainable_weights
1906
1997
# self.non_trainable_weights
1907
1998
# self.input_spec
@@ -1946,14 +2037,38 @@ def updates(self):
1946
2037
if len (layer .inbound_nodes ) == 1 :
1947
2038
updates += layer .updates
1948
2039
else :
2040
+ # Collect updates that are dependent on inputs
2041
+ # that are part of the model.
1949
2042
for node_index , node in enumerate (layer .inbound_nodes ):
1950
2043
node_key = layer .name + '_ib-' + str (node_index )
1951
2044
if node_key in self .container_nodes :
1952
2045
# The model owns this layer node.
1953
2046
inputs = node .input_tensors
1954
2047
updates += layer .get_updates_for (inputs )
2048
+ # Collect unconditional updates.
2049
+ updates += layer .get_updates_for (None )
1955
2050
return updates
1956
2051
2052
+ @property
2053
+ def losses (self ):
2054
+ losses = []
2055
+ for layer in self .layers :
2056
+ if hasattr (layer , 'losses' ):
2057
+ if len (layer .inbound_nodes ) == 1 :
2058
+ losses += layer .losses
2059
+ else :
2060
+ # Collect losses that are dependent on inputs
2061
+ # that are part of the model.
2062
+ for node_index , node in enumerate (layer .inbound_nodes ):
2063
+ node_key = layer .name + '_ib-' + str (node_index )
2064
+ if node_key in self .container_nodes :
2065
+ # The model owns this layer node.
2066
+ inputs = node .input_tensors
2067
+ losses += layer .get_losses_for (inputs )
2068
+ # Collect unconditional losses.
2069
+ losses += layer .get_losses_for (None )
2070
+ return losses
2071
+
1957
2072
@property
1958
2073
def stateful (self ):
1959
2074
return any ([(hasattr (layer , 'stateful' ) and layer .stateful ) for layer in self .layers ])
@@ -1990,10 +2105,13 @@ def constraints(self):
1990
2105
1991
2106
@property
1992
2107
def regularizers (self ):
1993
- regs = []
1994
- for layer in self .layers :
1995
- regs += layer .regularizers
1996
- return regs
2108
+ warnings .warn ('The `regularizers` attribute of layers/models '
2109
+ 'is deprecated. '
2110
+ 'Regularization losses are now managed via the `losses` '
2111
+ 'layer/model property.\n '
2112
+ 'The `regularizers` attribute will be removed '
2113
+ 'after 06/2017.' )
2114
+ return []
1997
2115
1998
2116
@property
1999
2117
def trainable_weights (self ):
@@ -2061,8 +2179,7 @@ def uses_learning_phase(self):
2061
2179
'''True if any layer in the graph uses it.
2062
2180
'''
2063
2181
layers_learning_phase = any ([layer .uses_learning_phase for layer in self .layers ])
2064
- regs_learning_phase = any ([reg .uses_learning_phase for reg in self .regularizers ])
2065
- return layers_learning_phase or regs_learning_phase
2182
+ return layers_learning_phase
2066
2183
2067
2184
def call (self , input , mask = None ):
2068
2185
'''`call` just reapplies all ops in the graph to the new inputs
@@ -2239,9 +2356,16 @@ def run_internal_graph(self, inputs, masks=None):
2239
2356
output_tensors = to_list (layer .call (computed_tensors , computed_masks ))
2240
2357
output_masks = to_list (layer .compute_mask (computed_tensors , computed_masks ))
2241
2358
2242
- # update model updates
2359
+ # Update model updates and losses:
2243
2360
layer_inputs = [x [0 ] for x in computed_data ]
2244
- self .add_updates (layer .get_updates_for (layer_inputs ), inputs )
2361
+ # Keep track of updates that depend on the inputs (e.g. BN updates).
2362
+ self .add_update (layer .get_updates_for (layer_inputs ), inputs )
2363
+ # Keep track of unconditional updates (e.g. a counter).
2364
+ self .add_update (layer .get_updates_for (None ), None )
2365
+ # Keep track of losses that depend on the inputs (e.g. activity regularizers).
2366
+ self .add_loss (layer .get_losses_for (layer_inputs ), inputs )
2367
+ # Keep track of unconditional losses (e.g. weight regularizers).
2368
+ self .add_loss (layer .get_losses_for (None ), None )
2245
2369
2246
2370
# Update _keras_shape.
2247
2371
if all ([hasattr (x , '_keras_shape' ) for x in computed_tensors ]):
0 commit comments