From 52fbbc110d4159c3d06a36bf0668e3a3358cdbe2 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Feb 2025 06:15:03 +0800 Subject: [PATCH] add autoTP training zero2 tests (#7049) - add zero2 test - minor fix with transformer version update & ds master merge. Signed-off-by: inkcherry Co-authored-by: Olatunji Ruwase Signed-off-by: Max Kovalenko --- deepspeed/module_inject/replace_module.py | 4 ++++ deepspeed/runtime/engine.py | 2 +- deepspeed/runtime/utils.py | 5 +++-- tests/unit/model_parallelism/test_autotp_training.py | 4 ++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 9510f96b89c6..ed94a5021fee 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -335,6 +335,10 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): return new_module def set_lm_head(module): + if is_autotp_training_mode(): + # we need to handle autoTP training mode separately. + return + embedding_weight = None for n, p in module.named_parameters(): if "word_embeddings." in n or "embed_tokens." in n or "wte." in n: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 8575df9d1d5d..4d932f8d5046 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -424,7 +424,7 @@ def _configure_tensor_parallel_states(self, model): # sanity check # currently, the compatibility between 'autotp' and 'zero > 1' has not been validated assert self.zero_optimization_stage( - ) <= 1, "Currently, the compatibility between 'autotp' and 'zero_stage > 1' has not been validated" + ) <= 2, "Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated" self.mpu = groups self.mpu._init_tp_mesh_device(tensor_model_parallel_size=self.autotp_size()) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 91fe7cbdcc96..9fd7a65a53ba 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1134,9 +1134,10 @@ def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[Lis if inputs1.keys() != inputs2.keys(): return False for key in inputs1: - val1 = inputs1[key].to(get_accelerator().current_device()) - val2 = inputs2[key].to(get_accelerator().current_device()) + val1, val2 = inputs1[key], inputs2[key] if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + val1 = val1.to(get_accelerator().current_device()) + val2 = val2.to(get_accelerator().current_device()) if not torch.equal(val1, val2): return False elif val1 != val2: diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py index 73e61b1d3398..7680b28ce6b5 100644 --- a/tests/unit/model_parallelism/test_autotp_training.py +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -360,7 +360,7 @@ def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, gro return model, base_model -@pytest.mark.parametrize("zero_stage", [0, 1]) +@pytest.mark.parametrize("zero_stage", [0, 1, 2]) @pytest.mark.parametrize("tp_size", [2, 4]) class TestSave(DistributedTest): @@ -492,7 +492,7 @@ def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int): compare_lr_scheduler_states(trained_model, loaded_model) -@pytest.mark.parametrize("zero_stage", [0, 1]) +@pytest.mark.parametrize("zero_stage", [0, 1, 2]) @pytest.mark.parametrize("tp_size", [2, 4]) class TestTpGradNorm(DistributedTest):