20
20
21
21
import torch
22
22
import torch .distributed as dist
23
+ from accelerate .utils import is_npu_available
23
24
24
25
from trl import TrlParser
26
+ from trl .distributed_util import init_process_group
25
27
from trl .import_utils import is_fastapi_available , is_pydantic_available , is_uvicorn_available , is_vllm_available
26
28
27
29
39
41
40
42
if is_vllm_available ():
41
43
from vllm import LLM , SamplingParams
42
- from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
43
44
from vllm .distributed .parallel_state import get_world_group
44
- from vllm .distributed .utils import StatelessProcessGroup
45
45
from vllm .sampling_params import GuidedDecodingParams
46
46
from vllm .worker .worker import Worker
47
+ if is_npu_available ():
48
+ from vllm_ascend .worker .worker import NPUWorker as Worker
47
49
else :
48
50
Worker = object
49
51
@@ -72,13 +74,13 @@ def __init__(self, *args, **kwargs):
72
74
73
75
super ().__init__ (* args , ** kwargs )
74
76
75
- # The following attributes are initialized when `init_communicator ` method is called.
76
- self .pynccl_comm = None # Communicator for weight updates
77
+ # The following attributes are initialized when `weight_update_group ` method is called.
78
+ self .weight_update_group = None # Communicator for weight updates
77
79
self .client_rank = None # Source rank for broadcasting updated weights
78
80
79
- def init_communicator (self , host : str , port : int , world_size : int ) -> None :
81
+ def init_weight_update_group (self , host : str , port : int , world_size : int , backend : str ) -> None :
80
82
"""
81
- Initializes the weight update communicator using a stateless process group.
83
+ Initializes the weight update process group.
82
84
83
85
This method creates a `StatelessProcessGroup` that allows external training processes to
84
86
communicate with vLLM workers without interfering with the global torch distributed group.
@@ -90,18 +92,23 @@ def init_communicator(self, host: str, port: int, world_size: int) -> None:
90
92
Port number to be used for communication.
91
93
world_size (`int`):
92
94
Total number of participating processes in the update group.
95
+ backend (`str`):
96
+ The backend to use for collective communication.
93
97
"""
94
- if self .pynccl_comm is not None :
95
- raise RuntimeError ("Weight update group already initialized. Call close_communicator first." )
98
+ if self .weight_update_group is not None :
99
+ raise RuntimeError ("Weight update group already initialized. Call close_weight_update_group first." )
96
100
97
101
# Get the rank of the current worker in the global world group.
98
102
rank = get_world_group ().rank
99
103
100
104
# Create a stateless process group to manage communication between training processes and vLLM workers.
101
- pg = StatelessProcessGroup .create (host = host , port = port , rank = rank , world_size = world_size )
102
-
103
- # Initialize the NCCL-based communicator for weight synchronization.
104
- self .pynccl_comm = PyNcclCommunicator (pg , device = self .device )
105
+ self .weight_update_group = init_process_group (
106
+ backend = backend ,
107
+ init_method = f"tcp://{ host } :{ port } " ,
108
+ world_size = world_size ,
109
+ rank = rank ,
110
+ group_name = "weight_update_group" ,
111
+ )
105
112
106
113
# The client process that sends updated weights has the highest rank (world_size - 1).
107
114
self .client_rank = world_size - 1
@@ -118,29 +125,28 @@ def update_named_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]
118
125
shape (`Sequence[int]`):
119
126
Shape of the weight tensor.
120
127
"""
121
- if self .pynccl_comm is None :
122
- raise RuntimeError ("Communicator not initialized. Call `init_communicator ` first." )
128
+ if self .weight_update_group is None :
129
+ raise RuntimeError ("weight update group not initialized. Call `weight_update_group ` first." )
123
130
124
131
# Allocate memory for the incoming weight tensor on the correct device.
125
132
weight = torch .empty (shape , dtype = dtype , device = self .device )
126
133
127
- # Use NCCL to broadcast the updated weights from the client (src) to all workers.
128
- self .pynccl_comm .broadcast (weight , src = self .client_rank , stream = torch .cuda .current_stream ())
129
- self .pynccl_comm .group .barrier ()
134
+ # Broadcast the updated weights from the client (src) to all workers.
135
+ torch .distributed .broadcast (weight , src = self .client_rank , group = self .weight_update_group )
130
136
131
137
# Load the received weights into the model.
132
138
self .model_runner .model .load_weights (weights = [(name , weight )])
133
139
134
- def close_communicator (self ) -> None :
140
+ def close_weight_update_group (self ) -> None :
135
141
"""
136
142
Closes the communicator when weight synchronization is no longer needed.
137
143
138
144
This method deletes the NCCL communicator to release associated resources.
139
145
"""
140
146
141
- if self .pynccl_comm is not None :
142
- del self .pynccl_comm
143
- self .pynccl_comm = None # Ensure attribute is reset to None
147
+ if self .weight_update_group is not None :
148
+ del self .weight_update_group
149
+ self .weight_update_group = None # Ensure attribute is reset to None
144
150
self .client_rank = None # Ensure attribute is reset to None
145
151
146
152
@@ -345,13 +351,15 @@ async def generate(request: GenerateRequest):
345
351
completion_ids = [list (output .token_ids ) for outputs in all_outputs for output in outputs .outputs ]
346
352
return {"completion_ids" : completion_ids }
347
353
348
- class InitCommunicatorRequest (BaseModel ):
354
+ class InitWeightUpdateGroupRequest (BaseModel ):
349
355
host : str
350
356
port : int
351
357
world_size : int
358
+ backend : str
359
+
352
360
353
- @app .post ("/init_communicator /" )
354
- async def init_communicator (request : InitCommunicatorRequest , background_tasks : BackgroundTasks ):
361
+ @app .post ("/init_weight_update_group /" )
362
+ async def init_weight_update_group (request : InitWeightUpdateGroupRequest , background_tasks : BackgroundTasks ):
355
363
"""
356
364
Initializes the communicator for synchronizing model weights between a client and multiple server
357
365
workers.
@@ -364,8 +372,8 @@ async def init_communicator(request: InitCommunicatorRequest, background_tasks:
364
372
"""
365
373
background_tasks .add_task (
366
374
llm .collective_rpc ,
367
- "init_communicator " ,
368
- args = (request .host , request .port , script_args .tensor_parallel_size + 1 ),
375
+ "init_weight_update_group " ,
376
+ args = (request .host , request .port , script_args .tensor_parallel_size + 1 , request . backend ),
369
377
)
370
378
return {"message" : "Request received, initializing communicator" }
371
379
@@ -406,13 +414,13 @@ async def reset_prefix_cache():
406
414
success = llm .llm_engine .reset_prefix_cache ()
407
415
return {"message" : "Request received, resetting prefix cache status: " + str (success )}
408
416
409
- @app .post ("/close_communicator /" )
410
- async def close_communicator ():
417
+ @app .post ("/close_weight_update_group /" )
418
+ async def close_weight_update_group ():
411
419
"""
412
420
Closes the weight update group and cleans up associated resources.
413
421
"""
414
- llm .collective_rpc ("close_communicator " )
415
- return {"message" : "Request received, closing communicator " }
422
+ llm .collective_rpc ("close_weight_update_group " )
423
+ return {"message" : "Request received, closing weight update group " }
416
424
417
425
# Start the server
418
426
uvicorn .run (app , host = script_args .host , port = script_args .port )
0 commit comments