Skip to content

Commit ca3fe11

Browse files
authored
[Distributed]Add unbalance batch for virtual pp (#58383)
* add unbalanced batch for vpp * add unbalanced batch for vpp * add unbalanced batch for vpp
1 parent dcee90b commit ca3fe11

File tree

4 files changed

+259
-13
lines changed

4 files changed

+259
-13
lines changed

python/paddle/distributed/fleet/meta_parallel/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
from .tensor_parallel import TensorParallel # noqa: F401
2626
from .pipeline_parallel import PipelineParallel # noqa: F401
2727
from .pipeline_parallel import PipelineParallelWithInterleave # noqa: F401
28+
from .pipeline_parallel import (
29+
PipelineParallelWithInterleaveFthenB,
30+
) # noqa: F401
2831
from .sharding_parallel import ShardingParallel # noqa: F401
2932

3033
__all__ = []

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

+236-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313

1414
import os
15+
import queue
1516
import sys
1617
from collections import defaultdict
1718

@@ -768,23 +769,26 @@ class PipelineParallelWithInterleave(PipelineParallel):
768769
def __init__(self, layers, hcg, strategy):
769770
super().__init__(layers=layers, hcg=hcg, strategy=strategy)
770771
assert layers.get_num_virtual_stages() > 1
772+
self._check_sanity()
773+
self.num_model_chunks = layers.get_num_virtual_stages()
774+
self.model_chunks = layers.get_model_chunks()
775+
assert self.model_chunks is not None
776+
assert len(self.model_chunks) == self.num_model_chunks
777+
self._virtual_pp_world_size = self.num_model_chunks
778+
self._virtual_pp_rank = 0
779+
self._assign_vpp_info(self.model_chunks)
780+
781+
def _check_sanity(self):
771782
assert (
772783
framework.in_dygraph_mode()
773784
), "virtual pipeline stage with interleave only support eager dygraph mode"
785+
774786
assert (
775787
self.num_stages > 2
776788
), "virtual pipeline must run under pp degree > 2"
777789
assert (
778790
self.accumulate_steps % self.num_stages == 0
779791
), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
780-
# setup for interleave scheduler
781-
self.num_model_chunks = layers.get_num_virtual_stages()
782-
self.model_chunks = layers.get_model_chunks()
783-
assert self.model_chunks is not None
784-
assert len(self.model_chunks) == self.num_model_chunks
785-
self._virtual_pp_world_size = self.num_model_chunks
786-
self._virtual_pp_rank = 0
787-
self._assign_vpp_info(self.model_chunks)
788792

789793
def _assign_vpp_info(self, chunks):
790794
chunk_num = len(chunks)
@@ -1186,3 +1190,227 @@ def eval_batch(self, data, compute_loss=False):
11861190
self._compute_loss = compute_loss
11871191

11881192
return self.forward_backward_pipeline(data, None, forward_only=True)
1193+
1194+
1195+
class PipelineParallelWithInterleaveFthenB(PipelineParallelWithInterleave):
1196+
def __init__(self, layers, hcg, strategy):
1197+
super().__init__(layers=layers, hcg=hcg, strategy=strategy)
1198+
1199+
def _check_sanity(self):
1200+
assert (
1201+
framework.in_dygraph_mode()
1202+
), "virtual pipeline stage with interleave only support eager dygraph mode"
1203+
1204+
assert (
1205+
self.num_stages > 2
1206+
), "virtual pipeline must run under pp degree > 2"
1207+
1208+
def _get_virtual_pp_rank(self, micro_step, forward):
1209+
1210+
virtual_pp_stage = micro_step % (
1211+
self.accumulate_steps * self.num_model_chunks
1212+
)
1213+
virtual_pp_stage = virtual_pp_stage // self.accumulate_steps
1214+
if not forward:
1215+
virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1
1216+
1217+
return virtual_pp_stage
1218+
1219+
def _overlap_comm_grads(self):
1220+
if not self._comm_overlap:
1221+
return
1222+
self._backward_step_count += 1
1223+
sync_step = self._backward_step_count - self.stage_id
1224+
1225+
if sync_step > 0 and sync_step % self.accumulate_steps == 0:
1226+
chunk_idx = self._virtual_pp_world_size - (
1227+
sync_step // self.accumulate_steps
1228+
)
1229+
for buffer in self._chunk_2_comm_buffers[chunk_idx]:
1230+
buffer.comm_grads()
1231+
1232+
if self.stage_id == 0:
1233+
return
1234+
1235+
if (
1236+
self._backward_step_count
1237+
== self.accumulate_steps * self._virtual_pp_world_size
1238+
):
1239+
for buffer in self._chunk_2_comm_buffers[0]:
1240+
buffer.comm_grads()
1241+
1242+
def _sync_overlap_grads(self):
1243+
if not self._comm_overlap:
1244+
return
1245+
1246+
expected_count = self.accumulate_steps * self._virtual_pp_world_size
1247+
assert self._backward_step_count == expected_count, (
1248+
f"backward step count should be equal to accumulate steps * virtual pp world size, "
1249+
f"but got {self._backward_step_count}, expected result is {expected_count}"
1250+
)
1251+
1252+
for buffers in self._chunk_2_comm_buffers.values():
1253+
for buffer in buffers:
1254+
buffer.scale_and_split_grads()
1255+
1256+
def forward_backward_pipeline(
1257+
self, data, scaler, forward_only=False, compute_loss=True
1258+
):
1259+
if not compute_loss:
1260+
assert (
1261+
not forward_only
1262+
), "compute_loss can only be set to False when forward_only is set to True"
1263+
1264+
# NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled
1265+
assert (
1266+
self._using_cache
1267+
), "cache should be enabled for pipeline with interleave"
1268+
1269+
# init some attributes for this batch run
1270+
self.scaler = scaler
1271+
self.total_loss = None
1272+
self.micro_batch_id = 0
1273+
self._forward_only = forward_only
1274+
1275+
assert (
1276+
self.accumulate_steps >= self.num_stages
1277+
), "accumulate_steps({}) should be larger than num_stages({}) for pipeline with interleave".format(
1278+
self.accumulate_steps, self.num_stages
1279+
)
1280+
assert (
1281+
self.accumulate_steps < 2 * self.num_stages
1282+
), "accumulate_steps({}) should be smaller than 2 * num_stages({}) for pipeline with interleave".format(
1283+
self.accumulate_steps, self.num_stages
1284+
)
1285+
1286+
self._backward_step_count = 0
1287+
skip_steps = self.accumulate_steps - self.num_stages
1288+
send_recv_buffer_queue = queue.Queue()
1289+
1290+
# init some data buffers for interleave scheduler
1291+
self.input_tensors = [[] for _ in range(self.num_model_chunks)]
1292+
self.output_tensors = [[] for _ in range(self.num_model_chunks)]
1293+
self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]
1294+
1295+
micro_dataset = self._wrap_data(data)
1296+
num_steps = self.accumulate_steps * self.num_model_chunks
1297+
1298+
self.set_virtual_pipeline_rank(0)
1299+
self.input_tensors[0].append(
1300+
self._p2p_helper.recv_forward(
1301+
self.is_pipeline_first_stage(), sync_recv=False
1302+
)
1303+
)
1304+
1305+
# run startup steps
1306+
for micro_step in range(num_steps):
1307+
output_tensor = self._forward_step_helper(micro_dataset, micro_step)
1308+
# determine whether recv forward tensor or not
1309+
next_virtual_pp_rank = self._get_virtual_pp_rank(
1310+
micro_step + 1, forward=True
1311+
)
1312+
1313+
recv_prev = True
1314+
if self.is_pipeline_first_stage(ignore_virtual=True):
1315+
if next_virtual_pp_rank == 0:
1316+
# next chunk is the first chunk, not need to pre recv an input tensor
1317+
recv_prev = False
1318+
1319+
# last micro step, no next run
1320+
if micro_step == (num_steps - 1):
1321+
recv_prev = False
1322+
1323+
if self.is_pipeline_last_stage(ignore_virtual=True):
1324+
# last stage skip send/recv
1325+
if not self.is_pipeline_last_stage():
1326+
send_recv_buffer_queue.put(output_tensor)
1327+
1328+
if micro_step < skip_steps or (
1329+
self.is_pipeline_last_stage()
1330+
and micro_step % self.accumulate_steps >= skip_steps
1331+
):
1332+
output_tensor = None
1333+
else:
1334+
output_tensor = send_recv_buffer_queue.get()
1335+
1336+
input_tensor = self._p2p_helper.send_forward_recv_forward(
1337+
output_tensor, recv_prev=recv_prev
1338+
)
1339+
self.input_tensors[next_virtual_pp_rank].append(input_tensor)
1340+
1341+
self._release_output(output_tensor)
1342+
1343+
assert (
1344+
send_recv_buffer_queue.empty()
1345+
), "send_recv buffer should be empty"
1346+
1347+
# remaining backward steps
1348+
if not forward_only:
1349+
1350+
self.output_tensor_grads[self.num_model_chunks - 1].append(
1351+
self._p2p_helper.recv_backward(
1352+
self.is_pipeline_last_stage(), sync_recv=False
1353+
)
1354+
)
1355+
1356+
for micro_step in range(num_steps):
1357+
# cooldown loop
1358+
input_tensor_grad = self._backward_step_helper(micro_step)
1359+
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
1360+
micro_step + 1, forward=False
1361+
)
1362+
1363+
recv_next = True
1364+
if self.is_pipeline_last_stage(ignore_virtual=True):
1365+
if next_backward_virtual_pp_rank == (
1366+
self.num_model_chunks - 1
1367+
):
1368+
recv_next = False
1369+
1370+
if micro_step == (num_steps - 1):
1371+
recv_next = False
1372+
1373+
if self.is_pipeline_first_stage(ignore_virtual=True):
1374+
if not self.is_pipeline_first_stage():
1375+
send_recv_buffer_queue.put(input_tensor_grad)
1376+
1377+
if micro_step < skip_steps or (
1378+
self.is_pipeline_first_stage()
1379+
and micro_step % self.accumulate_steps >= skip_steps
1380+
):
1381+
input_tensor_grad = None
1382+
else:
1383+
input_tensor_grad = send_recv_buffer_queue.get()
1384+
1385+
self.output_tensor_grads[next_backward_virtual_pp_rank].append(
1386+
self._p2p_helper.send_backward_recv_backward(
1387+
input_tensor_grad, recv_next=recv_next
1388+
)
1389+
)
1390+
1391+
assert (
1392+
send_recv_buffer_queue.empty()
1393+
), "send_recv buffer should be empty"
1394+
1395+
self._sync_overlap_grads()
1396+
1397+
if self._enable_timer:
1398+
self.timers("allreduce_shared_weight_gradients").start()
1399+
self._layers.allreduce_shared_weight_gradients()
1400+
if self._enable_timer:
1401+
self.timers("allreduce_shared_weight_gradients").stop()
1402+
1403+
if compute_loss:
1404+
# return loss if compute loss
1405+
if self._enable_timer:
1406+
self.timers("broadcast_final_loss").start()
1407+
with paddle.amp.auto_cast(enable=False):
1408+
train_loss = self._broadcast_final_loss()
1409+
if self._enable_timer:
1410+
self.timers("broadcast_final_loss").stop()
1411+
else:
1412+
# else just return all intermediate output tensor for all micro steps
1413+
train_loss = self.output_tensors
1414+
1415+
self.timer_printer()
1416+
return train_loss

python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,9 @@ def send_forward_recv_forward(self, output_tensor, recv_prev):
608608
if _timers is not None:
609609
_timers("send_forward_recv_forward").start()
610610

611-
self._send_meta(output_tensor)
611+
if output_tensor is not None:
612+
self._send_meta(output_tensor)
613+
612614
if recv_prev:
613615
self._recv_meta()
614616

python/paddle/distributed/fleet/model.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PipelineLayer,
2121
PipelineParallel,
2222
PipelineParallelWithInterleave,
23+
PipelineParallelWithInterleaveFthenB,
2324
ShardingParallel,
2425
TensorParallel,
2526
)
@@ -164,9 +165,21 @@ def forward(self, x):
164165
# 1f1b pipeline
165166
model = PipelineParallel(model, fleet_env._hcg, strategy=strategy)
166167
else:
167-
# interleave pipeline
168-
model = PipelineParallelWithInterleave(
169-
model, fleet_env._hcg, strategy=strategy
170-
)
168+
accumulate_steps = strategy.pipeline_configs['accumulate_steps']
169+
pp_degree = fleet_env._hcg.get_pipe_parallel_world_size()
170+
if (
171+
accumulate_steps >= pp_degree
172+
and accumulate_steps < pp_degree * 2
173+
):
174+
# NOTE(shenliang03): Hacky for unbalanced pipeline parallel with interleave
175+
# Currently, we only support pp_degree <= accumulate_steps < 2 * pp_degree
176+
model = PipelineParallelWithInterleaveFthenB(
177+
model, fleet_env._hcg, strategy=strategy
178+
)
179+
else:
180+
# interleave pipeline
181+
model = PipelineParallelWithInterleave(
182+
model, fleet_env._hcg, strategy=strategy
183+
)
171184

172185
return model

0 commit comments

Comments
 (0)