Skip to content

Commit d33acaf

Browse files
inkcherryloadams
authored andcommitted
autotp training(fix dco) (#7004)
Same as [this PR](#6922). [affeb88](affeb88) I noticed the CI updated the DCO check recently. Using the suggested rebase method for sign-off would reintroduce many conflicts, so I opted for a squash merge with sign-off instead. thanks: ) Signed-off-by: inkcherry <mingzhi.liu@intel.com> Signed-off-by: Logan Adams <loadams@microsoft.com>
1 parent 4557ab8 commit d33acaf

File tree

17 files changed

+1662
-164
lines changed

17 files changed

+1662
-164
lines changed

deepspeed/__init__.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
3838
from .runtime.activation_checkpointing import checkpointing
3939
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
40-
from .module_inject import replace_transformer_layer, revert_transformer_layer
40+
from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode
4141

4242
from .utils import log_dist, OnDevice, logger
4343
from .comm.comm import init_distributed
@@ -364,3 +364,34 @@ def init_inference(model, config=None, **kwargs):
364364
engine = InferenceEngine(model, config=ds_inference_config)
365365

366366
return engine
367+
368+
369+
def tp_model_init(model, tp_size, dtype):
370+
"""
371+
Initialize the model for tensor parallelism.
372+
373+
Args:
374+
model (torch.nn.Module): The model to be initialized.
375+
tp_size (int): The tensor parallelism size.
376+
dtype (torch.dtype): The data type to be used for the model.
377+
378+
Returns:
379+
torch.nn.Module: The initialized model with tensor parallelism.
380+
"""
381+
# avoid re-entry
382+
assert not hasattr(
383+
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."
384+
385+
set_autotp_mode(training=True)
386+
387+
from deepspeed.runtime.tensor_parallel import TpTrainingManager
388+
# The expected usage here is for it to be invoked by transformers package.
389+
390+
#TODO: We should provide a custom TP mapping solution without using autoTP
391+
#as modifying the autoTP logic may be more difficult for users compared to configuring it
392+
393+
model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module
394+
395+
setattr(model, 'ds_autotp_parsed', True)
396+
397+
return model

deepspeed/comm/comm.py

+6
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro
224224
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
225225

226226

227+
@timed_op
228+
def broadcast_object_list(object_list, src, group=None, device=None):
229+
global cdb
230+
return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)
231+
232+
227233
@timed_op
228234
def all_gather(tensor_list,
229235
tensor,

deepspeed/comm/torch.py

+4
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ def broadcast(self, tensor, src, group=None, async_op=False):
205205
else:
206206
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
207207

208+
@disable_compiler_collective
209+
def broadcast_object_list(self, object_list, src, group=None, device=None):
210+
return torch.distributed.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)
211+
208212
@disable_compiler_collective
209213
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
210214
if DS_COMM_ALL_GATHER_OFF:

deepspeed/inference/engine.py

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
1616
from deepspeed.utils.timer import SynchronizedWallClockTimer
1717
from deepspeed.runtime.compiler import is_compile_supported
18-
1918
from ..runtime.state_dict_factory import SDLoaderFactory
2019
from ..runtime.weight_quantizer import WeightQuantization
2120
from ..module_inject import replace_transformer_layer, generic_injection

deepspeed/module_inject/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
77
from .module_quantize import quantize_transformer_layer
88
from .replace_policy import HFBertLayerPolicy
9-
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
9+
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
1010
from .policy import DSPolicy

deepspeed/module_inject/auto_tp.py

+30-59
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from typing import Optional
1212
import torch
1313
from deepspeed import comm as dist
14-
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
14+
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
1515
from deepspeed.accelerator import get_accelerator
16-
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
16+
from .fusedqkv_utils import require_tp_fused_qkvw
1717
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
18+
from deepspeed.utils import groups
19+
from deepspeed.module_inject.layers import is_autotp_training_mode
1820

1921

2022
def move(tensor, device, copy=True):
@@ -333,10 +335,18 @@ def tp_parser(model):
333335
return policy_list
334336

335337
def set_tensor_parallel_config(self, mp_size, mp_group):
338+
339+
if is_autotp_training_mode():
340+
self.mp_group = groups.get_tensor_model_parallel_group()
341+
self.mp_size = groups.get_tensor_model_parallel_world_size()
342+
return
343+
336344
self.mp_size = mp_size
337345
self.mp_group = mp_group
338346

339347
def _replace(self, child, name, conv_linear_layer):
348+
# This function should clearly define the routing rules for specific layers
349+
# and avoid any complex shard-related logic.
340350
if getattr(child, "replaced", False) == True:
341351
return
342352
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
@@ -352,80 +362,41 @@ def _replace(self, child, name, conv_linear_layer):
352362
# For Yuan model
353363
if 'Yuan' in str(self.module):
354364
if 'v_proj' in name:
355-
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
356-
dist.get_world_size(), True)
357-
return LinearLayer(weight=weight, bias=bias)
365+
return Yuan_LinearLayer(child, self.mp_group)
366+
358367
elif 'o_proj' in name:
359-
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
360-
dist.get_world_size(), False)
361-
return LinearAllreduce(weight, bias, self.mp_group)
362-
# For Arctic model, bypass to all_reduce replacement for w2 weights
368+
return Yuan_LinearAllreduce(child, self.mp_group)
369+
370+
# For MLP including chunk layer.
371+
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
372+
return GateUpPack_LinearLayer(child, self.mp_group)
373+
# For Arctic model, bypass to all_reduce replacement for w2 weights
363374
arctic_w2_all_reduce_linear = False
364375
if 'Arctic' in str(self.module) and 'w2' in name:
365376
arctic_w2_all_reduce_linear = True
366377
# For MoE MLP model, e.g., deepseek and jamba
367378
down_proj = False
368379
if 'down_proj' in name:
369380
down_proj = True
370-
# For MLP including chunk layer.
371-
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
372-
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
373-
return LinearLayer(weight=weight, bias=bias)
374381
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
375-
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
376-
# else [weight_shape[0], weight_shape[1] // mp_size]
377382

383+
setattr(child, "replaced", True)
378384
if self.conv_linear_layer:
379-
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
380-
data = child.weight.data.split(get_shard_size_list(
381-
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
382-
dim=1)
383-
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
384-
del data
385+
return Conv_LinearALlreduce(child, self.mp_group, name=name)
386+
elif name == "lm_head" or name == 'embed_out':
387+
return LmHeadLinearAllreduce(child, self.mp_group)
385388

386-
setattr(child, "replaced", True)
387-
if name == "lm_head" or name == 'embed_out':
388-
return LmHeadLinearAllreduce(
389-
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
390-
child.bias if child.bias is None else torch.nn.parameter.Parameter(
391-
move(child.bias, device_name, return_new_copy)), self.mp_group)
392-
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
393-
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
389+
return LinearAllreduce(child, self.mp_group, name=name)
394390
else:
395391

396-
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
397-
# else [weight_shape[0] // mp_size, weight_shape[1]]
392+
setattr(child, "replaced", True)
398393
if self.conv_linear_layer:
399-
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
400-
401-
if require_tp_fused_qkvw(name, self.mp_size):
394+
conv_LinearLayer(child, self.mp_group)
395+
elif require_tp_fused_qkvw(name, self.mp_size):
402396
#Check and handle fused qkv for TP
403-
#The copy is a regular copy, The shape of dst and src is the same
404-
data_dc = move(
405-
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
406-
device_name, return_new_copy)
407-
408-
bias_data_dc = None if child.bias is None else move(
409-
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
410-
device_name, return_new_copy)
411-
else:
412-
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
413-
dim=1 if self.conv_linear_layer else 0)
414-
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
415-
del data
416-
417-
if child.bias is not None:
418-
bias_data = child.bias.data.split(get_shard_size_list(
419-
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
420-
dim=0)
421-
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
422-
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
423-
del bias_data
424-
else:
425-
bias_data_dc = None
397+
return fused_LinearLayer(child, self.mp_group, fused_module=self.module)
426398

427-
setattr(child, "replaced", True)
428-
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)
399+
return LinearLayer(child, self.mp_group, name=name)
429400

430401
def _slice_embedding(self, child, name, conv_linear_layer):
431402
if getattr(child, "replaced", False) == True:

0 commit comments

Comments
 (0)