Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid graph breaks in torch.compile caused by inner classes in the backward hooks #7062

Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
07bd88e
Avoid graph breaks caused by inner classes in the backward hooks
deepcharm Feb 20, 2025
3205186
Rename aio_thread_count to intra_op_parallelism (#7056)
tjruwase Feb 19, 2025
52fbbc1
add autoTP training zero2 tests (#7049)
inkcherry Feb 19, 2025
8b5d864
Fix, bf16 optimizer remove dup loop (#7054)
wukong1992 Feb 20, 2025
e7e9d3e
Update version.txt after 0.16.4 release (#7063)
loadams Feb 20, 2025
2c950ee
fix an outdated doc wrt CUDA_VISIBLE_DEVICES (#7058)
stas00 Feb 20, 2025
a4853fc
Tecorigin sdaa accelerator (#6903)
siqi654321 Feb 20, 2025
7d6fc03
Handle special case of libuv for Windows (#7064)
loadams Feb 20, 2025
d39269d
Update README with info on newest accelerator (#7065)
loadams Feb 21, 2025
3cd7a5c
Bug Fix for offload_states API (#7050)
U-rara Feb 21, 2025
3767709
Fix TOCTOU issues, switch to fstat (#7067)
loadams Feb 24, 2025
495606a
config torch to avoid graph breaks caused by logger (#6999)
ShellyNR Feb 24, 2025
20e8509
Fix meta load tensor imcompatible issue (#7073)
Yejing-Lai Feb 24, 2025
ad7b43b
Replace calls to `python setup.py sdist` with `python -m build --sdis…
loadams Feb 24, 2025
672b918
Revert "Handle special case of libuv for Windows (#7064)" (#7076)
loadams Feb 25, 2025
929a09c
Handle special case of libuv for Windows (#7064)
loadams Feb 20, 2025
623f7cf
Revert "Handle special case of libuv for Windows (#7064)" (#7076)
loadams Feb 25, 2025
fd94138
Add DeepseekV3 AutoTP. (#7045)
Yejing-Lai Feb 26, 2025
7357edc
Improve inference tutorial docs (#7083)
loadams Feb 26, 2025
e7b622b
Merge branch 'master' into avoid-graph-break-caused-by-inner-classes
tjruwase Feb 27, 2025
6284374
Merge branch 'master' into avoid-graph-break-caused-by-inner-classes
loadams Mar 3, 2025
a5bfd57
Merge branch 'master' into avoid-graph-break-caused-by-inner-classes
loadams Mar 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 68 additions & 68 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,39 +341,6 @@ def _bwd_hook_unexpected_inputs_msg(value):

def _pre_backward_module_hook(module, inputs, output):

if not hasattr(module, "pre_bwd_fn"):

@instrument_w_nvtx
def _run_before_backward_function(sub_module):
# some models (e.g. Albert) may run multiple forwards on the same layer in a loop
# before doing backwards, so each backward will need a pre-fetch - using reference
# counting to support this scenario
#print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
if sub_module.applied_pre_backward_ref_cnt > 0:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args

module.pre_bwd_fn = PreBackwardFunctionForModule

return apply_to_tensors_only(module.pre_bwd_fn.apply,
output,
warning_msg_fn=_bwd_hook_unexpected_inputs_msg)
Expand Down Expand Up @@ -402,41 +369,6 @@ def _post_backward_module_hook(module, inputs):
if not hasattr(module, "ds_grads_remaining"):
module.ds_grads_remaining = 0

if not hasattr(module, "post_bwd_fn"):

@instrument_w_nvtx
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args

module.post_bwd_fn = PostBackwardFunctionModule

return apply_to_tensors_only(module.post_bwd_fn.apply,
inputs,
warning_msg_fn=_bwd_hook_unexpected_inputs_msg)
Expand All @@ -448,9 +380,77 @@ def backward(ctx, *args):
self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))

# Pre backward hook
if not hasattr(module, "pre_bwd_fn"):

@instrument_w_nvtx
def _run_before_backward_function(sub_module):
# some models (e.g. Albert) may run multiple forwards on the same layer in a loop
# before doing backwards, so each backward will need a pre-fetch - using reference
# counting to support this scenario
#print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
if sub_module.applied_pre_backward_ref_cnt > 0:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args

module.pre_bwd_fn = PreBackwardFunctionForModule

self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook))

# post backward hook
if not hasattr(module, "post_bwd_fn"):

@instrument_w_nvtx
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args

module.post_bwd_fn = PostBackwardFunctionModule

self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))

@torch.no_grad()
Expand Down