11
11
from typing import Optional
12
12
import torch
13
13
from deepspeed import comm as dist
14
- from .layers import LinearAllreduce , LinearLayer , LmHeadLinearAllreduce
14
+ from .layers import LinearAllreduce , LinearLayer , LmHeadLinearAllreduce , Yuan_LinearAllreduce , Yuan_LinearLayer , GateUpPack_LinearLayer , Conv_LinearALlreduce , fused_LinearLayer , conv_LinearLayer
15
15
from deepspeed .accelerator import get_accelerator
16
- from .fusedqkv_utils import require_tp_fused_qkvw , prepare_tp_fused_qkvw , shard_value_with_share_qk , shard_chunk_mlp
16
+ from .fusedqkv_utils import require_tp_fused_qkvw
17
17
from deepspeed .module_inject .tp_shard import get_shard_size , get_shard_size_list
18
+ from deepspeed .utils import groups
19
+ from deepspeed .module_inject .layers import is_autotp_training_mode
18
20
19
21
20
22
def move (tensor , device , copy = True ):
@@ -333,10 +335,18 @@ def tp_parser(model):
333
335
return policy_list
334
336
335
337
def set_tensor_parallel_config (self , mp_size , mp_group ):
338
+
339
+ if is_autotp_training_mode ():
340
+ self .mp_group = groups .get_tensor_model_parallel_group ()
341
+ self .mp_size = groups .get_tensor_model_parallel_world_size ()
342
+ return
343
+
336
344
self .mp_size = mp_size
337
345
self .mp_group = mp_group
338
346
339
347
def _replace (self , child , name , conv_linear_layer ):
348
+ # This function should clearly define the routing rules for specific layers
349
+ # and avoid any complex shard-related logic.
340
350
if getattr (child , "replaced" , False ) == True :
341
351
return
342
352
device_name = 'cpu' if self .keep_module_on_host else get_accelerator ().current_device_name ()
@@ -352,80 +362,41 @@ def _replace(self, child, name, conv_linear_layer):
352
362
# For Yuan model
353
363
if 'Yuan' in str (self .module ):
354
364
if 'v_proj' in name :
355
- weight , bias = shard_value_with_share_qk (child .weight .data , child .bias , dist .get_rank (),
356
- dist .get_world_size (), True )
357
- return LinearLayer (weight = weight , bias = bias )
365
+ return Yuan_LinearLayer (child , self .mp_group )
366
+
358
367
elif 'o_proj' in name :
359
- weight , bias = shard_value_with_share_qk (child .weight .data , child .bias , dist .get_rank (),
360
- dist .get_world_size (), False )
361
- return LinearAllreduce (weight , bias , self .mp_group )
362
- # For Arctic model, bypass to all_reduce replacement for w2 weights
368
+ return Yuan_LinearAllreduce (child , self .mp_group )
369
+
370
+ # For MLP including chunk layer.
371
+ if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str (self .module )):
372
+ return GateUpPack_LinearLayer (child , self .mp_group )
373
+ # For Arctic model, bypass to all_reduce replacement for w2 weights
363
374
arctic_w2_all_reduce_linear = False
364
375
if 'Arctic' in str (self .module ) and 'w2' in name :
365
376
arctic_w2_all_reduce_linear = True
366
377
# For MoE MLP model, e.g., deepseek and jamba
367
378
down_proj = False
368
379
if 'down_proj' in name :
369
380
down_proj = True
370
- # For MLP including chunk layer.
371
- if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str (self .module )):
372
- weight , bias = shard_chunk_mlp (child .weight .data , child .bias , dist .get_rank (), dist .get_world_size ())
373
- return LinearLayer (weight = weight , bias = bias )
374
381
if name in self .all_reduce_linears or arctic_w2_all_reduce_linear or down_proj :
375
- # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
376
- # else [weight_shape[0], weight_shape[1] // mp_size]
377
382
383
+ setattr (child , "replaced" , True )
378
384
if self .conv_linear_layer :
379
- child .weight .data = child .weight .data .transpose (- 1 , - 2 ).contiguous ()
380
- data = child .weight .data .split (get_shard_size_list (
381
- weight_shape [0 ] if self .conv_linear_layer else weight_shape [1 ], self .mp_size , name ),
382
- dim = 1 )
383
- data_dc = move (data [mp_replace .gpu_index ], device_name , return_new_copy ).detach ()
384
- del data
385
+ return Conv_LinearALlreduce (child , self .mp_group , name = name )
386
+ elif name == "lm_head" or name == 'embed_out' :
387
+ return LmHeadLinearAllreduce (child , self .mp_group )
385
388
386
- setattr (child , "replaced" , True )
387
- if name == "lm_head" or name == 'embed_out' :
388
- return LmHeadLinearAllreduce (
389
- torch .nn .parameter .Parameter (data_dc , requires_grad = False ), dist .get_rank (), dist .get_world_size (),
390
- child .bias if child .bias is None else torch .nn .parameter .Parameter (
391
- move (child .bias , device_name , return_new_copy )), self .mp_group )
392
- return LinearAllreduce (torch .nn .parameter .Parameter (data_dc , requires_grad = False ), child .bias if child .bias is None else \
393
- torch .nn .parameter .Parameter (move (child .bias , device_name , return_new_copy )), self .mp_group )
389
+ return LinearAllreduce (child , self .mp_group , name = name )
394
390
else :
395
391
396
- # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
397
- # else [weight_shape[0] // mp_size, weight_shape[1]]
392
+ setattr (child , "replaced" , True )
398
393
if self .conv_linear_layer :
399
- child .weight .data = child .weight .data .transpose (- 1 , - 2 ).contiguous ()
400
-
401
- if require_tp_fused_qkvw (name , self .mp_size ):
394
+ conv_LinearLayer (child , self .mp_group )
395
+ elif require_tp_fused_qkvw (name , self .mp_size ):
402
396
#Check and handle fused qkv for TP
403
- #The copy is a regular copy, The shape of dst and src is the same
404
- data_dc = move (
405
- prepare_tp_fused_qkvw (self .module , child .weight .data , self .mp_size , mp_replace .gpu_index ),
406
- device_name , return_new_copy )
407
-
408
- bias_data_dc = None if child .bias is None else move (
409
- prepare_tp_fused_qkvw (self .module , child .bias .data , self .mp_size , mp_replace .gpu_index ),
410
- device_name , return_new_copy )
411
- else :
412
- data = child .weight .data .split (get_shard_size_list (weight_shape [0 ], self .mp_size , name ),
413
- dim = 1 if self .conv_linear_layer else 0 )
414
- data_dc = move (data [mp_replace .gpu_index ], device_name , return_new_copy ).detach ()
415
- del data
416
-
417
- if child .bias is not None :
418
- bias_data = child .bias .data .split (get_shard_size_list (
419
- weight_shape [1 ] if self .conv_linear_layer else weight_shape [0 ], self .mp_size , name ),
420
- dim = 0 )
421
- bias_data = move (bias_data [mp_replace .gpu_index ], device_name , return_new_copy )
422
- bias_data_dc = torch .nn .parameter .Parameter (bias_data , requires_grad = False )
423
- del bias_data
424
- else :
425
- bias_data_dc = None
397
+ return fused_LinearLayer (child , self .mp_group , fused_module = self .module )
426
398
427
- setattr (child , "replaced" , True )
428
- return LinearLayer (weight = torch .nn .parameter .Parameter (data_dc , requires_grad = False ), bias = bias_data_dc )
399
+ return LinearLayer (child , self .mp_group , name = name )
429
400
430
401
def _slice_embedding (self , child , name , conv_linear_layer ):
431
402
if getattr (child , "replaced" , False ) == True :
0 commit comments