Skip to content

Commit

Permalink
Avoid graph breaks in torch.compile caused by inner classes in the ba…
Browse files Browse the repository at this point in the history
…ckward hooks (deepspeedai#7062)

This PR is part of the effort to improve Deepspeed performance when
using PyTorch compile.

There is a known [bug](pytorch/pytorch#128942)
in torch.compile which causes a graph break when an inner class is
defined within
a method that is being compiled. The following would then appear in the
log:

`[__graph_breaks] torch._dynamo.exc.Unsupported: missing:
LOAD_BUILD_CLASS`

This is the case with the inner classes `PreBackwardFunctionForModule`
and `PostBackwardFunctionModule`.

While there is an open PyTorch [PR#133805
](pytorch/pytorch#133805) for this, we can solve
the issue by moving the inner classes into the initialization code.

No graph breaks and the corresponding logs are produced anymore.

---------

Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Signed-off-by: shaomin <wukon1992@gmail.com>
Signed-off-by: Stas Bekman <stas@stason.org>
Signed-off-by: siqi <siqi@tecorigin.com>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Wei Wu <wuwei211x@gmail.com>
Signed-off-by: ShellyNR <shelly.nahir@live.biu.ac.il>
Signed-off-by: Lai, Yejing <yejing.lai@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: wukong1992 <wukong1992@users.noreply.github.com>
Co-authored-by: shaomin <wukon1992@gmail.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: loadams <loadams@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: siqi654321 <siqi202311@163.com>
Co-authored-by: siqi <siqi@tecorigin.com>
Co-authored-by: Wei Wu <45323446+U-rara@users.noreply.github.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: Shelly Nahir <73890534+ShellyNR@users.noreply.github.com>
Co-authored-by: snahir <snahir@habana.ai>
Co-authored-by: Yejing-Lai <yejing.lai@intel.com>
Signed-off-by: yisheng <yi.sheng@intel.com>
  • Loading branch information
16 people authored and ys950902 committed Mar 6, 2025
1 parent ae6d1cb commit 2cb0224
Showing 1 changed file with 68 additions and 68 deletions.
136 changes: 68 additions & 68 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,39 +341,6 @@ def _bwd_hook_unexpected_inputs_msg(value):

def _pre_backward_module_hook(module, inputs, output):

if not hasattr(module, "pre_bwd_fn"):

@instrument_w_nvtx
def _run_before_backward_function(sub_module):
# some models (e.g. Albert) may run multiple forwards on the same layer in a loop
# before doing backwards, so each backward will need a pre-fetch - using reference
# counting to support this scenario
#print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
if sub_module.applied_pre_backward_ref_cnt > 0:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args

module.pre_bwd_fn = PreBackwardFunctionForModule

return apply_to_tensors_only(module.pre_bwd_fn.apply,
output,
warning_msg_fn=_bwd_hook_unexpected_inputs_msg)
Expand Down Expand Up @@ -402,41 +369,6 @@ def _post_backward_module_hook(module, inputs):
if not hasattr(module, "ds_grads_remaining"):
module.ds_grads_remaining = 0

if not hasattr(module, "post_bwd_fn"):

@instrument_w_nvtx
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args

module.post_bwd_fn = PostBackwardFunctionModule

return apply_to_tensors_only(module.post_bwd_fn.apply,
inputs,
warning_msg_fn=_bwd_hook_unexpected_inputs_msg)
Expand All @@ -448,9 +380,77 @@ def backward(ctx, *args):
self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))

# Pre backward hook
if not hasattr(module, "pre_bwd_fn"):

@instrument_w_nvtx
def _run_before_backward_function(sub_module):
# some models (e.g. Albert) may run multiple forwards on the same layer in a loop
# before doing backwards, so each backward will need a pre-fetch - using reference
# counting to support this scenario
#print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
if sub_module.applied_pre_backward_ref_cnt > 0:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return args

module.pre_bwd_fn = PreBackwardFunctionForModule

self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook))

# post backward hook
if not hasattr(module, "post_bwd_fn"):

@instrument_w_nvtx
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)

class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
ctx.module = module
if output.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
# ctx.view=True
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
#if module.ds_grads_remaining == 0:
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
if ctx.module.ds_grads_remaining == 0:
ctx.post_backward_function(ctx.module)
return args

module.post_bwd_fn = PostBackwardFunctionModule

self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))

@torch.no_grad()
Expand Down

0 comments on commit 2cb0224

Please sign in to comment.