Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training multiple models #7018

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f488b46
Support multiple engines
tjruwase Sep 16, 2024
6564884
Use module backward prehook
tjruwase Sep 17, 2024
7745ae5
Remove pdb
tjruwase Sep 20, 2024
a4ea120
Remove dead code
tjruwase Sep 20, 2024
3a7a94f
Add module forward hooks
tjruwase Oct 2, 2024
1ad8276
Rebase branch
tjruwase Feb 8, 2025
1e2595f
Formatting
tjruwase Feb 8, 2025
4abfd9f
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase Feb 8, 2025
204c4dd
Cleanup
tjruwase Feb 8, 2025
78e1915
Cleanup
tjruwase Feb 8, 2025
16d60bc
Bug fix
tjruwase Feb 8, 2025
02477ce
Prepare gradient handling in zero stage 1 & 2
tjruwase Feb 8, 2025
2f84032
Merge branch 'master' into olruwase/zero_multi_models
tjruwase Feb 11, 2025
5ded3a9
Merge branch 'master' into olruwase/zero_multi_models
loadams Feb 19, 2025
69c1489
Merge branch 'master' into olruwase/zero_multi_models
tjruwase Feb 23, 2025
3b86860
Add unit tests
tjruwase Feb 25, 2025
7edabdb
Merge branch 'master' into olruwase/zero_multi_models
tjruwase Feb 25, 2025
a7744da
Formatting
tjruwase Feb 25, 2025
b5c556d
Fix CI failures due to curriculum learning
tjruwase Feb 26, 2025
c20e393
Merge branch 'master' into olruwase/zero_multi_models
tjruwase Feb 26, 2025
19e9c1d
Merge branch 'master' into olruwase/zero_multi_models
tjruwase Feb 28, 2025
51f7bf6
Merge branch 'master' into olruwase/zero_multi_models
tjruwase Mar 4, 2025
2d267f8
Update deepspeed/runtime/engine.py
tjruwase Mar 5, 2025
51de671
Update deepspeed/runtime/engine.py
tjruwase Mar 5, 2025
acc22ee
Multiple models with indepdent loss (legacy case)
tjruwase Mar 6, 2025
86f08f8
Merge branch 'olruwase/zero_multi_models' of github.com:microsoft/Dee…
tjruwase Mar 6, 2025
e5a4958
Update UT and docs
tjruwase Mar 6, 2025
82372c3
Tweak RTD
tjruwase Mar 6, 2025
f50f5e6
Tweak RTD
tjruwase Mar 6, 2025
e7fe814
Merge branch 'master' into olruwase/zero_multi_models
loadams Mar 6, 2025
f7792f9
Merge branch 'olruwase/zero_multi_models' of github.com:microsoft/Dee…
tjruwase Mar 8, 2025
3002fce
Merge branch 'master' into olruwase/zero_multi_models
tjruwase Mar 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False,
immediate_grad_update=False,
immediate_grad_update=True,
has_moe_layers=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
Expand Down Expand Up @@ -313,7 +313,7 @@ def step(self, closure=None):

self.clear_hp_grads()

def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
"""Perform a backward pass and copy the low-precision gradients to the
high-precision copy.

Expand All @@ -323,7 +323,7 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg
The low-precision grads are deallocated during this procedure.
"""
self.clear_lp_grads()
loss.backward(**bwd_kwargs)
loss.backward(retain_graph=retain_graph, **bwd_kwargs)

if update_hp_grads:
self.update_hp_grads(clear_lp_grads=clear_lp_grads)
Expand Down Expand Up @@ -425,9 +425,6 @@ def update_lp_params(self):
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
# if i == 0:
# print_rank_0(f'{fp32_partition[:10]=}', force=True)

all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
partitioned_param_groups=self.bf16_partitioned_groups,
Expand All @@ -442,10 +439,12 @@ def clear_hp_grads(self):
for i, group in enumerate(self.fp32_groups_gradients):
self.fp32_groups_has_gradients[i] = [False] * len(group)

def clear_lp_grads(self):
def clear_lp_grads(self, set_to_none=False):

# using zero_() fixed memory address for graph replay
set_to_none = False if self.graph_harvesting else True
if self.graph_harvesting:
assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None"

zero_grads_list = []
for group in self.bf16_groups:
for param in group:
Expand All @@ -458,6 +457,10 @@ def clear_lp_grads(self):
if not set_to_none and len(zero_grads_list) > 0:
torch._foreach_zero_(zero_grads_list)

def zero_grad(self, set_to_none=True):
self.clear_lp_grads(set_to_none)
self.clear_hp_grads()

def state_dict(self):
state_dict = {}
state_dict[CLIP_GRAD] = self.clip_grad
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@

# BFLOAT16 optimizer immediate gradient update
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True

#########################################
# FP16 support
Expand Down
186 changes: 107 additions & 79 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def __init__(self,
# Configure distributed model
self._configure_distributed_model(model)

self.module_forward_pre_hook = self._create_module_forward_pre_hook()
self.module_forward_post_hook = self._create_module_forward_post_hook()
self.module_backward_pre_hook = self._create_module_backward_pre_hook()

# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}

Expand Down Expand Up @@ -1870,7 +1874,6 @@ def deepspeed_io(self,
GLOBAL_RANK: self.global_rank,
DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
}

return DeepSpeedDataLoader(dataset=dataset,
batch_size=batch_size,
pin_memory=pin_memory,
Expand Down Expand Up @@ -1917,17 +1920,30 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None):

return scaled_loss

@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
def _create_module_backward_pre_hook(self):

if self.autotuning_profile_model_info():
ma = get_ma_status()
else:
def _module_backward_hook(module, grad_output):
if hasattr(self.optimizer, 'backward_prologue'):
self.optimizer.backward_prologue()

return self.module.register_full_backward_pre_hook(_module_backward_hook)

def _create_module_forward_pre_hook(self):

def _module_forward_pre_hook(module, inputs):
self._forward_prologue(inputs)

return self.module.register_forward_pre_hook(_module_forward_pre_hook)

def _create_module_forward_post_hook(self):

def _module_forward_post_hook(module, input, output):
self._forward_epilogue()

return self.module.register_forward_hook(_module_forward_post_hook)

def _forward_prologue(self, inputs, kwargs=None):
if not self.autotuning_profile_model_info():
see_memory_usage("Engine before forward", force=self.memory_breakdown())

flops_profiler_active = (self.flops_profiler_enabled()
Expand All @@ -1950,54 +1966,71 @@ def forward(self, *inputs, **kwargs):
if flops_profiler_active:
self.flops_profiler.start_profile(ignore_list=None)

if self.module.training:
if self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())
if kwargs is not None:
if self.module.training:
if self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())

if self.__class__.__name__ != "PipelineEngine":
# TODO: The above if condition is a HACK since for PipelineEngine
# it's difficult to inject argument in forward pass.
if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
if self.__class__.__name__ != "PipelineEngine":
# TODO: The above if condition is a HACK since for PipelineEngine
# it's difficult to inject argument in forward pass.
if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})

if self.module.training and self.random_ltd_enabled():
self.random_ltd_scheduler.update_seq(self.global_steps)

if self.training_dataloader is None:
self.tput_timer.start()

self._start_timers(self.engine_timers.forward_timers)

if self.zero_optimization_partition_weights():
# Enable automated discovery of external parameters by indicating that
# we are in a forward pass.
for module in self.module.modules():
module._parameters._in_forward = True

self._start_timers(self.engine_timers.forward_timers)

if self.training_dataloader is None:
self.tput_timer.start()

if self.fp16_auto_cast():
inputs = self._cast_inputs_half(inputs)

loss = self.module(*inputs, **kwargs)

def _forward_epilogue(self):
if self.zero_optimization_partition_weights():
# Disable automated discovery of external parameters
for module in self.module.modules():
module._parameters._in_forward = False

self._stop_timers(self.engine_timers.forward_timers)

flops_profiler_active = (self.flops_profiler_enabled()
and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)

if flops_profiler_active:
self.flops_profiler.stop_profile()

if not self.autotuning_profile_model_info():
see_memory_usage("Engine after forward", force=self.memory_breakdown())

@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""
if self.autotuning_profile_model_info():
ma = get_ma_status()

loss = self.module(*inputs, **kwargs)

if self.autotuning_profile_model_info():
activation_mem = get_ma_status() - ma
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())
exit()
else:
see_memory_usage("Engine after forward", force=self.memory_breakdown())

return loss

def _cast_inputs_half(self, inputs):
Expand Down Expand Up @@ -2056,43 +2089,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
grads = None
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)

@contextmanager
def no_sync(self):
r"""
Context manager to disable gradient reduction during backward pass.
This context manager has the following effects on other DeepSpeed features.
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2. It is illegal to call engine.step() within the context manager.
3. Tracking of gradient accumulation steps is disabled.
"""
assert not self.zero_optimization_partition_gradients(), \
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"

assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"

self.inside_no_sync_ctxt = True
try:
yield
finally:
self.inside_no_sync_ctxt = False

@instrument_w_nvtx
def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""

def _backward_prologue(self, loss):
see_memory_usage("Engine before backward", force=self.memory_breakdown())

if self.scale_wrt_gas is not None:
scale_wrt_gas = self.scale_wrt_gas

do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt

# scale loss w.r.t. gradient accumulation if reduction is not disabled
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
loss = self._scale_loss_by_gas(loss.float())

Expand All @@ -2109,13 +2112,18 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
)]
self.monitor.write_events(self.summary_events)

self._start_timers(self.engine_timers.backward_timers)
return loss

assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
"must provide optimizer during init in order to use backward"
def _backward_epilogue(self):
self._start_timers(self.engine_timers.backward_reduce_timers)
if self.enable_backward_allreduce and not self.inside_no_sync_ctxt:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients()
self._stop_timers(self.engine_timers.backward_reduce_timers)
see_memory_usage("Engine after backward", force=self.memory_breakdown())

def _do_optimizer_backward(self, loss, retain_graph):
self._start_timers(self.engine_timers.backward_inner_timers)

if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
self.optimizer.backward(loss, retain_graph=retain_graph)
Expand All @@ -2131,30 +2139,50 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
else:
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.bfloat16_enabled():
self.optimizer.backward(loss)
self.optimizer.backward(loss, retain_graph=retain_graph)
else:
if self.eigenvalue_enabled():
loss.backward(create_graph=True, retain_graph=True)
else:
loss.backward(retain_graph=retain_graph)

self._stop_timers(self.engine_timers.backward_inner_timers)

self._start_timers(self.engine_timers.backward_reduce_timers)

if do_gradient_reduction:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients()
@contextmanager
def no_sync(self):
r"""
Context manager to disable gradient reduction during backward pass.
This context manager has the following effects on other DeepSpeed features.
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2. It is illegal to call engine.step() within the context manager.
3. Tracking of gradient accumulation steps is disabled.
"""
assert not self.zero_optimization_partition_gradients(), \
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"

self._stop_timers(self.engine_timers.backward_reduce_timers)
assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"

self._stop_timers(self.engine_timers.backward_timers)
self.inside_no_sync_ctxt = True
try:
yield
finally:
self.inside_no_sync_ctxt = False

if release_loss:
# loss.data = None
pass
@instrument_w_nvtx
def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
"must provide optimizer during init in order to use backward"

see_memory_usage("Engine after backward", force=self.memory_breakdown())
self._start_timers(self.engine_timers.backward_timers)
loss = self._backward_prologue(loss)
self._do_optimizer_backward(loss, retain_graph)
self._backward_epilogue()
self._stop_timers(self.engine_timers.backward_timers)

return loss

Expand Down
Loading
Loading