Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Jul 21, 2021
1 parent 64b70b4 commit 7d2d553
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 81 deletions.
87 changes: 42 additions & 45 deletions tensorflow_addons/optimizers/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,13 @@ def __init__(
super().__init__(name, **kwargs)
self._optimizer = tf.keras.optimizers.get(inner_optimizer)
self._step = None
self._gradients = {}
self._accum_steps = accum_steps
self._reduction = reduction

def _accum_grad(grads_and_vars):
with tf.init_scope():
if not self._gradients:
for grad, var in grads_and_vars:
self._gradients[var.ref()] = tf.Variable(
tf.zeros_like(var), trainable=False
)
new_grads_and_vars = []
for grad, var in grads_and_vars:
handle = self._gradients[var.ref()]
handle = self.get_slot(var, "ga")

if isinstance(grad, tf.IndexedSlices):
handle.scatter_add(grad)
Expand All @@ -84,9 +77,11 @@ def _get_grad():
values = tf.gather(new_grad, indices)
dense_shape = tf.constant(new_grad.shape.as_list())
handle.assign(
tf.zeros_like(handle), use_locking=self._use_locking
tf.zeros_like(handle),
use_locking=self._use_locking,
read_value=False,
)
return values, tf.cast(indices, tf.int32), dense_shape
return values, tf.cast(indices, grad.indices.dtype), dense_shape

values, indices, dense_shape = tf.cond(
self.step % self._accum_steps == 0,
Expand All @@ -100,14 +95,18 @@ def _get_grad():
new_grad = tf.IndexedSlices(values, indices, dense_shape)
new_grads_and_vars.append((new_grad, var))
else:
handle.assign_add(grad)
handle.assign_add(
grad, use_locking=self._use_locking, read_value=False
)

def _get_grad():
new_grad = handle.read_value()
if self._reduction == "MEAN":
new_grad /= tf.cast(self._accum_steps, new_grad.dtype)
handle.assign(
tf.zeros_like(handle), use_locking=self._use_locking
tf.zeros_like(handle),
use_locking=self._use_locking,
read_value=False,
)
return new_grad

Expand All @@ -119,11 +118,39 @@ def _get_grad():
new_grads_and_vars.append((new_grad, var))
return new_grads_and_vars

self._optimizer.gradient_transformers.append(_accum_grad)
self.gradient_transformers.append(_accum_grad)
self._iterations = self._optimizer.iterations

def _create_slots(self, var_list):
self._optimizer._create_slots(var_list=var_list)
for var in var_list:
self.add_slot(var, "ga")

def _resource_apply_dense(self, grad, handle, apply_state):
if "apply_state" in self._optimizer._dense_apply_args:
return self.inner_optimizer._resource_apply_dense(grad, handle, apply_state)
else:
return self.inner_optimizer._resource_apply_dense(grad, handle)

def _resource_apply_sparse(self, grad, handle, indices, apply_state):
if "apply_state" in self._optimizer._sparse_apply_args:
return self.inner_optimizer._resource_apply_sparse(
grad, handle, indices, apply_state=apply_state
)
else:
return self.inner_optimizer._resource_apply_sparse(grad, handle, indices)

def _resource_apply_sparse_duplicate_indices(
self, grad, handle, indices, apply_state=None
):
if "apply_state" in self._optimizer._sparse_apply_args:
return self.inner_optimizer._resource_apply_sparse_duplicate_indices(
grad, handle, indices, apply_state=apply_state
)
else:
return self.inner_optimizer._resource_apply_sparse_duplicate_indices(
grad, handle, indices
)

@property
def step(self):
Expand Down Expand Up @@ -151,49 +178,19 @@ def step(self, variable):
self._step = variable
self._weights.append(self._step)

@property
def gradients(self):
"""The accumulated gradients on the current replica."""
if not self._gradients:
raise ValueError(
"The accumulator should be called first to initialize the gradients"
)
return list(
gradient.read_value() if gradient is not None else gradient
for _, gradient in self._gradients
)

def apply_gradients(self, grads_and_vars, name=None, **kwargs):
train_op = self._optimizer.apply_gradients(grads_and_vars, name, **kwargs)
train_op = super().apply_gradients(grads_and_vars, name, **kwargs)
with tf.control_dependencies([train_op]):
with tf.control_dependencies(
[
self._optimizer.iterations.assign_add(
self.iterations.assign_add(
tf.cast(self.step % self._accum_steps == 0, tf.int64),
read_value=False,
)
]
):
return self.step.assign_add(1, read_value=False)

def reset(self):
"""Resets the accumulated gradients on the current replica."""
assign_ops = []
if not self._gradients:
return assign_ops

for _, gradient in self._gradients:
if gradient is not None:
assign_ops.append(
gradient.assign(
tf.zeros_like(gradient),
use_locking=self._use_locking,
read_value=False,
)
)

return tf.group(assign_ops)

@property
def inner_optimizer(self):
"""The optimizer that this LossScaleOptimizer is wrapping."""
Expand Down
48 changes: 12 additions & 36 deletions tensorflow_addons/optimizers/tests/gradient_accumulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import numpy as np
import pytest
import tensorflow as tf
from tensorflow_addons.utils import test_utils

from tensorflow_addons.optimizers import GradientAccumulator


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy])
def test_run():
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
Expand All @@ -35,14 +35,16 @@ def test_run():

opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0), accum_steps)

strategy = tf.distribute.get_strategy()
for _ in range(accum_steps + 1):
opt.apply_gradients(grads_and_vars)
strategy.run(opt.apply_gradients, [grads_and_vars])

np.testing.assert_allclose(var0.read_value(), [0.6, 1.6])
np.testing.assert_allclose(var1.read_value(), [2.96, 3.96])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy])
def test_sparse():
var0 = tf.Variable([[1.0, 2.0, 0.0], [1.0, 2.0, 0.0]])
var1 = tf.Variable([[3.0, 4.0, 0.0]])
Expand All @@ -60,38 +62,13 @@ def test_sparse():

grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0))
strategy = tf.distribute.get_strategy()
for _ in range(8):
opt.apply_gradients(grads_and_vars)
strategy.run(opt.apply_gradients, [grads_and_vars])
np.testing.assert_allclose(var0.read_value(), [[1.0, 2.0, 0.0], [0.2, 1.2, 0.0]])
np.testing.assert_allclose(var1.read_value(), [[2.92, 3.92, 0.0]])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.needs_gpu
def test_sparse_multi_gpus():
strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing())
with strategy.scope():
var0 = tf.Variable([[1.0, 2.0, 0.0]])
var1 = tf.Variable([[3.0, 4.0, 0.0]])

grads0 = tf.IndexedSlices(
tf.constant([[0.1, 0.1, 0.0]]),
tf.constant([0]),
tf.constant([1, 3]),
)
grads1 = tf.IndexedSlices(
tf.constant([[0.01, 0.01, 0.0]]),
tf.constant([0]),
tf.constant([1, 3]),
)

grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
opt = GradientAccumulator(tf.keras.optimizers.SGD(lr=1.0))
strategy.run(opt.apply_gradients, [grads_and_vars])
np.testing.assert_allclose(var0.read_value(), [[1.0, 2.0, 0.0]])
np.testing.assert_allclose(var1.read_value(), [[3.0, 4.0, 0.0]])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_dense():
grad = tf.Variable([[0.1]])
Expand Down Expand Up @@ -133,7 +110,7 @@ def test_config():


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.needs_gpu
@pytest.mark.with_device([tf.distribute.MirroredStrategy])
def test_fit_simple_linear_model():
seed = 0x2019
np.random.seed(seed)
Expand All @@ -142,13 +119,12 @@ def test_fit_simple_linear_model():
x = np.random.standard_normal((num_examples, 3))
w = np.random.standard_normal((3, 1))
y = np.dot(x, w) + np.random.standard_normal((num_examples, 1)) * 1e-4
strategy = tf.distribute.MirroredStrategy(test_utils.gpus_for_testing())
with strategy.scope():
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))

opt = GradientAccumulator("sgd")
model.compile(opt, loss="mse")
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(input_shape=(3,), units=1))

opt = GradientAccumulator("sgd")
model.compile(opt, loss="mse")

model.fit(x, y, epochs=5)

Expand Down

0 comments on commit 7d2d553

Please sign in to comment.