Skip to content

Commit 0a2cd72

Browse files
authored
Support writing optimizer checkpoint only on rank0 and make UT pass on A100 (#142)
**Description** Support writing optimizer checkpoint only on rank0 and make UT pass on A100. - support write checkpoint on rank0. With this PR, we don't need to change [checkpointing](https://github.com/NVIDIA/Megatron-LM/blob/0609f27fe8376f17ab65c001d3d8f35cd8175950/megatron/checkpointing.py) in [MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples/tree/main/gpt3). - Fix some bugs of TransformerEngine integration and make UT pass on A100 - Improve document
1 parent 51f34ac commit 0a2cd72

File tree

9 files changed

+356
-60
lines changed

9 files changed

+356
-60
lines changed

.github/workflows/build-image.yaml

+3-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
uses: actions/checkout@v2
3232
with:
3333
submodules: true
34+
path: buildimage
3435
- name: Free disk space
3536
run: |
3637
mkdir -p /tmp/emptydir
@@ -54,7 +55,7 @@ jobs:
5455
if [[ "${{ github.event_name }}" == "release" ]]; then
5556
TAGS=$(sed "s/main/${GITHUB_REF##*/}/g" <<< ${TAGS})
5657
fi
57-
DOCKERFILE=dockerfile/${{ matrix.name }}.dockerfile
58+
DOCKERFILE=buildimage/dockerfile/${{ matrix.name }}.dockerfile
5859
5960
CACHE_FROM="type=registry,ref=$(cut -d, -f1 <<< ${TAGS})"
6061
CACHE_TO=""
@@ -87,7 +88,7 @@ jobs:
8788
uses: docker/build-push-action@v2
8889
with:
8990
platforms: linux/amd64
90-
context: .
91+
context: ./buildimage
9192
file: ${{ steps.metadata.outputs.dockerfile }}
9293
push: ${{ github.event_name != 'pull_request' }}
9394
tags: ${{ steps.metadata.outputs.tags }}

.github/workflows/unit-tests.yaml

-4
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ jobs:
5757
export LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:/usr/local/lib/libnccl.so:${LD_PRELOAD}"
5858
cd ${{ matrix.dir }}/
5959
python3 setup.py test
60-
- name: Clean repository
61-
if: always()
62-
run: |
63-
rm -rf ${{ matrix.dir }}/
6460
# - name: Report coverage results
6561
# run: |
6662
# bash <(curl -s https://codecov.io/bash)

docs/getting-started/run-msamp.md

+4
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,8 @@ deepspeed cifar10_deepspeed.py --deepspeed --deepspeed_config ds_config_zero_msa
5252
deepspeed cifar10_deepspeed_te.py --deepspeed --deepspeed_config ds_config_zero_te_msamp.json
5353
```
5454

55+
:::note Note
56+
If you get "ModuleNotFoundError: No module named 'timm'" error when running this example, you need to install timm using `pip install timm`.
57+
:::
58+
5559
For more comprehensive examples, please go to [MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples).

docs/introduction.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Here are the results for GPT-3, Swin-T, DeiT-S and RoBERTa-B.
3636

3737
### System performance
3838

39-
MS-AMP preserves high-precision's accuracy while using only a fraction of the memory footprint on a range of tasks, including GPT-3, DeiT and Swin Transformer. For example, when training GPT-175B on NVIDIA H100 platform, MS-AMP achieves a notable 42% reduction in real memory usage compared with BF16 mixed-precision approach and reduces training time by 17% compared with Transformer Engine. For small models, MS-AMP with O2 mode can achieve 44% memory saving for Swin-1.0B and 26% memory saving for ViT-1.2B, comparing with FP16 AMP.
39+
MS-AMP preserves high-precision's accuracy while using only a fraction of the memory footprint on a range of tasks, including GPT-3, DeiT and Swin Transformer. For example, when training GPT-175B on NVIDIA H100 platform, MS-AMP achieves a notable 39% reduction in real memory usage compared with BF16 mixed-precision approach and reduces training time by 37% compared with Transformer Engine. For small models, MS-AMP with O2 mode can achieve 44% memory saving for Swin-1.0B and 26% memory saving for ViT-1.2B, comparing with FP16 AMP.
4040

4141
Here are the resuls for GPT-3:
4242

msamp/megatron/optimizer/distrib_optimizer.py

+297-30
Large diffs are not rendered by default.

msamp/optim/adamw.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,12 @@ def adamw_fn( # noqa: C901
185185

186186
for i, param in enumerate(params):
187187
grad = grads[i].float() if not maximize else -grads[i].float()
188-
exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else 1.0
189-
exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else 1.0
188+
exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else torch.ones((), device='cuda')
189+
exp_avgs[i].meta.scale_inv.fill_(1.0 / exp_avgs[i].meta.scale)
190+
exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else torch.ones(
191+
(), device='cuda'
192+
)
193+
exp_avg_sqs[i].meta.scale_inv.fill_(1.0 / exp_avg_sqs[i].meta.scale)
190194
# update state
191195
msamp_adamw.adamw_fp8_stage2_compute(
192196
grad, exp_avgs[i].value, _exp_avg_inv_factors[i], exp_avgs[i].meta.scale, beta1,

msamp/te/extension.py

+24
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from msamp.common.dtype import Dtypes
1111
from msamp.common.tensor import ScalingTensor
12+
from msamp.nn import ScalingParameter
1213

1314

1415
class TeExtensionOverrider:
@@ -24,6 +25,7 @@ class TeExtensionOverrider:
2425
original_fused_cast_transpose = tex.fused_cast_transpose
2526
original_cast_to_fp8 = te.cpp_extensions.cast_to_fp8
2627
original_fp8_cast_transpose_fused = te.cpp_extensions.fp8_cast_transpose_fused
28+
original_cast_if_needed = te.utils.cast_if_needed
2729

2830
@staticmethod
2931
@torch.no_grad()
@@ -119,6 +121,24 @@ def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None):
119121
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype)
120122
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out)
121123

124+
@staticmethod
125+
def cast_if_needed(tensor, dtype):
126+
"""Cast tensor to dtype.
127+
128+
Args:
129+
tensor (torch.Tensor or ScalingParameter): Input tensor.
130+
dtype (torch.dtype): Output dtype.
131+
132+
Returns:
133+
torch.Tensor: Output tensor.
134+
"""
135+
with torch.enable_grad():
136+
if isinstance(tensor, ScalingParameter):
137+
new_tensor = tensor.to(dtype)
138+
new_tensor.requires_grad = tensor.requires_grad
139+
return new_tensor
140+
return TeExtensionOverrider.original_cast_if_needed(tensor, dtype)
141+
122142
@staticmethod
123143
def override():
124144
"""Override transformer engine extension functions."""
@@ -127,5 +147,9 @@ def override():
127147
te.module.linear.cast_to_fp8 = TeExtensionOverrider.cast_to_fp8
128148
te.cpp_extensions.fp8_cast_transpose_fused = TeExtensionOverrider.fp8_cast_transpose_fused
129149

150+
te.module.layernorm_linear.cast_if_needed = TeExtensionOverrider.cast_if_needed
151+
te.module.linear.cast_if_needed = TeExtensionOverrider.cast_if_needed
152+
te.module.layernorm_mlp.cast_if_needed = TeExtensionOverrider.cast_if_needed
153+
130154

131155
TeExtensionOverrider.override()

msamp/te/modules.py

-2
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@ def _override_classes(cls):
229229
te.attention.Linear = MSAMPLinear
230230
te.attention.LayerNormLinear = MSAMPLayerNormLinear
231231

232-
te.transformer.Linear = MSAMPLinear
233-
te.transformer.LayerNormLinear = MSAMPLayerNormLinear
234232
te.transformer.LayerNormMLP = MSAMPLayerNormMLP
235233

236234
@staticmethod

tests/te/test_te_replacer.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import os
77
import unittest
8+
from contextlib import nullcontext
89

910
import torch
1011
import torch.distributed as dist
@@ -65,17 +66,16 @@ def _check_model(model):
6566

6667
scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)]
6768
assert len(scaling_params) == 4
68-
is_fp8_available = te.fp8.check_fp8_support()
69-
if is_fp8_available:
70-
# Do a forward pass to make sure the model is working.
71-
fp8_format = Format.HYBRID
72-
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
73-
x = torch.rand(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype)
74-
75-
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
76-
y = model(x, attention_mask=None)
77-
assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size)
78-
y.sum().backward()
69+
is_fp8_available, _ = te.fp8.check_fp8_support()
70+
# Do a forward pass to make sure the model is working.
71+
fp8_format = Format.HYBRID
72+
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
73+
x = torch.rand(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype)
74+
75+
with te.fp8_autocast(enabled=is_fp8_available, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext():
76+
y = model(x, attention_mask=None)
77+
assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size)
78+
y.sum().backward()
7979

8080
@decorator.cuda_test
8181
def test_te_with_deepspeed(self):
@@ -100,12 +100,13 @@ def test_te_with_deepspeed(self):
100100

101101
fp8_format = Format.HYBRID
102102
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
103-
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
103+
is_fp8_available, _ = te.fp8.check_fp8_support()
104+
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext():
104105
input = torch.randn(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype)
105106
output = model(input, attention_mask=None)
106-
loss = output.sum()
107-
model.backward(loss)
108-
model.step()
107+
loss = output.sum()
108+
model.backward(loss)
109+
model.step()
109110

110111

111112
class TeReplacerDistributedTestCast(MultiProcessTestCase):
@@ -163,9 +164,10 @@ def test_fp8_ddp_with_te(self):
163164
x = torch.randn(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
164165
fp8_format = Format.HYBRID
165166
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
166-
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
167+
is_fp8_available, _ = te.fp8.check_fp8_support()
168+
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext():
167169
output = model(x, attention_mask=None, is_first_microbatch=True)
168-
output.sum().backward()
169-
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
170+
output.sum().backward()
171+
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext():
170172
output = model(x, attention_mask=None, is_first_microbatch=False)
171-
output.sum().backward()
173+
output.sum().backward()

0 commit comments

Comments
 (0)