From cfaa001630256550d7544c36bab659acb6edf582 Mon Sep 17 00:00:00 2001 From: yinwei Date: Mon, 18 Mar 2024 11:20:30 +0800 Subject: [PATCH] Sequence Parallel Support Overlap (#62284) * update sequence_parallel_utils.py --- .../fleet/utils/sequence_parallel_utils.py | 186 ++++++++- ...arallel_mp_model_with_sequence_parallel.py | 369 ++++++++++++++++++ 2 files changed, 544 insertions(+), 11 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py b/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py index 940d7408ff5be..96d511f2dc06c 100644 --- a/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py +++ b/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py @@ -227,6 +227,158 @@ def is_fused_matmul_bias_supported(): return False +def is_fused_linear_param_grad_add_supported(): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): + return hasattr(paddle._C_ops, 'fused_linear_param_grad_add') + else: + return False + + +class SPInnerOverlapLinear(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + fuse_matmul_bias, + mp_fused_linear_param_grad_add, + model_parallel_group, + ): + ctx.mp_fused_linear_param_grad_add = mp_fused_linear_param_grad_add + ctx.model_parallel_group = model_parallel_group + + world_size = model_parallel_group.nranks + input_parallel = all_gather(x) + + ctx.save_for_backward(x, weight, bias, input_parallel) + if not fuse_matmul_bias: + output = paddle._C_ops.linear(input_parallel, weight, bias) + else: + output = paddle._legacy_C_ops.fused_gemm_epilogue( + input_parallel, weight, bias + ) + return output + + @staticmethod + def backward(ctx, dy): + x, weight, bias, input_parallel = ctx.saved_tensor() + parallelism = ctx.model_parallel_group.nranks + + if dy.dtype == weight.dtype: + dinput_parallel = paddle.matmul(dy, weight, transpose_y=True) + else: + dinput_parallel = paddle.matmul( + dy, paddle.cast(weight, dtype=dy.dtype), transpose_y=True + ) + + assert ( + dinput_parallel.shape[0] % parallelism == 0 + ), "Input sequence length {} can't be divided exactly by sequence parallelism {}".format( + dinput_parallel.shape[0], parallelism + ) + + dx_shape = dinput_parallel.shape + dx_shape[0] = dx_shape[0] // parallelism + dx = paddle.empty(shape=dx_shape, dtype=dinput_parallel.dtype) + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + task = dist.stream.reduce_scatter( + dx, + dinput_parallel, + op=dist.ReduceOp.SUM, + group=group, + sync_op=False, + ) + + if ctx.mp_fused_linear_param_grad_add: + if not is_fused_linear_param_grad_add_supported(): + raise NotImplementedError( + "You set mp_fused_linear_param_grad_add=True, " + "however, the paddle you are using not support this operation. " + "Please unset fused_linear_param_grad_add or use paddle compiled " + "with cuda 11.6 or higher." + ) + if bias is None: + if hasattr(weight, "main_grad"): + ( + weight.main_grad, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + input_parallel, dy, weight.main_grad, None, True, False + ) + task.wait() + return dx, None + else: + if weight.grad is not None: + ( + weight.grad, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + input_parallel, dy, weight.grad, None, False, False + ) + task.wait() + return dx, None + else: + ( + dw, + _, + ) = paddle._C_ops.fused_linear_param_grad_add( + input_parallel, dy, None, None, False, False + ) + task.wait() + return dx, dw + + if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"): + ( + weight.main_grad, + bias.main_grad, + ) = paddle._C_ops.fused_linear_param_grad_add( + input_parallel, + dy, + weight.main_grad, + bias.main_grad, + True, + True, + ) + task.wait() + return dx, None, None + else: + if weight.grad is not None: + assert bias.grad is not None + ( + weight.grad, + bias.grad, + ) = paddle._C_ops.fused_linear_param_grad_add( + input_parallel, dy, weight.grad, bias.grad, False, True + ) + task.wait() + return dx, None, None + else: + ( + dw, + dbias, + ) = paddle._C_ops.fused_linear_param_grad_add( + input_parallel, dy, None, None, False, True + ) + task.wait() + return dx, dw, dbias + else: + dy = dy.reshape([-1, dy.shape[-1]]) + dw = paddle.matmul( + input_parallel.reshape([-1, input_parallel.shape[-1]]), + dy, + transpose_x=True, + ) + if bias is None: + task.wait() + return dx, dw + else: + dbias = paddle.sum(dy, axis=0) + task.wait() + return dx, dw, dbias + + class ColumnSequenceParallelLinear(Layer): def __init__( self, @@ -250,9 +402,12 @@ def __init__( if mp_group is None else mp_group.nranks ) + assert ( + self.world_size > 1 + ), "tensor parallel degree must be greater than 1 in sequence parallel" + self._name = name self.is_mp = self.world_size > 1 - assert ( gather_output is False ), "If sequence_parallel is True, \ @@ -285,6 +440,7 @@ def __init__( ) self.weight.is_distributed = True if self.is_mp else False + self.fuse_matmul_bias = fuse_matmul_bias if has_bias: # initialize bias to zero like Megatron @@ -312,18 +468,26 @@ def __init__( self.linear = fused_linear + mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[ + "mp_configs" + ] + self.mp_async_allreduce = mp_configs.mp_async_allreduce + + self.mp_fused_linear_param_grad_add = ( + self.mp_async_allreduce + and mp_configs.mp_fused_linear_param_grad_add + ) + def forward(self, x): - # sequence parallelism is same as model parallelism - # if sequence parallel is true, input shape is [s, b, h] - # else input shape is [b, s, h] - if self.is_mp: - input_parallel = AllGatherOp.apply(x) - else: - input_parallel = x - output = self.linear( - input_parallel, self.weight, self.bias, name=self._name + # sequence parallelism is same as model parallelis, if sequence parallel is true, input shape is [s, b, h],else input shape is [b, s, h] + return SPInnerOverlapLinear.apply( + x, + self.weight, + self.bias, + self.fuse_matmul_bias, + self.mp_fused_linear_param_grad_add, + self.model_parallel_group, ) - return output class MPScale(PyLayer): diff --git a/test/collective/fleet/hybrid_parallel_mp_model_with_sequence_parallel.py b/test/collective/fleet/hybrid_parallel_mp_model_with_sequence_parallel.py index 4ff3c4a87fbb6..13d2a647cf1c2 100644 --- a/test/collective/fleet/hybrid_parallel_mp_model_with_sequence_parallel.py +++ b/test/collective/fleet/hybrid_parallel_mp_model_with_sequence_parallel.py @@ -21,6 +21,10 @@ import paddle.distributed as dist from paddle.distributed import fleet from paddle.distributed.fleet.utils import sequence_parallel_utils as spu +from paddle.distributed.fleet.utils.mix_precision_utils import ( + MixPrecisionLayer, + MixPrecisionOptimizer, +) def set_random_seed(seed, dp_id, rank_id): @@ -475,5 +479,370 @@ def test_mp_model(self): ) +class TestDistSPTraining2(TestDistSPTraining): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "mp_async_allreduce": True, + "mp_fused_linear_param_grad_add": True, + }, + } + fleet.init(is_collective=True, strategy=strategy) + + def build_model_optimizer(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleSPNet( + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ) + model_a = MixPrecisionLayer(model_a) + optimizer_a = self.build_optimizer(model_a) + optimizer_a = MixPrecisionOptimizer(optimizer_a) + + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_b = self.build_optimizer(model_b) + + return model_a, optimizer_a, model_b, optimizer_b + + +class TestDistSPTraining3(TestDistSPTraining): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "mp_async_allreduce": True, + "mp_fused_linear_param_grad_add": True, + }, + } + fleet.init(is_collective=True, strategy=strategy) + + def build_model_optimizer(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleSPNet( + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ) + optimizer_a = self.build_optimizer(model_a) + + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_b = self.build_optimizer(model_b) + + return model_a, optimizer_a, model_b, optimizer_b + + +class SimpleSPNetWithoutBias(paddle.nn.Layer): + def __init__( + self, + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ): + super().__init__() + + if mp_id == 0: + init_fc1_data = np_fc1[:, : (inner_size // 2)] + init_fc2_data = np_fc2[: (inner_size // 2), :] + else: + init_fc1_data = np_fc1[:, (inner_size // 2) :] + init_fc2_data = np_fc2[(inner_size // 2) :, :] + + self.embedding = fleet.meta_parallel.VocabParallelEmbedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + self.linear1 = spu.ColumnSequenceParallelLinear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc1_data) + ), + gather_output=False, + has_bias=False, + ) + + self.linear2 = spu.RowSequenceParallelLinear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc2_data) + ), + input_is_parallel=True, + has_bias=False, + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.norm = paddle.nn.LayerNorm(hidden_size, epsilon=1e-5) + # if sequence parallel is true, + # register hook to all_reduce gradient of weight, bias + spu.mark_as_sequence_parallel_parameter(self.norm.weight) + spu.mark_as_sequence_parallel_parameter(self.norm.bias) + + spu.register_sequence_parallel_allreduce_hooks(self, 1, False) + + def forward(self, x): + x = self.embedding(x) + + x = paddle.transpose(x, perm=[1, 0, 2]) + x = spu.ScatterOp.apply(x) + + x = self.linear1(x) + x = self.linear2(x) + x = self.norm(x) + x = self.linear3(x) + + x = paddle.transpose(x, perm=[1, 0, 2]) + + x = parallel_matmul(x, self.embedding.weight, False) + return x + + +class SimpleDPNetWithoutBias(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1) + ), + ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2) + ), + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.norm = paddle.nn.LayerNorm(hidden_size, epsilon=1e-5) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.norm(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class TestDistSPTrainingWithoutBias(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "mp_async_allreduce": False, + "mp_fused_linear_param_grad_add": False, + }, + } + fleet.init(is_collective=True, strategy=strategy) + + def train_batch(self, batch, model, optimizer, is_mp): + output = model(batch) + loss = output.mean() + loss.backward() # do backward + optimizer.step() # update parameters + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + optimizer = paddle.optimizer.SGD( + learning_rate=0.001, parameters=model.parameters() + ) + return optimizer + + def build_model_optimizer(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleSPNetWithoutBias( + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ) + optimizer_a = self.build_optimizer(model_a) + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNetWithoutBias( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_b = self.build_optimizer(model_b) + + return model_a, optimizer_a, model_b, optimizer_b + + def test_mp_model(self): + ( + model_a, + optimizer_a, + model_b, + optimizer_b, + ) = self.build_model_optimizer() + + for _ in range(5): + np_data = np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + batch = paddle.to_tensor(np_data) + loss_a = self.train_batch(batch, model_a, optimizer_a, True) + loss_b = self.train_batch(batch, model_b, optimizer_b, False) + + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=1e-5, atol=1e-5 + ) + + +class TestDistSPTrainingWithoutBias2(TestDistSPTrainingWithoutBias): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "mp_async_allreduce": True, + "mp_fused_linear_param_grad_add": True, + }, + } + fleet.init(is_collective=True, strategy=strategy) + + +class TestDistSPTrainingWithoutBias3(TestDistSPTrainingWithoutBias2): + def build_model_optimizer(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleSPNetWithoutBias( + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ) + model_a = MixPrecisionLayer(model_a) + optimizer_a = self.build_optimizer(model_a) + optimizer_a = MixPrecisionOptimizer(optimizer_a) + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNetWithoutBias( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_b = self.build_optimizer(model_b) + + return model_a, optimizer_a, model_b, optimizer_b + + if __name__ == "__main__": unittest.main()