Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient accumulation for Z2+offload #6550

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients'
OPTIMIZER_STEP_TIMER = 'optimizer_step'
OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER]
INITIAL_MICRO_STEP_ID = -1


def input(msg):
Expand Down Expand Up @@ -224,7 +225,7 @@ def __init__(self,
self.gradient_predivide_factor = gradient_predivide_factor
self.postscale_gradients = postscale_gradients
self.gradient_accumulation_steps = gradient_accumulation_steps
self.micro_step_id = 0
self.micro_step_id = INITIAL_MICRO_STEP_ID
self.ignore_unused_parameters = ignore_unused_parameters
self.round_robin_gradients = round_robin_gradients

Expand Down Expand Up @@ -1231,9 +1232,7 @@ def copy_gradients_to_cpu():

if self.micro_step_id > 0:
accumulate_gradients()

# at the boundary we will send 32bit directly
if not self.is_gradient_accumulation_boundary:
else:
copy_gradients_to_cpu()

def set_norm_for_param_grad(self, param):
Expand Down Expand Up @@ -1824,7 +1823,7 @@ def step(self, closure=None):
"""
Not supporting closure.
"""
self.micro_step_id = -1
self.micro_step_id = INITIAL_MICRO_STEP_ID

see_memory_usage(f"In step before checking overflow")

Expand Down
Loading