Skip to content

Commit

Permalink
Sequence Parallel Support Overlap (#62284)
Browse files Browse the repository at this point in the history
* update sequence_parallel_utils.py
  • Loading branch information
iosmers authored Mar 18, 2024
1 parent d57a869 commit cfaa001
Show file tree
Hide file tree
Showing 2 changed files with 544 additions and 11 deletions.
186 changes: 175 additions & 11 deletions python/paddle/distributed/fleet/utils/sequence_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,158 @@ def is_fused_matmul_bias_supported():
return False


def is_fused_linear_param_grad_add_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
return hasattr(paddle._C_ops, 'fused_linear_param_grad_add')
else:
return False


class SPInnerOverlapLinear(paddle.autograd.PyLayer):
@staticmethod
def forward(
ctx,
x,
weight,
bias,
fuse_matmul_bias,
mp_fused_linear_param_grad_add,
model_parallel_group,
):
ctx.mp_fused_linear_param_grad_add = mp_fused_linear_param_grad_add
ctx.model_parallel_group = model_parallel_group

world_size = model_parallel_group.nranks
input_parallel = all_gather(x)

ctx.save_for_backward(x, weight, bias, input_parallel)
if not fuse_matmul_bias:
output = paddle._C_ops.linear(input_parallel, weight, bias)
else:
output = paddle._legacy_C_ops.fused_gemm_epilogue(
input_parallel, weight, bias
)
return output

@staticmethod
def backward(ctx, dy):
x, weight, bias, input_parallel = ctx.saved_tensor()
parallelism = ctx.model_parallel_group.nranks

if dy.dtype == weight.dtype:
dinput_parallel = paddle.matmul(dy, weight, transpose_y=True)
else:
dinput_parallel = paddle.matmul(
dy, paddle.cast(weight, dtype=dy.dtype), transpose_y=True
)

assert (
dinput_parallel.shape[0] % parallelism == 0
), "Input sequence length {} can't be divided exactly by sequence parallelism {}".format(
dinput_parallel.shape[0], parallelism
)

dx_shape = dinput_parallel.shape
dx_shape[0] = dx_shape[0] // parallelism
dx = paddle.empty(shape=dx_shape, dtype=dinput_parallel.dtype)
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_model_parallel_group()
task = dist.stream.reduce_scatter(
dx,
dinput_parallel,
op=dist.ReduceOp.SUM,
group=group,
sync_op=False,
)

if ctx.mp_fused_linear_param_grad_add:
if not is_fused_linear_param_grad_add_supported():
raise NotImplementedError(
"You set mp_fused_linear_param_grad_add=True, "
"however, the paddle you are using not support this operation. "
"Please unset fused_linear_param_grad_add or use paddle compiled "
"with cuda 11.6 or higher."
)
if bias is None:
if hasattr(weight, "main_grad"):
(
weight.main_grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
input_parallel, dy, weight.main_grad, None, True, False
)
task.wait()
return dx, None
else:
if weight.grad is not None:
(
weight.grad,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
input_parallel, dy, weight.grad, None, False, False
)
task.wait()
return dx, None
else:
(
dw,
_,
) = paddle._C_ops.fused_linear_param_grad_add(
input_parallel, dy, None, None, False, False
)
task.wait()
return dx, dw

if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
(
weight.main_grad,
bias.main_grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input_parallel,
dy,
weight.main_grad,
bias.main_grad,
True,
True,
)
task.wait()
return dx, None, None
else:
if weight.grad is not None:
assert bias.grad is not None
(
weight.grad,
bias.grad,
) = paddle._C_ops.fused_linear_param_grad_add(
input_parallel, dy, weight.grad, bias.grad, False, True
)
task.wait()
return dx, None, None
else:
(
dw,
dbias,
) = paddle._C_ops.fused_linear_param_grad_add(
input_parallel, dy, None, None, False, True
)
task.wait()
return dx, dw, dbias
else:
dy = dy.reshape([-1, dy.shape[-1]])
dw = paddle.matmul(
input_parallel.reshape([-1, input_parallel.shape[-1]]),
dy,
transpose_x=True,
)
if bias is None:
task.wait()
return dx, dw
else:
dbias = paddle.sum(dy, axis=0)
task.wait()
return dx, dw, dbias


class ColumnSequenceParallelLinear(Layer):
def __init__(
self,
Expand All @@ -250,9 +402,12 @@ def __init__(
if mp_group is None
else mp_group.nranks
)
assert (
self.world_size > 1
), "tensor parallel degree must be greater than 1 in sequence parallel"

self._name = name
self.is_mp = self.world_size > 1

assert (
gather_output is False
), "If sequence_parallel is True, \
Expand Down Expand Up @@ -285,6 +440,7 @@ def __init__(
)

self.weight.is_distributed = True if self.is_mp else False
self.fuse_matmul_bias = fuse_matmul_bias

if has_bias:
# initialize bias to zero like Megatron
Expand Down Expand Up @@ -312,18 +468,26 @@ def __init__(

self.linear = fused_linear

mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[
"mp_configs"
]
self.mp_async_allreduce = mp_configs.mp_async_allreduce

self.mp_fused_linear_param_grad_add = (
self.mp_async_allreduce
and mp_configs.mp_fused_linear_param_grad_add
)

def forward(self, x):
# sequence parallelism is same as model parallelism
# if sequence parallel is true, input shape is [s, b, h]
# else input shape is [b, s, h]
if self.is_mp:
input_parallel = AllGatherOp.apply(x)
else:
input_parallel = x
output = self.linear(
input_parallel, self.weight, self.bias, name=self._name
# sequence parallelism is same as model parallelis, if sequence parallel is true, input shape is [s, b, h],else input shape is [b, s, h]
return SPInnerOverlapLinear.apply(
x,
self.weight,
self.bias,
self.fuse_matmul_bias,
self.mp_fused_linear_param_grad_add,
self.model_parallel_group,
)
return output


class MPScale(PyLayer):
Expand Down
Loading

0 comments on commit cfaa001

Please sign in to comment.