-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add gradient accumulator #2525
Closed
Closed
add gradient accumulator #2525
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
e49c805
add gradient accumulator
fsx950223 2c0fbae
add exceptions
fsx950223 11e536d
fix multi gpus bug
fsx950223 1a4c0d4
fix test bugs
fsx950223 eabed95
fix sparse optimizer
fsx950223 a6ff7c0
remove read_value
fsx950223 24ae8a9
fix sparse test
fsx950223 2760fad
fix sparse bug
fsx950223 4ba7a55
refactor
fsx950223 dc50184
add sparse multi gpu test
fsx950223 8cd65ad
fix rnn bug
fsx950223 7d40946
fix step bugs
fsx950223 6949bd3
fix _iterations
fsx950223 9e423e5
use gradient transformer
fsx950223 7f3b2e9
fix bug
fsx950223 99dcde5
fix step bug
fsx950223 a184581
simpify code
fsx950223 d0718f8
optimize
fsx950223 2af5475
fix bug
fsx950223 42fccea
fix bug
fsx950223 93794ec
simpify code
fsx950223 e62cc95
add mean reduction
fsx950223 64b70b4
decrease memory usage
fsx950223 67c1e8e
fix iterations
fsx950223 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import tensorflow as tf | ||
from tensorflow_addons.utils import types | ||
from typeguard import typechecked | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="Addons") | ||
class GradientAccumulator(tf.keras.optimizers.Optimizer): | ||
"""Optimizer wrapper for gradient accumulation.""" | ||
|
||
@typechecked | ||
def __init__( | ||
self, | ||
optimizer: types.Optimizer, | ||
accum_steps: types.TensorLike = 4, | ||
name: str = "GradientAccumulator", | ||
**kwargs, | ||
): | ||
r"""Construct a new GradientAccumulator optimizer. | ||
|
||
Args: | ||
optimizer: str or `tf.keras.optimizers.Optimizer` that will be | ||
used to compute and apply gradients. | ||
accum_steps: int > 0. Update gradient in every accumulation steps. | ||
name: Optional name for the operations created when applying | ||
gradients. Defaults to "GradientAccumulator". | ||
**kwargs: keyword arguments. Allowed to be {`clipnorm`, | ||
`clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by | ||
norm; `clipvalue` is clip gradients by value, `decay` is | ||
included for backward compatibility to allow time inverse | ||
decay of learning rate. `lr` is included for backward | ||
compatibility, recommended to use `learning_rate` instead. | ||
""" | ||
super().__init__(name, **kwargs) | ||
self._optimizer = tf.keras.optimizers.get(optimizer) | ||
self._gradients = [] | ||
self._accum_steps = accum_steps | ||
|
||
def _create_slots(self, var_list): | ||
self._optimizer._create_slots(var_list=var_list) | ||
for var in var_list: | ||
self.add_slot(var, "ga") | ||
|
||
self._gradients = [self.get_slot(var, "ga") for var in var_list] | ||
|
||
@property | ||
def gradients(self): | ||
"""The accumulated gradients on the current replica.""" | ||
if not self._gradients: | ||
raise ValueError( | ||
"The accumulator should be called first to initialize the gradients" | ||
) | ||
return list( | ||
gradient.read_value() if gradient is not None else gradient | ||
for gradient in self._gradients | ||
) | ||
|
||
def apply_gradients(self, grads_and_vars, name=None, **kwargs): | ||
self._optimizer._iterations = self.iterations | ||
return super().apply_gradients(grads_and_vars, name, **kwargs) | ||
|
||
def _resource_apply_dense(self, grad, var, apply_state=None): | ||
accum_gradient = self.get_slot(var, "ga") | ||
if accum_gradient is not None and grad is not None: | ||
accum_gradient.assign_add( | ||
grad, use_locking=self._use_locking, read_value=False | ||
) | ||
|
||
def _apply(): | ||
if "apply_state" in self._optimizer._dense_apply_args: | ||
train_op = self._optimizer._resource_apply_dense( | ||
accum_gradient, var, apply_state=apply_state | ||
) | ||
else: | ||
train_op = self._optimizer._resource_apply_dense(accum_gradient, var) | ||
reset_op = accum_gradient.assign( | ||
tf.zeros_like(accum_gradient), | ||
use_locking=self._use_locking, | ||
read_value=False, | ||
) | ||
return tf.group(train_op, reset_op) | ||
|
||
apply_op = tf.cond( | ||
self.iterations % self._accum_steps == 0, _apply, lambda: tf.no_op() | ||
) | ||
return apply_op | ||
|
||
def _resource_apply_sparse(self, grad: types.TensorLike, var, indices, apply_state): | ||
accum_gradient = self.get_slot(var, "ga") | ||
if accum_gradient is not None and grad is not None: | ||
self._resource_scatter_add(accum_gradient, indices, grad) | ||
|
||
def _apply(): | ||
if "apply_state" in self._optimizer._sparse_apply_args: | ||
train_op = self._optimizer._resource_apply_dense( | ||
accum_gradient, | ||
var, | ||
apply_state=apply_state, | ||
) | ||
else: | ||
train_op = self._optimizer._resource_apply_dense(accum_gradient, var) | ||
reset_op = accum_gradient.assign( | ||
tf.zeros_like(accum_gradient), | ||
use_locking=self._use_locking, | ||
read_value=False, | ||
) | ||
return tf.group(train_op, reset_op) | ||
|
||
apply_op = tf.cond( | ||
self.iterations % self._accum_steps == 0, _apply, lambda: tf.no_op() | ||
) | ||
return apply_op | ||
|
||
def reset(self): | ||
"""Resets the accumulated gradients on the current replica.""" | ||
assign_ops = [] | ||
if not self._gradients: | ||
return assign_ops | ||
|
||
for gradient in self._gradients: | ||
if gradient is not None: | ||
assign_ops.append( | ||
gradient.assign( | ||
tf.zeros_like(gradient), | ||
use_locking=self._use_locking, | ||
read_value=False, | ||
) | ||
) | ||
|
||
return tf.group(assign_ops) | ||
|
||
@property | ||
def lr(self): | ||
return self._optimizer._get_hyper("learning_rate") | ||
|
||
@lr.setter | ||
def lr(self, lr): | ||
self._optimizer._set_hyper("learning_rate", lr) # | ||
|
||
@property | ||
def learning_rate(self): | ||
return self._optimizer._get_hyper("learning_rate") | ||
|
||
@learning_rate.setter | ||
def learning_rate(self, learning_rate): | ||
self._optimizer._set_hyper("learning_rate", learning_rate) | ||
|
||
def get_config(self): | ||
config = { | ||
"accum_steps": self._accum_steps, | ||
"optimizer": tf.keras.optimizers.serialize(self._optimizer), | ||
} | ||
base_config = super().get_config() | ||
return {**base_config, **config} | ||
|
||
@classmethod | ||
def from_config(cls, config, custom_objects=None): | ||
optimizer = tf.keras.optimizers.deserialize( | ||
config.pop("optimizer"), custom_objects=custom_objects | ||
) | ||
return cls(optimizer, **config) |
153 changes: 153 additions & 0 deletions
153
tensorflow_addons/optimizers/tests/gradient_accumulator_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for GradientAccumulator optimizers.""" | ||
|
||
import numpy as np | ||
import pytest | ||
import tensorflow as tf | ||
from tensorflow_addons.utils import test_utils | ||
|
||
from tensorflow_addons.optimizers import GradientAccumulator | ||
|
||
|
||
@pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
def test_run(): | ||
var0 = tf.Variable([1.0, 2.0]) | ||
var1 = tf.Variable([3.0, 4.0]) | ||
accum_steps = 4 | ||
|
||
grads0 = tf.constant([0.1, 0.1]) | ||
grads1 = tf.constant([0.01, 0.01]) | ||
|
||
grads_and_vars = list(zip([grads0, grads1], [var0, var1])) | ||
|
||
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0), accum_steps) | ||
|
||
for _ in range(accum_steps + 1): | ||
opt.apply_gradients(grads_and_vars) | ||
|
||
np.testing.assert_allclose(var0.read_value(), [0.5, 1.5]) | ||
np.testing.assert_allclose(var1.read_value(), [2.95, 3.95]) | ||
|
||
|
||
@pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
def test_sparse(): | ||
var0 = tf.Variable([[1.0, 2.0, 0.0]]) | ||
var1 = tf.Variable([[3.0, 4.0, 0.0]]) | ||
|
||
grads0 = tf.IndexedSlices( | ||
tf.constant([[0.1, 0.1, 0.0]]), | ||
tf.constant([0]), | ||
tf.constant([1, 3]), | ||
) | ||
grads1 = tf.IndexedSlices( | ||
tf.constant([[0.01, 0.01, 0.0]]), | ||
tf.constant([0]), | ||
tf.constant([1, 3]), | ||
) | ||
|
||
grads_and_vars = list(zip([grads0, grads1], [var0, var1])) | ||
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0, momentum=0.1)) | ||
opt.apply_gradients(grads_and_vars) | ||
np.testing.assert_allclose(var0.read_value(), [[0.9, 1.9, 0.0]]) | ||
np.testing.assert_allclose(var1.read_value(), [[2.99, 3.99, 0.0]]) | ||
|
||
|
||
@pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
def test_dense(): | ||
grad = tf.Variable([[0.1]]) | ||
model = tf.keras.Sequential( | ||
[ | ||
tf.keras.layers.Dense( | ||
1, | ||
kernel_initializer=tf.keras.initializers.Constant([[1.0]]), | ||
use_bias=False, | ||
) | ||
] | ||
) | ||
model.build(input_shape=[1, 1]) | ||
|
||
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=2.0), accum_steps=2) | ||
_ = opt.apply_gradients(list(zip([grad], model.variables))) | ||
np.testing.assert_allclose(model.variables[0].read_value(), [[0.8]]) | ||
|
||
|
||
@pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
def test_optimizer_string(): | ||
_ = GradientAccumulator("adam") | ||
|
||
|
||
def test_config(): | ||
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1) | ||
accum_steps = 4 | ||
opt = GradientAccumulator(sgd_opt, accum_steps=accum_steps) | ||
config = opt.get_config() | ||
|
||
assert config["accum_steps"] == accum_steps | ||
|
||
new_opt = GradientAccumulator.from_config(config) | ||
old_sgd_config = opt._optimizer.get_config() | ||
new_sgd_config = new_opt._optimizer.get_config() | ||
|
||
for k1, k2 in zip(old_sgd_config, new_sgd_config): | ||
assert old_sgd_config[k1] == new_sgd_config[k2] | ||
|
||
|
||
@pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
@pytest.mark.needs_gpu | ||
def test_fit_simple_linear_model(): | ||
seed = 0x2019 | ||
np.random.seed(seed) | ||
tf.random.set_seed(seed) | ||
num_examples = 5000 | ||
x = np.random.standard_normal((num_examples, 3)) | ||
w = np.random.standard_normal((3, 1)) | ||
y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4 | ||
strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing()) | ||
with strategy.scope(): | ||
model = tf.keras.models.Sequential() | ||
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) | ||
|
||
opt = GradientAccumulator("sgd") | ||
model.compile(opt, loss="mse") | ||
|
||
model.fit(x, y, epochs=5) | ||
|
||
x = np.random.standard_normal((100, 3)) | ||
y = np.dot(x, w) | ||
|
||
predicted = model.predict(x) | ||
|
||
max_abs_diff = np.max(np.abs(predicted - y)) | ||
assert max_abs_diff < 5e-3 | ||
|
||
|
||
def test_serialization(): | ||
sgd_opt = tf.keras.optimizers.SGD(lr=2.0, nesterov=True, momentum=0.3, decay=0.1) | ||
optimizer = GradientAccumulator(sgd_opt) | ||
config = tf.keras.optimizers.serialize(optimizer) | ||
new_optimizer = tf.keras.optimizers.deserialize(config) | ||
assert new_optimizer.get_config() == optimizer.get_config() | ||
|
||
|
||
@pytest.mark.usefixtures("run_with_mixed_precision_policy") | ||
def test_model_mixed_precision(): | ||
x = np.random.standard_normal((10000, 3)) | ||
w = np.random.standard_normal((3, 1)) | ||
y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4 | ||
model = tf.keras.Sequential() | ||
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) | ||
model.compile(GradientAccumulator("sgd"), loss="mse") | ||
model.fit(x, y, epochs=3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MNIST example is still working but in my main project I am getting:
I think it might be the call to
self._resource_scatter_add(accum_gradient, indices, grad)
?