Skip to content

Commit

Permalink
Avoid graph breaks in torch.compile caused by inner classes in the ba…
Browse files Browse the repository at this point in the history
…ckward hooks (#7062)

This PR is part of the effort to improve Deepspeed performance when
using PyTorch compile.

There is a known [bug](pytorch/pytorch#128942)
in torch.compile which causes a graph break when an inner class is
defined within
a method that is being compiled. The following would then appear in the
log:

`[__graph_breaks] torch._dynamo.exc.Unsupported: missing:
LOAD_BUILD_CLASS`

This is the case with the inner classes `PreBackwardFunctionForModule`
and `PostBackwardFunctionModule`.

While there is an open PyTorch [PR#133805
](pytorch/pytorch#133805) for this, we can solve
the issue by moving the inner classes into the initialization code.

No graph breaks and the corresponding logs are produced anymore.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Signed-off-by: shaomin <wukon1992@gmail.com>
Signed-off-by: Stas Bekman <stas@stason.org>
Signed-off-by: siqi <siqi@tecorigin.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Wei Wu <wuwei211x@gmail.com>
Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il>
Signed-off-by: Lai, Yejing <yejing.lai@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: wukong1992 <wukong1992@users.noreply.github.com>
Co-authored-by: shaomin <wukon1992@gmail.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: loadams <loadams@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: siqi654321 <siqi202311@163.com>
Co-authored-by: siqi <siqi@tecorigin.com>
Co-authored-by: Wei Wu <45323446+U-rara@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Shelly Nahir <73890534+ShellyNR@users.noreply.github.com>
Co-authored-by: snahir <snahir@habana.ai>
Co-authored-by: Yejing-Lai <yejing.lai@intel.com>
  • Loading branch information
16 people authored Mar 4, 2025
1 parent a88f56a commit 776822f
Showing 1 changed file with 68 additions and 68 deletions.
136 changes: 68 additions & 68 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,39 +341,6 @@ def _bwd_hook_unexpected_inputs_msg(value):

def _pre_backward_module_hook(module, inputs, output):

if not hasattr(module, "pre_bwd_fn"):

@instrument_w_nvtx
def _run_before_backward_function(sub_module):
# some models (e.g. Albert) may run multiple forwards on the same layer in a loop
# before doing backwards, so each backward will need a pre-fetch - using reference
# counting to support this scenario
#print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
if sub_module.applied_pre_backward_ref_cnt > 0:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args

module.pre_bwd_fn = PreBackwardFunctionForModule

return apply_to_tensors_only(module.pre_bwd_fn.apply,
output,
warning_msg_fn=_bwd_hook_unexpected_inputs_msg)
Expand Down Expand Up @@ -402,41 +369,6 @@ def _post_backward_module_hook(module, inputs):
if not hasattr(module, "ds_grads_remaining"):
module.ds_grads_remaining = 0

if not hasattr(module, "post_bwd_fn"):

@instrument_w_nvtx
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args

module.post_bwd_fn = PostBackwardFunctionModule

return apply_to_tensors_only(module.post_bwd_fn.apply,
inputs,
warning_msg_fn=_bwd_hook_unexpected_inputs_msg)
Expand All @@ -448,9 +380,77 @@ def backward(ctx, *args):
self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))

# Pre backward hook
if not hasattr(module, "pre_bwd_fn"):

@instrument_w_nvtx
def _run_before_backward_function(sub_module):
# some models (e.g. Albert) may run multiple forwards on the same layer in a loop
# before doing backwards, so each backward will need a pre-fetch - using reference
# counting to support this scenario
#print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
if sub_module.applied_pre_backward_ref_cnt > 0:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args

module.pre_bwd_fn = PreBackwardFunctionForModule

self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook))

# post backward hook
if not hasattr(module, "post_bwd_fn"):

@instrument_w_nvtx
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args

module.post_bwd_fn = PostBackwardFunctionModule

self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))

@torch.no_grad()
Expand Down

0 comments on commit 776822f

Please sign in to comment.