From 349b2cc365194a4c24a1c9cd6104f8ef2c3e7617 Mon Sep 17 00:00:00 2001 From: ForFishes <2282912238@qq.com> Date: Tue, 19 Dec 2023 12:07:03 +0800 Subject: [PATCH] fix the limitation of fthenb schedule --- .../fleet/meta_parallel/pipeline_parallel.py | 10 +++------- python/paddle/distributed/fleet/model.py | 13 ++++++------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 44e72ad668415c..90c519c07a8719 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -1288,13 +1288,9 @@ def forward_backward_pipeline( self._forward_only = forward_only assert ( - self.accumulate_steps >= self.num_stages - ), "accumulate_steps({}) should be larger than num_stages({}) for pipeline with interleave".format( - self.accumulate_steps, self.num_stages - ) - assert ( - self.accumulate_steps < 2 * self.num_stages - ), "accumulate_steps({}) should be smaller than 2 * num_stages({}) for pipeline with interleave".format( + self.accumulate_steps == self.num_stages + or self.accumulate_steps % self.num_stages != 0 + ), "accumulate_steps({}) and num_stages({}) should be a multiple or accumulate_steps % num_stages == 0".format( self.accumulate_steps, self.num_stages ) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 17f2dc21bf46c4..4bd87e70eee33d 100755 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -168,17 +168,16 @@ def forward(self, x): accumulate_steps = strategy.pipeline_configs['accumulate_steps'] pp_degree = fleet_env._hcg.get_pipe_parallel_world_size() if ( - accumulate_steps >= pp_degree - and accumulate_steps < pp_degree * 2 + accumulate_steps > pp_degree + and accumulate_steps % pp_degree == 0 ): - # NOTE(shenliang03): Hacky for unbalanced pipeline parallel with interleave - # Currently, we only support pp_degree <= accumulate_steps < 2 * pp_degree - model = PipelineParallelWithInterleaveFthenB( + # interleave pipeline + model = PipelineParallelWithInterleave( model, fleet_env._hcg, strategy=strategy ) else: - # interleave pipeline - model = PipelineParallelWithInterleave( + # NOTE(shenliang03): Hacky for unbalanced pipeline parallel with interleave + model = PipelineParallelWithInterleaveFthenB( model, fleet_env._hcg, strategy=strategy )