Skip to content

Commit

Permalink
add autoTP training zero2 tests (deepspeedai#7049)
Browse files Browse the repository at this point in the history
- add zero2 test
- minor fix with transformer version update & ds master merge.

Signed-off-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Max Kovalenko <mkovalenko@habana.ai>
  • Loading branch information
2 people authored and deepcharm committed Feb 27, 2025
1 parent 3205186 commit 52fbbc1
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
4 changes: 4 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
return new_module

def set_lm_head(module):
if is_autotp_training_mode():
# we need to handle autoTP training mode separately.
return

embedding_weight = None
for n, p in module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def _configure_tensor_parallel_states(self, model):
# sanity check
# currently, the compatibility between 'autotp' and 'zero > 1' has not been validated
assert self.zero_optimization_stage(
) <= 1, "Currently, the compatibility between 'autotp' and 'zero_stage > 1' has not been validated"
) <= 2, "Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated"

self.mpu = groups
self.mpu._init_tp_mesh_device(tensor_model_parallel_size=self.autotp_size())
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,9 +1134,10 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis
if inputs1.keys() != inputs2.keys():
return False
for key in inputs1:
val1 = inputs1[key].to(get_accelerator().current_device())
val2 = inputs2[key].to(get_accelerator().current_device())
val1, val2 = inputs1[key], inputs2[key]
if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor):
val1 = val1.to(get_accelerator().current_device())
val2 = val2.to(get_accelerator().current_device())
if not torch.equal(val1, val2):
return False
elif val1 != val2:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/model_parallelism/test_autotp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, gro
return model, base_model


@pytest.mark.parametrize("zero_stage", [0, 1])
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
@pytest.mark.parametrize("tp_size", [2, 4])
class TestSave(DistributedTest):

Expand Down Expand Up @@ -492,7 +492,7 @@ def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int):
compare_lr_scheduler_states(trained_model, loaded_model)


@pytest.mark.parametrize("zero_stage", [0, 1])
@pytest.mark.parametrize("zero_stage", [0, 1, 2])
@pytest.mark.parametrize("tp_size", [2, 4])
class TestTpGradNorm(DistributedTest):

Expand Down

0 comments on commit 52fbbc1

Please sign in to comment.