Skip to content

Commit

Permalink
works for fp8 with deepspeed (#3361)
Browse files Browse the repository at this point in the history
* works for fp8 with deepspeed

* Add tests

---------

Co-authored-by: [[ -z $EMAIL ]] && read -e -p "Enter your email (for git configuration): " EMAIL <muellerzr@gmail.com>
  • Loading branch information
XiaobingSuper and muellerzr authored Feb 10, 2025
1 parent f19b957 commit ce63623
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 18 deletions.
12 changes: 6 additions & 6 deletions docs/source/usage_guides/low_precision_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```{yaml}
mixed_precision: fp8
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
amax_compute_algo: max
amax_history_len: 1024
backend: TE
fp8_format: HYBRID
interval: 1
margin: 0
override_linear_precision: false
override_linear_precision: (false, false, false)
use_autocast_during_eval: false
```

Expand Down Expand Up @@ -114,13 +114,13 @@ Similarly this can be set in your `config.yaml`:
```{yaml}
mixed_precision: fp8
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
amax_compute_algo: max
amax_history_len: 1024
backend: TE
fp8_format: HYBRID
interval: 1
margin: 0
override_linear_precision: false
override_linear_precision: (false, false, false)
use_autocast_during_eval: false
```

Expand Down
4 changes: 2 additions & 2 deletions examples/config_yaml_templates/fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ fp8_config:
backend: TE # Can be TE | MS-AMP
# The following are TE specific arguments.
# See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#common-api for more details
amax_history_length: 1024
amax_history_len: 1024
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
override_linear_precision: (false, false, false)
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
use_autocast_during_eval: false
# If using MS-AMP, we ignore all of the prior and set a opt_level
Expand Down
12 changes: 8 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,14 @@ def __init__(
**kwargs,
)

if self.state.mixed_precision == "fp8" and self.fp8_recipe_handler is None:
self._mixed_precision = mixed_precision
if mixed_precision == "fp8" and self.fp8_recipe_handler is None:
self.fp8_recipe_handler = FP8RecipeKwargs()

self.delayed_fp8_autocast = False
if self.fp8_recipe_handler is not None:
# We already check if FP8 is available during `self.state`
if self.state.mixed_precision != "fp8" and (
if mixed_precision != "fp8" and (
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
):
raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.")
Expand Down Expand Up @@ -536,7 +537,10 @@ def __init__(
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available():
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")

elif self.state.mixed_precision == "fp8":
# for DeepSpeed, self.state.mixed_precision is always "bf16",
# see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1263.
elif mixed_precision == "fp8" or self.state.mixed_precision == "fp8":
# We always enable `native_amp` for FP8
self.native_amp = True
if self.fp8_backend == "MSAMP":
Expand Down Expand Up @@ -3643,7 +3647,7 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:
@property
def fp8_backend(self):
"Returns the configured backend for training in FP8"
if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None:
if self._mixed_precision == "fp8" and self.fp8_recipe_handler is not None:
return self.fp8_recipe_handler.backend
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
return "MSAMP"
Expand Down
5 changes: 5 additions & 0 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,11 @@ def _validate_launch_command(args):
setattr(args, k, defaults.ipex_config[k])
for k in defaults.mpirun_config:
setattr(args, k, defaults.mpirun_config[k])
for k in defaults.fp8_config:
arg_to_set = k
if "fp8" not in arg_to_set:
arg_to_set = "fp8_" + arg_to_set
setattr(args, arg_to_set, defaults.fp8_config[k])
continue

# Those args are handled separately
Expand Down
9 changes: 7 additions & 2 deletions src/accelerate/utils/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,13 @@ def __init__(self, config_file_or_dict):
config = json.load(f)
else:
try:
config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8")
config = json.loads(config_decoded)
try:
# First try parsing as JSON directly
config = json.loads(config_file_or_dict)
except json.JSONDecodeError:
# If that fails, try base64 decoding
config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8")
config = json.loads(config_decoded)
except (UnicodeDecodeError, AttributeError, ValueError):
raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config_file_or_dict}"
Expand Down
8 changes: 7 additions & 1 deletion src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ def setup_fp8_env(args: argparse.Namespace, current_env: Dict[str, str]):
if arg.startswith("fp8_"):
value = getattr(args, arg)
if value is not None:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
if arg == "fp8_override_linear_precision":
values = value.strip("()").split(",")
current_env[prefix + "FP8_OVERRIDE_FPROP"] = values[0].strip()
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = values[1].strip()
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = values[2].strip()
else:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
return current_env


Expand Down
6 changes: 3 additions & 3 deletions tests/test_configs/0_34_0_fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
amax_compute_algo: max
amax_history_len: 1024
backend: TE
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
override_linear_precision: (false, false, false)
use_autocast_during_eval: false
gpu_ids: all
machine_rank: 0
Expand Down
97 changes: 97 additions & 0 deletions tests/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import unittest

import torch

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.test_utils import get_launch_command, require_cuda, require_multi_gpu, require_transformer_engine
from accelerate.test_utils.testing import require_deepspeed, run_command
from accelerate.utils import FP8RecipeKwargs, has_transformer_engine_layers


def can_convert_model():
print("Starting basic_fp8_test")
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]}
accelerator = Accelerator(**accelerator_kwargs)
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
assert has_transformer_engine_layers(model)


def maintain_proper_deepspeed_config(expected_version):
assert (
AcceleratorState().deepspeed_plugin.zero_stage == expected_version
), f"Expected zero stage {expected_version} but got {AcceleratorState().deepspeed_plugin.zero_stage}"


@require_transformer_engine
class TestTransformerEngine(unittest.TestCase):
@require_cuda
def test_can_prepare_model_single_gpu(self):
command = get_launch_command(num_processes=1, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
run_command(command)

@require_multi_gpu
def test_can_prepare_model_multi_gpu(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
run_command(command)

@require_deepspeed
@require_multi_gpu
def test_can_prepare_model_multigpu_deepspeed(self):
for zero_stage in [1, 2, 3]:
os.environ["ZERO_STAGE"] = str(zero_stage)
ds_config = {
"bf16": {"enabled": True},
"zero_optimization": {
"stage": zero_stage,
"allgather_partitions": True,
"allgather_bucket_size": 2e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 2e8,
"contiguous_gradients": True,
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": False,
}

ds_config = json.dumps(ds_config)

command = get_launch_command(
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
)
command += ["-m", "tests.test_fp8"]
run_command(command)


if __name__ == "__main__":
can_convert_model()
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))

0 comments on commit ce63623

Please sign in to comment.