Skip to content

Commit

Permalink
Merge branch 'master' into olruwase/update_nvme_offload_states
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Mar 10, 2025
2 parents 044db61 + 8ec1af5 commit b0f1391
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 38 deletions.
9 changes: 3 additions & 6 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
from .layers import *
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(self,
self.orig_layer_impl = orig_layer_impl
self.linear_policies = None
self.conv_linear_layer = False
self.keep_module_on_host = keep_module_on_host
TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)

def in_module_list(module, module_list):
for item in module_list:
Expand Down Expand Up @@ -350,10 +350,7 @@ def _replace(self, child, name, conv_linear_layer):
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
return_new_copy = not self.keep_module_on_host

weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
Expand Down
88 changes: 56 additions & 32 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from copy import deepcopy
from typing import Union

__all__ = [
"TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce",
"Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer"
]

DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
DS_IS_REPLACED_MODULE = 'ds_is_replaced_module'
DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel'
Expand All @@ -43,26 +48,6 @@ def set_autotp_mode(training=False):
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE


def move(tensor, device):
# TODO: consider the timing of deletion
# to save host resources when DP > 1。

if tensor.is_meta:
# Keep tensor in meta device if tensor is meta.
return tensor
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
cloned_tensor = tensor.to(device, copy=True)

# free the memory of the original tensor to reduce memory peak
# Equivalent to directly deleting the tensor reference outside the function.
# see https://github.com/microsoft/DeepSpeed/pull/4353
tensor.data = torch.empty(0, device=tensor.device)
return cloned_tensor


class RowParallel(torch.autograd.Function):
"""
A custom autograd function for performing row-wise parallelism.
Expand Down Expand Up @@ -140,6 +125,10 @@ class TensorParallel_Layer(nn.Module, ABC):
name (Optional[str]): The name of the layer, if provided.
"""

# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
keep_module_on_host: bool = False

def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
"""
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
Expand All @@ -163,6 +152,16 @@ def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
if kwargs.get('name') is not None:
self.name = kwargs.get('name') # Set the layer name if provided.

@classmethod
def set_keep_module_on_host(cls, value: bool):
"""
Set the static variable keep_module_on_host.
Args:
value (bool): The new value for keep_module_on_host.
"""
cls.keep_module_on_host = value

@abstractmethod
def forward(self, input):
"""
Expand Down Expand Up @@ -235,6 +234,31 @@ def extra_repr(self):
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str

def move(self, tensor):
# TODO: consider the timing of deletion
# to save host resources when DP > 1。

# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
if tensor.is_meta:
# Keep tensor in meta device if tensor is meta.
return tensor
else:
device = 'cpu' if self.__class__.keep_module_on_host else get_accelerator().current_device_name()
return_new_copy = not self.__class__.keep_module_on_host

# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
cloned_tensor = tensor.to(device, copy=return_new_copy)

if return_new_copy:
# free the memory of the original tensor to reduce memory peak
# Equivalent to directly deleting the tensor reference outside the function.
# see https://github.com/microsoft/DeepSpeed/pull/4353
tensor.data = torch.empty(0, device=tensor.device)
return cloned_tensor


class GatherReplacedLayerParams:
"""
Expand Down Expand Up @@ -349,7 +373,7 @@ def _tp_partition(self, params_list):
return
_partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index]

_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()

params_list[idx].data = _partition

Expand All @@ -363,7 +387,7 @@ def uneven_partition(self, params_list):
self.name),
dim=1)[self.tp_index]

_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
params_list[idx].data = _partition


Expand Down Expand Up @@ -414,7 +438,7 @@ def _tp_partition(self, params_list):
#split bias if provide
_partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index]

_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()

params_list[idx].data = _partition

Expand All @@ -429,7 +453,7 @@ def uneven_partition(self, params_list):
self.name),
dim=0)[self.tp_index]

_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()

params_list[idx].data = _partition

Expand Down Expand Up @@ -475,7 +499,7 @@ def _tp_partition(self, params_list):

_partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index)

_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()

params_list[idx].data = _partition

Expand All @@ -492,13 +516,13 @@ def _tp_partition(self, params_list):
weight, bias = params_list[0], params_list[1]
_partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name),
dim=1)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()
weight.data = _partition

if bias is not None:
_partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name),
dim=0)[self.tp_index]
_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()

bias.data = _partition

Expand All @@ -522,19 +546,19 @@ class Yuan_LinearLayer(LinearLayer):
def _tp_partition(self, params_list):
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
self.tp_world_size, True)
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach()
params_list[0].data = self.move(weight).detach()
if bias is not None:
params_list[1].data = move(bias, get_accelerator().current_device_name()).detach()
params_list[1].data = self.move(bias).detach()


class GateUpPack_LinearLayer(LinearLayer):
# chatGLM2, chatGLM2
@torch.no_grad()
def _tp_partition(self, params_list):
weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach()
params_list[0].data = self.move(weight).detach()
if bias is not None:
params_list[1].data = move(bias, device=get_accelerator().current_device_name()).detach()
params_list[1].data = self.move(bias).detach()


class Conv_LinearALlreduce(LinearAllreduce):
Expand All @@ -549,7 +573,7 @@ def _tp_partition(self, params_list):
_partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name),
dim=1)[self.tp_index]

_partition = move(_partition, get_accelerator().current_device_name()).detach()
_partition = self.move(_partition).detach()

params_list[idx].data = _partition

Expand Down

0 comments on commit b0f1391

Please sign in to comment.