diff --git a/tensorflow_addons/optimizers/__init__.py b/tensorflow_addons/optimizers/__init__.py index b8bc0109da..3cf79856c5 100644 --- a/tensorflow_addons/optimizers/__init__.py +++ b/tensorflow_addons/optimizers/__init__.py @@ -32,6 +32,7 @@ from tensorflow_addons.optimizers.lamb import LAMB from tensorflow_addons.optimizers.lazy_adam import LazyAdam from tensorflow_addons.optimizers.lookahead import Lookahead +from tensorflow_addons.optimizers.gradient_accumulator import GradientAccumulator from tensorflow_addons.optimizers.moving_average import MovingAverage from tensorflow_addons.optimizers.novograd import NovoGrad from tensorflow_addons.optimizers.proximal_adagrad import ProximalAdagrad diff --git a/tensorflow_addons/optimizers/gradient_accumulator.py b/tensorflow_addons/optimizers/gradient_accumulator.py new file mode 100644 index 0000000000..57051f8e9e --- /dev/null +++ b/tensorflow_addons/optimizers/gradient_accumulator.py @@ -0,0 +1,231 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import tensorflow as tf +from tensorflow_addons.utils import types +from typeguard import typechecked + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class GradientAccumulator(tf.keras.optimizers.Optimizer): + """Optimizer wrapper for gradient accumulation.""" + + @typechecked + def __init__( + self, + inner_optimizer: types.Optimizer, + accum_steps: types.TensorLike = 4, + reduction: str = "SUM", + name: str = "GradientAccumulator", + **kwargs, + ): + r"""Construct a new GradientAccumulator optimizer. + + Args: + inner_optimizer: str or `tf.keras.optimizers.Optimizer` that will be + used to compute and apply gradients. + accum_steps: int > 0. Update gradient in every accumulation steps. + reduction: str, Reduction method ['SUM', 'MEAN'] + name: Optional name for the operations created when applying + gradients. Defaults to "GradientAccumulator". + **kwargs: keyword arguments. Allowed to be {`clipnorm`, + `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by + norm; `clipvalue` is clip gradients by value, `decay` is + included for backward compatibility to allow time inverse + decay of learning rate. `lr` is included for backward + compatibility, recommended to use `learning_rate` instead. + """ + super().__init__(name, **kwargs) + self._optimizer = tf.keras.optimizers.get(inner_optimizer) + self._step = None + self._accum_steps = accum_steps + self._reduction = reduction + + def _accum_grad(grads_and_vars): + new_grads_and_vars = [] + for grad, var in grads_and_vars: + handle = self.get_slot(var, "ga") + + if isinstance(grad, tf.IndexedSlices): + handle.scatter_add(grad) + + def _get_grad(): + new_grad = handle.read_value() + if self._reduction == "MEAN": + new_grad /= tf.cast(self._accum_steps, new_grad.dtype) + indices = tf.squeeze( + tf.where( + tf.reduce_sum( + new_grad, axis=list(range(len(new_grad.shape))[1:]) + ) + != 0 + ), + axis=-1, + ) + + values = tf.gather(new_grad, indices) + dense_shape = tf.constant(new_grad.shape.as_list()) + handle.assign( + tf.zeros_like(handle), + use_locking=self._use_locking, + read_value=False, + ) + return values, tf.cast(indices, grad.indices.dtype), dense_shape + + values, indices, dense_shape = tf.cond( + self.step % self._accum_steps == 0, + _get_grad, + lambda: ( + tf.zeros_like(grad.values), + grad.indices, + grad.dense_shape, + ), + ) + new_grad = tf.IndexedSlices(values, indices, dense_shape) + new_grads_and_vars.append((new_grad, var)) + else: + handle.assign_add( + grad, use_locking=self._use_locking, read_value=False + ) + + def _get_grad(): + new_grad = handle.read_value() + if self._reduction == "MEAN": + new_grad /= tf.cast(self._accum_steps, new_grad.dtype) + handle.assign( + tf.zeros_like(handle), + use_locking=self._use_locking, + read_value=False, + ) + return new_grad + + new_grad = tf.cond( + self.step % self._accum_steps == 0, + _get_grad, + lambda: tf.zeros_like(grad), + ) + new_grads_and_vars.append((new_grad, var)) + return new_grads_and_vars + + self.gradient_transformers.append(_accum_grad) + self._iterations = self._optimizer.iterations + + def _create_slots(self, var_list): + self._optimizer._create_slots(var_list=var_list) + for var in var_list: + self.add_slot(var, "ga") + + def _resource_apply_dense(self, grad, handle, apply_state): + if "apply_state" in self._optimizer._dense_apply_args: + return self.inner_optimizer._resource_apply_dense(grad, handle, apply_state) + else: + return self.inner_optimizer._resource_apply_dense(grad, handle) + + def _resource_apply_sparse(self, grad, handle, indices, apply_state): + if "apply_state" in self._optimizer._sparse_apply_args: + return self.inner_optimizer._resource_apply_sparse( + grad, handle, indices, apply_state=apply_state + ) + else: + return self.inner_optimizer._resource_apply_sparse(grad, handle, indices) + + def _resource_apply_sparse_duplicate_indices( + self, grad, handle, indices, apply_state=None + ): + if "apply_state" in self._optimizer._sparse_apply_args: + return self.inner_optimizer._resource_apply_sparse_duplicate_indices( + grad, handle, indices, apply_state=apply_state + ) + else: + return self.inner_optimizer._resource_apply_sparse_duplicate_indices( + grad, handle, indices + ) + + @property + def step(self): + """Variable. The number of training steps this Optimizer has run.""" + if self._step is None: + with self._distribution_strategy_scope(): + self._step = self.add_weight( + "iter", + shape=[], + dtype=tf.int64, + trainable=False, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + self._weights.append(self._step) + return self._step + + @step.setter + def step(self, variable): + if self._step is not None: + raise RuntimeError( + "Cannot set `step` to a new Variable after " + "the Optimizer weights have been created" + ) + self._step = variable + self._weights.append(self._step) + + def apply_gradients(self, grads_and_vars, name=None, **kwargs): + with tf.control_dependencies([self.step.assign_add(1, read_value=False)]): + train_op = super().apply_gradients(grads_and_vars, name, **kwargs) + with tf.control_dependencies([train_op]): + return self.iterations.assign_sub( + tf.cast(self.step % self._accum_steps != 0, tf.int64), + read_value=False, + ) + + @property + def inner_optimizer(self): + """The optimizer that this LossScaleOptimizer is wrapping.""" + return self._optimizer + + @property + def iterations(self): + return self._optimizer.iterations + + @iterations.setter + def iterations(self, variable): + self._optimizer.iterations = variable + + @property + def lr(self): + return self._optimizer._get_hyper("learning_rate") + + @lr.setter + def lr(self, lr): + self._optimizer._set_hyper("learning_rate", lr) # + + @property + def learning_rate(self): + return self._optimizer._get_hyper("learning_rate") + + @learning_rate.setter + def learning_rate(self, learning_rate): + self._optimizer._set_hyper("learning_rate", learning_rate) + + def get_config(self): + config = { + "accum_steps": self._accum_steps, + "optimizer": tf.keras.optimizers.serialize(self._optimizer), + } + base_config = super().get_config() + return {**base_config, **config} + + @classmethod + def from_config(cls, config, custom_objects=None): + optimizer = tf.keras.optimizers.deserialize( + config.pop("optimizer"), custom_objects=custom_objects + ) + return cls(optimizer, **config) diff --git a/tensorflow_addons/optimizers/tests/gradient_accumulator_test.py b/tensorflow_addons/optimizers/tests/gradient_accumulator_test.py new file mode 100644 index 0000000000..18d8d890f1 --- /dev/null +++ b/tensorflow_addons/optimizers/tests/gradient_accumulator_test.py @@ -0,0 +1,159 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GradientAccumulator optimizers.""" + +import numpy as np +import pytest +import tensorflow as tf + +from tensorflow_addons.optimizers import GradientAccumulator + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy]) +def test_run(): + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0, 4.0]) + accum_steps = 4 + + grads0 = tf.constant([0.1, 0.1]) + grads1 = tf.constant([0.01, 0.01]) + + grads_and_vars = list(zip([grads0, grads1], [var0, var1])) + + opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0), accum_steps) + + strategy = tf.distribute.get_strategy() + for _ in range(accum_steps + 1): + strategy.run(opt.apply_gradients, [grads_and_vars]) + + np.testing.assert_allclose(var0.read_value(), [0.6, 1.6]) + np.testing.assert_allclose(var1.read_value(), [2.96, 3.96]) + np.testing.assert_allclose(opt.iterations.read_value(), 1) + np.testing.assert_allclose(opt.step.read_value(), accum_steps + 1) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy]) +def test_sparse(): + var0 = tf.Variable([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]]) + var1 = tf.Variable([[3.0, 4.0, 0.0]]) + + grads0 = tf.IndexedSlices( + tf.constant([[0.1, 0.1, 0.0]]), + tf.constant([1]), + tf.constant([1, 3]), + ) + grads1 = tf.IndexedSlices( + tf.constant([[0.01, 0.01, 0.0]]), + tf.constant([0]), + tf.constant([1, 3]), + ) + + grads_and_vars = list(zip([grads0, grads1], [var0, var1])) + opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0)) + strategy = tf.distribute.get_strategy() + for _ in range(8): + strategy.run(opt.apply_gradients, [grads_and_vars]) + np.testing.assert_allclose(var0.read_value(), [[1.0, 2.0, 0.0], [0.2, 1.2, 0.0]]) + np.testing.assert_allclose(var1.read_value(), [[2.92, 3.92, 0.0]]) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_dense(): + grad = tf.Variable([[0.1]]) + model = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + 1, + kernel_initializer=tf.keras.initializers.Constant([[1.0]]), + use_bias=False, + ) + ] + ) + model.build(input_shape=[1, 1]) + + opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=2.0), accum_steps=2) + _ = opt.apply_gradients(list(zip([grad], model.variables))) + np.testing.assert_allclose(model.variables[0].read_value(), [[1.0]]) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_optimizer_string(): + _ = GradientAccumulator("adam") + + +def test_config(): + sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1) + accum_steps = 4 + opt = GradientAccumulator(sgd_opt, accum_steps=accum_steps) + config = opt.get_config() + + assert config["accum_steps"] == accum_steps + + new_opt = GradientAccumulator.from_config(config) + old_sgd_config = opt._optimizer.get_config() + new_sgd_config = new_opt._optimizer.get_config() + + for k1, k2 in zip(old_sgd_config, new_sgd_config): + assert old_sgd_config[k1] == new_sgd_config[k2] + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.with_device([tf.distribute.MirroredStrategy]) +def test_fit_simple_linear_model(): + seed = 0x2019 + np.random.seed(seed) + tf.random.set_seed(seed) + num_examples = 5000 + x = np.random.standard_normal((num_examples, 3)) + w = np.random.standard_normal((3, 1)) + y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4 + + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) + + opt = GradientAccumulator("sgd") + model.compile(opt, loss="mse") + + model.fit(x, y, epochs=5) + + x = np.random.standard_normal((100, 3)) + y = np.dot(x, w) + + predicted = model.predict(x) + + max_abs_diff = np.max(np.abs(predicted - y)) + assert max_abs_diff < 5e-3 + + +def test_serialization(): + sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1) + optimizer = GradientAccumulator(sgd_opt) + config = tf.keras.optimizers.serialize(optimizer) + new_optimizer = tf.keras.optimizers.deserialize(config) + assert new_optimizer.get_config() == optimizer.get_config() + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.usefixtures("run_with_mixed_precision_policy") +def test_model_mixed_precision(): + x = np.random.standard_normal((10000, 3)) + w = np.random.standard_normal((3, 1)) + y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4 + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) + model.compile(GradientAccumulator("sgd"), loss="mse") + model.fit(x, y, epochs=3) diff --git a/tensorflow_addons/optimizers/tests/standard_test.py b/tensorflow_addons/optimizers/tests/standard_test.py index f1d284ad68..3366c4f9a4 100644 --- a/tensorflow_addons/optimizers/tests/standard_test.py +++ b/tensorflow_addons/optimizers/tests/standard_test.py @@ -29,6 +29,7 @@ "ConditionalGradient", # is wrapper "Lookahead", # is wrapper "MovingAverage", # is wrapper + "GradientAccumulator", # is wrapper ] diff --git a/tools/testing/source_code_test.py b/tools/testing/source_code_test.py index c54bf73ea2..299e612078 100644 --- a/tools/testing/source_code_test.py +++ b/tools/testing/source_code_test.py @@ -124,6 +124,7 @@ def test_no_tf_cond(): "tensorflow_addons/metrics/cohens_kappa.py", "tensorflow_addons/seq2seq/sampler.py", "tensorflow_addons/seq2seq/beam_search_decoder.py", + "tensorflow_addons/optimizers/gradient_accumulator.py", ] for file_path, line_idx, line in get_lines_of_source_code(allowlist):