|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 |
|
14 | 14 | import os
|
| 15 | +import queue |
15 | 16 | import sys
|
16 | 17 | from collections import defaultdict
|
17 | 18 |
|
@@ -768,23 +769,26 @@ class PipelineParallelWithInterleave(PipelineParallel):
|
768 | 769 | def __init__(self, layers, hcg, strategy):
|
769 | 770 | super().__init__(layers=layers, hcg=hcg, strategy=strategy)
|
770 | 771 | assert layers.get_num_virtual_stages() > 1
|
| 772 | + self._check_sanity() |
| 773 | + self.num_model_chunks = layers.get_num_virtual_stages() |
| 774 | + self.model_chunks = layers.get_model_chunks() |
| 775 | + assert self.model_chunks is not None |
| 776 | + assert len(self.model_chunks) == self.num_model_chunks |
| 777 | + self._virtual_pp_world_size = self.num_model_chunks |
| 778 | + self._virtual_pp_rank = 0 |
| 779 | + self._assign_vpp_info(self.model_chunks) |
| 780 | + |
| 781 | + def _check_sanity(self): |
771 | 782 | assert (
|
772 | 783 | framework.in_dygraph_mode()
|
773 | 784 | ), "virtual pipeline stage with interleave only support eager dygraph mode"
|
| 785 | + |
774 | 786 | assert (
|
775 | 787 | self.num_stages > 2
|
776 | 788 | ), "virtual pipeline must run under pp degree > 2"
|
777 | 789 | assert (
|
778 | 790 | self.accumulate_steps % self.num_stages == 0
|
779 | 791 | ), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
|
780 |
| - # setup for interleave scheduler |
781 |
| - self.num_model_chunks = layers.get_num_virtual_stages() |
782 |
| - self.model_chunks = layers.get_model_chunks() |
783 |
| - assert self.model_chunks is not None |
784 |
| - assert len(self.model_chunks) == self.num_model_chunks |
785 |
| - self._virtual_pp_world_size = self.num_model_chunks |
786 |
| - self._virtual_pp_rank = 0 |
787 |
| - self._assign_vpp_info(self.model_chunks) |
788 | 792 |
|
789 | 793 | def _assign_vpp_info(self, chunks):
|
790 | 794 | chunk_num = len(chunks)
|
@@ -1186,3 +1190,227 @@ def eval_batch(self, data, compute_loss=False):
|
1186 | 1190 | self._compute_loss = compute_loss
|
1187 | 1191 |
|
1188 | 1192 | return self.forward_backward_pipeline(data, None, forward_only=True)
|
| 1193 | + |
| 1194 | + |
| 1195 | +class PipelineParallelWithInterleaveFthenB(PipelineParallelWithInterleave): |
| 1196 | + def __init__(self, layers, hcg, strategy): |
| 1197 | + super().__init__(layers=layers, hcg=hcg, strategy=strategy) |
| 1198 | + |
| 1199 | + def _check_sanity(self): |
| 1200 | + assert ( |
| 1201 | + framework.in_dygraph_mode() |
| 1202 | + ), "virtual pipeline stage with interleave only support eager dygraph mode" |
| 1203 | + |
| 1204 | + assert ( |
| 1205 | + self.num_stages > 2 |
| 1206 | + ), "virtual pipeline must run under pp degree > 2" |
| 1207 | + |
| 1208 | + def _get_virtual_pp_rank(self, micro_step, forward): |
| 1209 | + |
| 1210 | + virtual_pp_stage = micro_step % ( |
| 1211 | + self.accumulate_steps * self.num_model_chunks |
| 1212 | + ) |
| 1213 | + virtual_pp_stage = virtual_pp_stage // self.accumulate_steps |
| 1214 | + if not forward: |
| 1215 | + virtual_pp_stage = self.num_model_chunks - virtual_pp_stage - 1 |
| 1216 | + |
| 1217 | + return virtual_pp_stage |
| 1218 | + |
| 1219 | + def _overlap_comm_grads(self): |
| 1220 | + if not self._comm_overlap: |
| 1221 | + return |
| 1222 | + self._backward_step_count += 1 |
| 1223 | + sync_step = self._backward_step_count - self.stage_id |
| 1224 | + |
| 1225 | + if sync_step > 0 and sync_step % self.accumulate_steps == 0: |
| 1226 | + chunk_idx = self._virtual_pp_world_size - ( |
| 1227 | + sync_step // self.accumulate_steps |
| 1228 | + ) |
| 1229 | + for buffer in self._chunk_2_comm_buffers[chunk_idx]: |
| 1230 | + buffer.comm_grads() |
| 1231 | + |
| 1232 | + if self.stage_id == 0: |
| 1233 | + return |
| 1234 | + |
| 1235 | + if ( |
| 1236 | + self._backward_step_count |
| 1237 | + == self.accumulate_steps * self._virtual_pp_world_size |
| 1238 | + ): |
| 1239 | + for buffer in self._chunk_2_comm_buffers[0]: |
| 1240 | + buffer.comm_grads() |
| 1241 | + |
| 1242 | + def _sync_overlap_grads(self): |
| 1243 | + if not self._comm_overlap: |
| 1244 | + return |
| 1245 | + |
| 1246 | + expected_count = self.accumulate_steps * self._virtual_pp_world_size |
| 1247 | + assert self._backward_step_count == expected_count, ( |
| 1248 | + f"backward step count should be equal to accumulate steps * virtual pp world size, " |
| 1249 | + f"but got {self._backward_step_count}, expected result is {expected_count}" |
| 1250 | + ) |
| 1251 | + |
| 1252 | + for buffers in self._chunk_2_comm_buffers.values(): |
| 1253 | + for buffer in buffers: |
| 1254 | + buffer.scale_and_split_grads() |
| 1255 | + |
| 1256 | + def forward_backward_pipeline( |
| 1257 | + self, data, scaler, forward_only=False, compute_loss=True |
| 1258 | + ): |
| 1259 | + if not compute_loss: |
| 1260 | + assert ( |
| 1261 | + not forward_only |
| 1262 | + ), "compute_loss can only be set to False when forward_only is set to True" |
| 1263 | + |
| 1264 | + # NOTE(shenliang03): Due to ring_exchange for pipeline with interleave, cache should be enabled |
| 1265 | + assert ( |
| 1266 | + self._using_cache |
| 1267 | + ), "cache should be enabled for pipeline with interleave" |
| 1268 | + |
| 1269 | + # init some attributes for this batch run |
| 1270 | + self.scaler = scaler |
| 1271 | + self.total_loss = None |
| 1272 | + self.micro_batch_id = 0 |
| 1273 | + self._forward_only = forward_only |
| 1274 | + |
| 1275 | + assert ( |
| 1276 | + self.accumulate_steps >= self.num_stages |
| 1277 | + ), "accumulate_steps({}) should be larger than num_stages({}) for pipeline with interleave".format( |
| 1278 | + self.accumulate_steps, self.num_stages |
| 1279 | + ) |
| 1280 | + assert ( |
| 1281 | + self.accumulate_steps < 2 * self.num_stages |
| 1282 | + ), "accumulate_steps({}) should be smaller than 2 * num_stages({}) for pipeline with interleave".format( |
| 1283 | + self.accumulate_steps, self.num_stages |
| 1284 | + ) |
| 1285 | + |
| 1286 | + self._backward_step_count = 0 |
| 1287 | + skip_steps = self.accumulate_steps - self.num_stages |
| 1288 | + send_recv_buffer_queue = queue.Queue() |
| 1289 | + |
| 1290 | + # init some data buffers for interleave scheduler |
| 1291 | + self.input_tensors = [[] for _ in range(self.num_model_chunks)] |
| 1292 | + self.output_tensors = [[] for _ in range(self.num_model_chunks)] |
| 1293 | + self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)] |
| 1294 | + |
| 1295 | + micro_dataset = self._wrap_data(data) |
| 1296 | + num_steps = self.accumulate_steps * self.num_model_chunks |
| 1297 | + |
| 1298 | + self.set_virtual_pipeline_rank(0) |
| 1299 | + self.input_tensors[0].append( |
| 1300 | + self._p2p_helper.recv_forward( |
| 1301 | + self.is_pipeline_first_stage(), sync_recv=False |
| 1302 | + ) |
| 1303 | + ) |
| 1304 | + |
| 1305 | + # run startup steps |
| 1306 | + for micro_step in range(num_steps): |
| 1307 | + output_tensor = self._forward_step_helper(micro_dataset, micro_step) |
| 1308 | + # determine whether recv forward tensor or not |
| 1309 | + next_virtual_pp_rank = self._get_virtual_pp_rank( |
| 1310 | + micro_step + 1, forward=True |
| 1311 | + ) |
| 1312 | + |
| 1313 | + recv_prev = True |
| 1314 | + if self.is_pipeline_first_stage(ignore_virtual=True): |
| 1315 | + if next_virtual_pp_rank == 0: |
| 1316 | + # next chunk is the first chunk, not need to pre recv an input tensor |
| 1317 | + recv_prev = False |
| 1318 | + |
| 1319 | + # last micro step, no next run |
| 1320 | + if micro_step == (num_steps - 1): |
| 1321 | + recv_prev = False |
| 1322 | + |
| 1323 | + if self.is_pipeline_last_stage(ignore_virtual=True): |
| 1324 | + # last stage skip send/recv |
| 1325 | + if not self.is_pipeline_last_stage(): |
| 1326 | + send_recv_buffer_queue.put(output_tensor) |
| 1327 | + |
| 1328 | + if micro_step < skip_steps or ( |
| 1329 | + self.is_pipeline_last_stage() |
| 1330 | + and micro_step % self.accumulate_steps >= skip_steps |
| 1331 | + ): |
| 1332 | + output_tensor = None |
| 1333 | + else: |
| 1334 | + output_tensor = send_recv_buffer_queue.get() |
| 1335 | + |
| 1336 | + input_tensor = self._p2p_helper.send_forward_recv_forward( |
| 1337 | + output_tensor, recv_prev=recv_prev |
| 1338 | + ) |
| 1339 | + self.input_tensors[next_virtual_pp_rank].append(input_tensor) |
| 1340 | + |
| 1341 | + self._release_output(output_tensor) |
| 1342 | + |
| 1343 | + assert ( |
| 1344 | + send_recv_buffer_queue.empty() |
| 1345 | + ), "send_recv buffer should be empty" |
| 1346 | + |
| 1347 | + # remaining backward steps |
| 1348 | + if not forward_only: |
| 1349 | + |
| 1350 | + self.output_tensor_grads[self.num_model_chunks - 1].append( |
| 1351 | + self._p2p_helper.recv_backward( |
| 1352 | + self.is_pipeline_last_stage(), sync_recv=False |
| 1353 | + ) |
| 1354 | + ) |
| 1355 | + |
| 1356 | + for micro_step in range(num_steps): |
| 1357 | + # cooldown loop |
| 1358 | + input_tensor_grad = self._backward_step_helper(micro_step) |
| 1359 | + next_backward_virtual_pp_rank = self._get_virtual_pp_rank( |
| 1360 | + micro_step + 1, forward=False |
| 1361 | + ) |
| 1362 | + |
| 1363 | + recv_next = True |
| 1364 | + if self.is_pipeline_last_stage(ignore_virtual=True): |
| 1365 | + if next_backward_virtual_pp_rank == ( |
| 1366 | + self.num_model_chunks - 1 |
| 1367 | + ): |
| 1368 | + recv_next = False |
| 1369 | + |
| 1370 | + if micro_step == (num_steps - 1): |
| 1371 | + recv_next = False |
| 1372 | + |
| 1373 | + if self.is_pipeline_first_stage(ignore_virtual=True): |
| 1374 | + if not self.is_pipeline_first_stage(): |
| 1375 | + send_recv_buffer_queue.put(input_tensor_grad) |
| 1376 | + |
| 1377 | + if micro_step < skip_steps or ( |
| 1378 | + self.is_pipeline_first_stage() |
| 1379 | + and micro_step % self.accumulate_steps >= skip_steps |
| 1380 | + ): |
| 1381 | + input_tensor_grad = None |
| 1382 | + else: |
| 1383 | + input_tensor_grad = send_recv_buffer_queue.get() |
| 1384 | + |
| 1385 | + self.output_tensor_grads[next_backward_virtual_pp_rank].append( |
| 1386 | + self._p2p_helper.send_backward_recv_backward( |
| 1387 | + input_tensor_grad, recv_next=recv_next |
| 1388 | + ) |
| 1389 | + ) |
| 1390 | + |
| 1391 | + assert ( |
| 1392 | + send_recv_buffer_queue.empty() |
| 1393 | + ), "send_recv buffer should be empty" |
| 1394 | + |
| 1395 | + self._sync_overlap_grads() |
| 1396 | + |
| 1397 | + if self._enable_timer: |
| 1398 | + self.timers("allreduce_shared_weight_gradients").start() |
| 1399 | + self._layers.allreduce_shared_weight_gradients() |
| 1400 | + if self._enable_timer: |
| 1401 | + self.timers("allreduce_shared_weight_gradients").stop() |
| 1402 | + |
| 1403 | + if compute_loss: |
| 1404 | + # return loss if compute loss |
| 1405 | + if self._enable_timer: |
| 1406 | + self.timers("broadcast_final_loss").start() |
| 1407 | + with paddle.amp.auto_cast(enable=False): |
| 1408 | + train_loss = self._broadcast_final_loss() |
| 1409 | + if self._enable_timer: |
| 1410 | + self.timers("broadcast_final_loss").stop() |
| 1411 | + else: |
| 1412 | + # else just return all intermediate output tensor for all micro steps |
| 1413 | + train_loss = self.output_tensors |
| 1414 | + |
| 1415 | + self.timer_printer() |
| 1416 | + return train_loss |
0 commit comments