Skip to content

Commit 2ffacc5

Browse files
ForFisheswentaoyu
authored andcommitted
part-5 cherry from: [Distributed]Add unbalance batch for virtual pp (PaddlePaddle#58383)
* add unbalanced batch for vpp * add unbalanced batch for vpp * add unbalanced batch for vpp
1 parent 69264b9 commit 2ffacc5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,7 @@ def __init__(self, layers, hcg, strategy):
912912
assert layers.get_num_virtual_stages() > 1
913913

914914
# setup for interleave scheduler
915+
self._check_sanity()
915916
self.num_model_chunks = layers.get_num_virtual_stages()
916917
self.model_chunks = layers.get_model_chunks()
917918
assert self.model_chunks is not None
@@ -920,7 +921,7 @@ def __init__(self, layers, hcg, strategy):
920921
self._virtual_pp_rank = 0
921922
self._reset_counter()
922923

923-
self._check_sanity()
924+
self._assign_vpp_info(self.model_chunks)
924925

925926
def _check_sanity(self):
926927
assert (

0 commit comments

Comments
 (0)