Skip to content

Commit

Permalink
decrease memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Jul 20, 2021
1 parent e62cc95 commit 64b70b4
Showing 1 changed file with 4 additions and 16 deletions.
20 changes: 4 additions & 16 deletions tensorflow_addons/optimizers/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,12 @@ def _accum_grad(grads_and_vars):
with tf.init_scope():
if not self._gradients:
for grad, var in grads_and_vars:
if tf.distribute.has_strategy():
for v in var.values:
self._gradients[v.ref()] = tf.Variable(
tf.zeros_like(v), trainable=False
)
else:
self._gradients[var.ref()] = tf.Variable(
tf.zeros_like(var), trainable=False
)
self._gradients[var.ref()] = tf.Variable(
tf.zeros_like(var), trainable=False
)
new_grads_and_vars = []
for grad, var in grads_and_vars:
if tf.distribute.has_strategy():
replica_id = tf.get_static_value(
tf.distribute.get_replica_context().replica_id_in_sync_group
)
handle = self._gradients[var.values[replica_id].ref()]
else:
handle = self._gradients[var.ref()]
handle = self._gradients[var.ref()]

if isinstance(grad, tf.IndexedSlices):
handle.scatter_add(grad)
Expand Down

0 comments on commit 64b70b4

Please sign in to comment.