Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add gradient accumulator #2525

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tensorflow_addons.optimizers.lamb import LAMB
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
from tensorflow_addons.optimizers.lookahead import Lookahead
from tensorflow_addons.optimizers.gradient_accumulator import GradientAccumulator
from tensorflow_addons.optimizers.moving_average import MovingAverage
from tensorflow_addons.optimizers.novograd import NovoGrad
from tensorflow_addons.optimizers.proximal_adagrad import ProximalAdagrad
Expand Down
174 changes: 174 additions & 0 deletions tensorflow_addons/optimizers/gradient_accumulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from tensorflow_addons.utils import types
from typeguard import typechecked


@tf.keras.utils.register_keras_serializable(package="Addons")
class GradientAccumulator(tf.keras.optimizers.Optimizer):
"""Optimizer wrapper for gradient accumulation."""

@typechecked
def __init__(
self,
optimizer: types.Optimizer,
accum_steps: types.TensorLike = 4,
name: str = "GradientAccumulator",
**kwargs,
):
r"""Construct a new GradientAccumulator optimizer.

Args:
optimizer: str or `tf.keras.optimizers.Optimizer` that will be
used to compute and apply gradients.
accum_steps: int > 0. Update gradient in every accumulation steps.
name: Optional name for the operations created when applying
gradients. Defaults to "GradientAccumulator".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
norm; `clipvalue` is clip gradients by value, `decay` is
included for backward compatibility to allow time inverse
decay of learning rate. `lr` is included for backward
compatibility, recommended to use `learning_rate` instead.
"""
super().__init__(name, **kwargs)
self._optimizer = tf.keras.optimizers.get(optimizer)
self._gradients = []
self._accum_steps = accum_steps

def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list)
for var in var_list:
self.add_slot(var, "ga")

self._gradients = [self.get_slot(var, "ga") for var in var_list]

@property
def gradients(self):
"""The accumulated gradients on the current replica."""
if not self._gradients:
raise ValueError(
"The accumulator should be called first to initialize the gradients"
)
return list(
gradient.read_value() if gradient is not None else gradient
for gradient in self._gradients
)

def apply_gradients(self, grads_and_vars, name=None, **kwargs):
self._optimizer._iterations = self.iterations
return super().apply_gradients(grads_and_vars, name, **kwargs)

def _resource_apply_dense(self, grad, var, apply_state=None):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
accum_gradient.assign_add(
grad, use_locking=self._use_locking, read_value=False
)

def _apply():
if "apply_state" in self._optimizer._dense_apply_args:
train_op = self._optimizer._resource_apply_dense(
accum_gradient, var, apply_state=apply_state
)
else:
train_op = self._optimizer._resource_apply_dense(accum_gradient, var)
reset_op = accum_gradient.assign(
tf.zeros_like(accum_gradient),
use_locking=self._use_locking,
read_value=False,
)
return tf.group(train_op, reset_op)

apply_op = tf.cond(
self.iterations % self._accum_steps == 0, _apply, lambda: tf.no_op()
)
return apply_op

def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state):
accum_gradient = self.get_slot(var, "ga")
if accum_gradient is not None and grad is not None:
self._resource_scatter_add(accum_gradient, indices, grad)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MNIST example is still working but in my main project I am getting:

  File "/home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  shape of indices ([187]) is not compatible with the shape of updates ([1004,256])
	 [[{{node cond_1/then/_631/cond_1/GradientAccumulator/GradientAccumulator/update_298/update_0/ResourceScatterAdd}}]]
	 [[cond_1/then/_631/cond_1/GradientAccumulator/GradientAccumulator/update_133/update_0/ReadVariableOp/_514]]
  (1) Invalid argument:  shape of indices ([187]) is not compatible with the shape of updates ([1004,256])
	 [[{{node cond_1/then/_631/cond_1/GradientAccumulator/GradientAccumulator/update_298/update_0/ResourceScatterAdd}}]]

I think it might be the call to self._resource_scatter_add(accum_gradient, indices, grad)?


def _apply():
if "apply_state" in self._optimizer._sparse_apply_args:
train_op = self._optimizer._resource_apply_dense(
accum_gradient,
var,
apply_state=apply_state,
)
else:
train_op = self._optimizer._resource_apply_dense(accum_gradient, var)
reset_op = accum_gradient.assign(
tf.zeros_like(accum_gradient),
use_locking=self._use_locking,
read_value=False,
)
return tf.group(train_op, reset_op)

apply_op = tf.cond(
self.iterations % self._accum_steps == 0, _apply, lambda: tf.no_op()
)
return apply_op

def reset(self):
"""Resets the accumulated gradients on the current replica."""
assign_ops = []
if not self._gradients:
return assign_ops

for gradient in self._gradients:
if gradient is not None:
assign_ops.append(
gradient.assign(
tf.zeros_like(gradient),
use_locking=self._use_locking,
read_value=False,
)
)

return tf.group(assign_ops)

@property
def lr(self):
return self._optimizer._get_hyper("learning_rate")

@lr.setter
def lr(self, lr):
self._optimizer._set_hyper("learning_rate", lr) #

@property
def learning_rate(self):
return self._optimizer._get_hyper("learning_rate")

@learning_rate.setter
def learning_rate(self, learning_rate):
self._optimizer._set_hyper("learning_rate", learning_rate)

def get_config(self):
config = {
"accum_steps": self._accum_steps,
"optimizer": tf.keras.optimizers.serialize(self._optimizer),
}
base_config = super().get_config()
return {**base_config, **config}

@classmethod
def from_config(cls, config, custom_objects=None):
optimizer = tf.keras.optimizers.deserialize(
config.pop("optimizer"), custom_objects=custom_objects
)
return cls(optimizer, **config)
153 changes: 153 additions & 0 deletions tensorflow_addons/optimizers/tests/gradient_accumulator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for GradientAccumulator optimizers."""

import numpy as np
import pytest
import tensorflow as tf
from tensorflow_addons.utils import test_utils

from tensorflow_addons.optimizers import GradientAccumulator


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_run():
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
accum_steps = 4

grads0 = tf.constant([0.1, 0.1])
grads1 = tf.constant([0.01, 0.01])

grads_and_vars = list(zip([grads0, grads1], [var0, var1]))

opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0), accum_steps)

for _ in range(accum_steps + 1):
opt.apply_gradients(grads_and_vars)

np.testing.assert_allclose(var0.read_value(), [0.5, 1.5])
np.testing.assert_allclose(var1.read_value(), [2.95, 3.95])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_sparse():
var0 = tf.Variable([[1.0, 2.0, 0.0]])
var1 = tf.Variable([[3.0, 4.0, 0.0]])

grads0 = tf.IndexedSlices(
tf.constant([[0.1, 0.1, 0.0]]),
tf.constant([0]),
tf.constant([1, 3]),
)
grads1 = tf.IndexedSlices(
tf.constant([[0.01, 0.01, 0.0]]),
tf.constant([0]),
tf.constant([1, 3]),
)

grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0, momentum=0.1))
opt.apply_gradients(grads_and_vars)
np.testing.assert_allclose(var0.read_value(), [[0.9, 1.9, 0.0]])
np.testing.assert_allclose(var1.read_value(), [[2.99, 3.99, 0.0]])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_dense():
grad = tf.Variable([[0.1]])
model = tf.keras.Sequential(
[
tf.keras.layers.Dense(
1,
kernel_initializer=tf.keras.initializers.Constant([[1.0]]),
use_bias=False,
)
]
)
model.build(input_shape=[1, 1])

opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=2.0), accum_steps=2)
_ = opt.apply_gradients(list(zip([grad], model.variables)))
np.testing.assert_allclose(model.variables[0].read_value(), [[0.8]])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_optimizer_string():
_ = GradientAccumulator("adam")


def test_config():
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
accum_steps = 4
opt = GradientAccumulator(sgd_opt, accum_steps=accum_steps)
config = opt.get_config()

assert config["accum_steps"] == accum_steps

new_opt = GradientAccumulator.from_config(config)
old_sgd_config = opt._optimizer.get_config()
new_sgd_config = new_opt._optimizer.get_config()

for k1, k2 in zip(old_sgd_config, new_sgd_config):
assert old_sgd_config[k1] == new_sgd_config[k2]


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.needs_gpu
def test_fit_simple_linear_model():
seed = 0x2019
np.random.seed(seed)
tf.random.set_seed(seed)
num_examples = 5000
x = np.random.standard_normal((num_examples, 3))
w = np.random.standard_normal((3, 1))
y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4
strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing())
with strategy.scope():
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))

opt = GradientAccumulator("sgd")
model.compile(opt, loss="mse")

model.fit(x, y, epochs=5)

x = np.random.standard_normal((100, 3))
y = np.dot(x, w)

predicted = model.predict(x)

max_abs_diff = np.max(np.abs(predicted - y))
assert max_abs_diff < 5e-3


def test_serialization():
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1)
optimizer = GradientAccumulator(sgd_opt)
config = tf.keras.optimizers.serialize(optimizer)
new_optimizer = tf.keras.optimizers.deserialize(config)
assert new_optimizer.get_config() == optimizer.get_config()


@pytest.mark.usefixtures("run_with_mixed_precision_policy")
def test_model_mixed_precision():
x = np.random.standard_normal((10000, 3))
w = np.random.standard_normal((3, 1))
y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))
model.compile(GradientAccumulator("sgd"), loss="mse")
model.fit(x, y, epochs=3)
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/tests/standard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"ConditionalGradient", # is wrapper
"Lookahead", # is wrapper
"MovingAverage", # is wrapper
"GradientAccumulator", # is wrapper
]


Expand Down
1 change: 1 addition & 0 deletions tools/testing/source_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_no_tf_cond():
"tensorflow_addons/metrics/cohens_kappa.py",
"tensorflow_addons/seq2seq/sampler.py",
"tensorflow_addons/seq2seq/beam_search_decoder.py",
"tensorflow_addons/optimizers/gradient_accumulator.py",
]
for file_path, line_idx, line in get_lines_of_source_code(allowlist):

Expand Down