From ce63623421750c95276971ba1086e89dcdc0d112 Mon Sep 17 00:00:00 2001 From: XiaobingZhang Date: Mon, 10 Feb 2025 22:31:15 +0800 Subject: [PATCH] works for fp8 with deepspeed (#3361) * works for fp8 with deepspeed * Add tests --------- Co-authored-by: [[ -z $EMAIL ]] && read -e -p "Enter your email (for git configuration): " EMAIL --- .../usage_guides/low_precision_training.md | 12 +-- examples/config_yaml_templates/fp8.yaml | 4 +- src/accelerate/accelerator.py | 12 ++- src/accelerate/commands/launch.py | 5 + src/accelerate/utils/deepspeed.py | 9 +- src/accelerate/utils/launch.py | 8 +- tests/test_configs/0_34_0_fp8.yaml | 6 +- tests/test_fp8.py | 97 +++++++++++++++++++ 8 files changed, 135 insertions(+), 18 deletions(-) create mode 100644 tests/test_fp8.py diff --git a/docs/source/usage_guides/low_precision_training.md b/docs/source/usage_guides/low_precision_training.md index 80dad01525c..c730136e1ce 100644 --- a/docs/source/usage_guides/low_precision_training.md +++ b/docs/source/usage_guides/low_precision_training.md @@ -53,13 +53,13 @@ accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs) ```{yaml} mixed_precision: fp8 fp8_config: - amax_compute_algorithm: max - amax_history_length: 1024 + amax_compute_algo: max + amax_history_len: 1024 backend: TE fp8_format: HYBRID interval: 1 margin: 0 - override_linear_precision: false + override_linear_precision: (false, false, false) use_autocast_during_eval: false ``` @@ -114,13 +114,13 @@ Similarly this can be set in your `config.yaml`: ```{yaml} mixed_precision: fp8 fp8_config: - amax_compute_algorithm: max - amax_history_length: 1024 + amax_compute_algo: max + amax_history_len: 1024 backend: TE fp8_format: HYBRID interval: 1 margin: 0 - override_linear_precision: false + override_linear_precision: (false, false, false) use_autocast_during_eval: false ``` diff --git a/examples/config_yaml_templates/fp8.yaml b/examples/config_yaml_templates/fp8.yaml index 4e81ac8e9fb..256bca970d2 100644 --- a/examples/config_yaml_templates/fp8.yaml +++ b/examples/config_yaml_templates/fp8.yaml @@ -7,11 +7,11 @@ fp8_config: backend: TE # Can be TE | MS-AMP # The following are TE specific arguments. # See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#common-api for more details - amax_history_length: 1024 + amax_history_len: 1024 fp8_format: E4M3 interval: 1 margin: 0 - override_linear_precision: false + override_linear_precision: (false, false, false) # Generally this should always be set to `false` to have the most realistic fp8 eval performance use_autocast_during_eval: false # If using MS-AMP, we ignore all of the prior and set a opt_level diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 800b0965f2f..a483f0d1a39 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -462,13 +462,14 @@ def __init__( **kwargs, ) - if self.state.mixed_precision == "fp8" and self.fp8_recipe_handler is None: + self._mixed_precision = mixed_precision + if mixed_precision == "fp8" and self.fp8_recipe_handler is None: self.fp8_recipe_handler = FP8RecipeKwargs() self.delayed_fp8_autocast = False if self.fp8_recipe_handler is not None: # We already check if FP8 is available during `self.state` - if self.state.mixed_precision != "fp8" and ( + if mixed_precision != "fp8" and ( self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED) ): raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.") @@ -536,7 +537,10 @@ def __init__( if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available(): raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.") - elif self.state.mixed_precision == "fp8": + # for DeepSpeed, self.state.mixed_precision is always "bf16", + # see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and + # https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1263. + elif mixed_precision == "fp8" or self.state.mixed_precision == "fp8": # We always enable `native_amp` for FP8 self.native_amp = True if self.fp8_backend == "MSAMP": @@ -3643,7 +3647,7 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None: @property def fp8_backend(self): "Returns the configured backend for training in FP8" - if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None: + if self._mixed_precision == "fp8" and self.fp8_recipe_handler is not None: return self.fp8_recipe_handler.backend elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp: return "MSAMP" diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 9a5ece87675..45e98a8c9ac 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -1060,6 +1060,11 @@ def _validate_launch_command(args): setattr(args, k, defaults.ipex_config[k]) for k in defaults.mpirun_config: setattr(args, k, defaults.mpirun_config[k]) + for k in defaults.fp8_config: + arg_to_set = k + if "fp8" not in arg_to_set: + arg_to_set = "fp8_" + arg_to_set + setattr(args, arg_to_set, defaults.fp8_config[k]) continue # Those args are handled separately diff --git a/src/accelerate/utils/deepspeed.py b/src/accelerate/utils/deepspeed.py index 3942b696f3e..32e4d4842e9 100644 --- a/src/accelerate/utils/deepspeed.py +++ b/src/accelerate/utils/deepspeed.py @@ -143,8 +143,13 @@ def __init__(self, config_file_or_dict): config = json.load(f) else: try: - config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8") - config = json.loads(config_decoded) + try: + # First try parsing as JSON directly + config = json.loads(config_file_or_dict) + except json.JSONDecodeError: + # If that fails, try base64 decoding + config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8") + config = json.loads(config_decoded) except (UnicodeDecodeError, AttributeError, ValueError): raise ValueError( f"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config_file_or_dict}" diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index 68b6355912a..f5a14ef102c 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -83,7 +83,13 @@ def setup_fp8_env(args: argparse.Namespace, current_env: Dict[str, str]): if arg.startswith("fp8_"): value = getattr(args, arg) if value is not None: - current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg)) + if arg == "fp8_override_linear_precision": + values = value.strip("()").split(",") + current_env[prefix + "FP8_OVERRIDE_FPROP"] = values[0].strip() + current_env[prefix + "FP8_OVERRIDE_DGRAD"] = values[1].strip() + current_env[prefix + "FP8_OVERRIDE_WGRAD"] = values[2].strip() + else: + current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg)) return current_env diff --git a/tests/test_configs/0_34_0_fp8.yaml b/tests/test_configs/0_34_0_fp8.yaml index 21bce1d93dc..4c414b00478 100644 --- a/tests/test_configs/0_34_0_fp8.yaml +++ b/tests/test_configs/0_34_0_fp8.yaml @@ -4,13 +4,13 @@ distributed_type: MULTI_GPU downcast_bf16: 'no' enable_cpu_affinity: false fp8_config: - amax_compute_algorithm: max - amax_history_length: 1024 + amax_compute_algo: max + amax_history_len: 1024 backend: TE fp8_format: E4M3 interval: 1 margin: 0 - override_linear_precision: false + override_linear_precision: (false, false, false) use_autocast_during_eval: false gpu_ids: all machine_rank: 0 diff --git a/tests/test_fp8.py b/tests/test_fp8.py new file mode 100644 index 00000000000..eb35f183b6a --- /dev/null +++ b/tests/test_fp8.py @@ -0,0 +1,97 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import unittest + +import torch + +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.test_utils import get_launch_command, require_cuda, require_multi_gpu, require_transformer_engine +from accelerate.test_utils.testing import require_deepspeed, run_command +from accelerate.utils import FP8RecipeKwargs, has_transformer_engine_layers + + +def can_convert_model(): + print("Starting basic_fp8_test") + accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]} + accelerator = Accelerator(**accelerator_kwargs) + dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2) + model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16)) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + assert has_transformer_engine_layers(model) + + +def maintain_proper_deepspeed_config(expected_version): + assert ( + AcceleratorState().deepspeed_plugin.zero_stage == expected_version + ), f"Expected zero stage {expected_version} but got {AcceleratorState().deepspeed_plugin.zero_stage}" + + +@require_transformer_engine +class TestTransformerEngine(unittest.TestCase): + @require_cuda + def test_can_prepare_model_single_gpu(self): + command = get_launch_command(num_processes=1, monitor_interval=0.1) + command += ["-m", "tests.test_fp8"] + run_command(command) + + @require_multi_gpu + def test_can_prepare_model_multi_gpu(self): + command = get_launch_command(num_processes=2, monitor_interval=0.1) + command += ["-m", "tests.test_fp8"] + run_command(command) + + @require_deepspeed + @require_multi_gpu + def test_can_prepare_model_multigpu_deepspeed(self): + for zero_stage in [1, 2, 3]: + os.environ["ZERO_STAGE"] = str(zero_stage) + ds_config = { + "bf16": {"enabled": True}, + "zero_optimization": { + "stage": zero_stage, + "allgather_partitions": True, + "allgather_bucket_size": 2e8, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 2e8, + "contiguous_gradients": True, + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": False, + } + + ds_config = json.dumps(ds_config) + + command = get_launch_command( + num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config + ) + command += ["-m", "tests.test_fp8"] + run_command(command) + + +if __name__ == "__main__": + can_convert_model() + if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": + maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))