Skip to content

Commit

Permalink
Fix gradient accumulation for Z2+offload (#6550)
Browse files Browse the repository at this point in the history
The ZeRO 1/2 optimizer performs incorrect gradient accumulation in the
path for ZeRO2 + Offloading. This issue is caused by two main reasons:

1) The micro_step_id in the ZeRO 1/2 optimizer is:

- Initialized to 0 in the constructor.
- Reset to -1 during the backward pass.

For example, given a gradient accumulation step of 4, the micro_step_id
changes as follows:

- For the first global step: 1, 2, 3, 4.
- Subsequently: 0, 1, 2, 3.

2) Gradients are copied to the buffer on the first micro step and
accumulated in the buffer during the following micro steps. However, the
current code incorrectly copies gradients at steps that are not at the
accumulation boundary.

This PR aligns the micro_step_id initialization in both the constructor
and the backward pass, and corrects the condition for copying and
accumulating gradients.

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 26, 2024
1 parent 0fbe96a commit c85c870
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients'
OPTIMIZER_STEP_TIMER = 'optimizer_step'
OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER]
INITIAL_MICRO_STEP_ID = -1


def input(msg):
Expand Down Expand Up @@ -224,7 +225,7 @@ def __init__(self,
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.micro_step_id = INITIAL_MICRO_STEP_ID
self.ignore_unused_parameters = ignore_unused_parameters
self.round_robin_gradients = round_robin_gradients

Expand Down Expand Up @@ -1231,9 +1232,7 @@ def copy_gradients_to_cpu():

if self.micro_step_id > 0:
accumulate_gradients()

# at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
else:
copy_gradients_to_cpu()

def set_norm_for_param_grad(self, param):
Expand Down Expand Up @@ -1824,7 +1823,7 @@ def step(self, closure=None):
"""
Not supporting closure.
"""
self.micro_step_id = -1
self.micro_step_id = INITIAL_MICRO_STEP_ID

see_memory_usage(f"In step before checking overflow")

Expand Down

0 comments on commit c85c870

Please sign in to comment.