Skip to content

Commit bbf0752

Browse files
committed
update weight update process group
1 parent 9f3702f commit bbf0752

File tree

4 files changed

+152
-50
lines changed

4 files changed

+152
-50
lines changed

trl/distributed_util.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from datetime import timedelta
16+
from typing import Any, Optional, Union
17+
18+
import torch
19+
import torch.distributed
20+
from torch.distributed.distributed_c10d import (
21+
Backend,
22+
PrefixStore,
23+
Store,
24+
_new_process_group_helper,
25+
_world,
26+
default_pg_timeout,
27+
rendezvous,
28+
)
29+
30+
31+
def init_process_group(
32+
backend: Union[str, Backend] = None,
33+
init_method: Optional[str] = None,
34+
timeout: Optional[timedelta] = None,
35+
world_size: int = -1,
36+
rank: int = -1,
37+
store: Optional[Store] = None,
38+
group_name: str = None,
39+
pg_options: Optional[Any] = None,
40+
):
41+
"""
42+
Copy from pytorch to allow creating multiple main groups.
43+
https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
44+
Reference implementation from: https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
45+
"""
46+
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
47+
48+
if store is not None:
49+
assert world_size > 0, "world_size must be positive if using store"
50+
assert rank >= 0, "rank must be non-negative if using store"
51+
elif init_method is None:
52+
init_method = "env://"
53+
54+
if backend:
55+
backend = Backend(backend)
56+
else:
57+
backend = Backend("undefined")
58+
59+
if timeout is None:
60+
timeout = default_pg_timeout
61+
62+
# backward compatible API
63+
if store is None:
64+
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
65+
store, rank, world_size = next(rendezvous_iterator)
66+
store.set_timeout(timeout)
67+
68+
# Use a PrefixStore to avoid accidental overrides of keys used by
69+
# different systems (e.g. RPC) in case the store is multi-tenant.
70+
store = PrefixStore(group_name, store)
71+
72+
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
73+
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
74+
# We need to determine the appropriate parameter name based on PyTorch version
75+
pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
76+
pg, _ = _new_process_group_helper(
77+
world_size,
78+
rank,
79+
[],
80+
backend,
81+
store,
82+
group_name=group_name,
83+
**{pg_options_param_name: pg_options},
84+
timeout=timeout,
85+
)
86+
87+
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
88+
89+
return pg

trl/extras/vllm_client.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from torch import nn
2222

23+
from ..distributed_util import init_process_group
2324
from ..import_utils import is_requests_available, is_vllm_available
2425

2526

@@ -28,11 +29,6 @@
2829
from requests import ConnectionError
2930

3031

31-
if is_vllm_available():
32-
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
33-
from vllm.distributed.utils import StatelessProcessGroup
34-
35-
3632
logger = logging.getLogger(__name__)
3733

3834

@@ -53,6 +49,8 @@ class VLLMClient:
5349
connection_timeout (`float`, *optional*, defaults to `0.0`):
5450
Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
5551
timeout, a `ConnectionError` is raised.
52+
backend (`str`, *optional*, default to `nccl`):
53+
The backend to use for collective communication.
5654
5755
Examples:
5856
Run the vLLM server with the model `Qwen/Qwen2.5-7B`:
@@ -80,7 +78,7 @@ class VLLMClient:
8078
"""
8179

8280
def __init__(
83-
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
81+
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0, backend: str = "nccl"
8482
):
8583
if not is_requests_available():
8684
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
@@ -91,9 +89,10 @@ def __init__(
9189
self.host = host
9290
self.server_port = server_port
9391
self.group_port = group_port
92+
self.backend = backend
9493
self.check_server(connection_timeout) # check server and fail after timeout
95-
self.init_communicator()
96-
atexit.register(self.close_communicator) # when the client object is deleted, close the weight update group
94+
self.init_weight_update_group()
95+
atexit.register(self.close_weight_update_group) # when the client object is deleted, close the weight update group
9796

9897
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
9998
"""
@@ -188,7 +187,7 @@ def generate(
188187
else:
189188
raise Exception(f"Request failed: {response.status_code}, {response.text}")
190189

191-
def init_communicator(self):
190+
def init_weight_update_group(self):
192191
"""
193192
Initializes the weight update group in a distributed setup for model synchronization.
194193
"""
@@ -204,15 +203,20 @@ def init_communicator(self):
204203
self.rank = tensor_parallel_size # The client's rank is the last process
205204

206205
# Initialize weight update group
207-
url = f"http://{self.host}:{self.server_port}/init_communicator/"
206+
url = f"http://{self.host}:{self.server_port}/init_weight_update_group/"
208207
# In the server side, the host is set to 0.0.0.0
209-
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
208+
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size, "backend": self.backend})
210209
if response.status_code != 200:
211210
raise Exception(f"Request failed: {response.status_code}, {response.text}")
212211

213212
# Set up the communication group for weight broadcasting
214-
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
215-
self.pynccl_comm = PyNcclCommunicator(pg, device="cuda:0")
213+
self.weight_update_group = init_process_group(
214+
backend=self.backend,
215+
init_method=f"tcp://{self.host}:{self.group_port}",
216+
world_size=world_size,
217+
rank=self.rank,
218+
group_name="weight_update_group",
219+
)
216220

217221
def update_named_param(self, name: str, weights: torch.Tensor):
218222
"""
@@ -231,8 +235,8 @@ def update_named_param(self, name: str, weights: torch.Tensor):
231235
raise Exception(f"Request failed: {response.status_code}, {response.text}")
232236

233237
# Broadcast the weights to the other processes
234-
self.pynccl_comm.broadcast(weights, src=self.rank, stream=torch.cuda.current_stream())
235-
self.pynccl_comm.group.barrier()
238+
torch.distributed.broadcast(weights, src=self.rank, group_name=self.weight_update_group)
239+
torch.distributed.barrier()
236240

237241
def update_model_params(self, model: nn.Module):
238242
"""
@@ -255,11 +259,11 @@ def reset_prefix_cache(self):
255259
if response.status_code != 200:
256260
raise Exception(f"Request failed: {response.status_code}, {response.text}")
257261

258-
def close_communicator(self):
262+
def close_weight_update_group(self):
259263
"""
260-
Closes the weight update group and cleans up the communication group.
264+
Closes the weight update group and cleans up the weight update group.
261265
"""
262-
url = f"http://{self.host}:{self.server_port}/close_communicator/"
266+
url = f"http://{self.host}:{self.server_port}/close_weight_update_group/"
263267
response = self.session.post(url)
264268
if response.status_code != 200:
265269
raise Exception(f"Request failed: {response.status_code}, {response.text}")

trl/scripts/vllm_serve.py

+38-30
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020

2121
import torch
2222
import torch.distributed as dist
23+
from accelerate.utils import is_npu_available
2324

2425
from trl import TrlParser
26+
from trl.distributed_util import init_process_group
2527
from trl.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available, is_vllm_available
2628

2729

@@ -39,11 +41,11 @@
3941

4042
if is_vllm_available():
4143
from vllm import LLM, SamplingParams
42-
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
4344
from vllm.distributed.parallel_state import get_world_group
44-
from vllm.distributed.utils import StatelessProcessGroup
4545
from vllm.sampling_params import GuidedDecodingParams
4646
from vllm.worker.worker import Worker
47+
if is_npu_available():
48+
from vllm_ascend.worker.worker import NPUWorker as Worker
4749
else:
4850
Worker = object
4951

@@ -72,13 +74,13 @@ def __init__(self, *args, **kwargs):
7274

7375
super().__init__(*args, **kwargs)
7476

75-
# The following attributes are initialized when `init_communicator` method is called.
76-
self.pynccl_comm = None # Communicator for weight updates
77+
# The following attributes are initialized when `weight_update_group` method is called.
78+
self.weight_update_group = None # Communicator for weight updates
7779
self.client_rank = None # Source rank for broadcasting updated weights
7880

79-
def init_communicator(self, host: str, port: int, world_size: int) -> None:
81+
def init_weight_update_group(self, host: str, port: int, world_size: int, backend: str) -> None:
8082
"""
81-
Initializes the weight update communicator using a stateless process group.
83+
Initializes the weight update process group.
8284
8385
This method creates a `StatelessProcessGroup` that allows external training processes to
8486
communicate with vLLM workers without interfering with the global torch distributed group.
@@ -90,18 +92,23 @@ def init_communicator(self, host: str, port: int, world_size: int) -> None:
9092
Port number to be used for communication.
9193
world_size (`int`):
9294
Total number of participating processes in the update group.
95+
backend (`str`):
96+
The backend to use for collective communication.
9397
"""
94-
if self.pynccl_comm is not None:
95-
raise RuntimeError("Weight update group already initialized. Call close_communicator first.")
98+
if self.weight_update_group is not None:
99+
raise RuntimeError("Weight update group already initialized. Call close_weight_update_group first.")
96100

97101
# Get the rank of the current worker in the global world group.
98102
rank = get_world_group().rank
99103

100104
# Create a stateless process group to manage communication between training processes and vLLM workers.
101-
pg = StatelessProcessGroup.create(host=host, port=port, rank=rank, world_size=world_size)
102-
103-
# Initialize the NCCL-based communicator for weight synchronization.
104-
self.pynccl_comm = PyNcclCommunicator(pg, device=self.device)
105+
self.weight_update_group = init_process_group(
106+
backend=backend,
107+
init_method=f"tcp://{host}:{port}",
108+
world_size=world_size,
109+
rank=rank,
110+
group_name="weight_update_group",
111+
)
105112

106113
# The client process that sends updated weights has the highest rank (world_size - 1).
107114
self.client_rank = world_size - 1
@@ -118,29 +125,28 @@ def update_named_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]
118125
shape (`Sequence[int]`):
119126
Shape of the weight tensor.
120127
"""
121-
if self.pynccl_comm is None:
122-
raise RuntimeError("Communicator not initialized. Call `init_communicator` first.")
128+
if self.weight_update_group is None:
129+
raise RuntimeError("weight update group not initialized. Call `weight_update_group` first.")
123130

124131
# Allocate memory for the incoming weight tensor on the correct device.
125132
weight = torch.empty(shape, dtype=dtype, device=self.device)
126133

127-
# Use NCCL to broadcast the updated weights from the client (src) to all workers.
128-
self.pynccl_comm.broadcast(weight, src=self.client_rank, stream=torch.cuda.current_stream())
129-
self.pynccl_comm.group.barrier()
134+
# Broadcast the updated weights from the client (src) to all workers.
135+
torch.distributed.broadcast(weight, src=self.client_rank, group=self.weight_update_group)
130136

131137
# Load the received weights into the model.
132138
self.model_runner.model.load_weights(weights=[(name, weight)])
133139

134-
def close_communicator(self) -> None:
140+
def close_weight_update_group(self) -> None:
135141
"""
136142
Closes the communicator when weight synchronization is no longer needed.
137143
138144
This method deletes the NCCL communicator to release associated resources.
139145
"""
140146

141-
if self.pynccl_comm is not None:
142-
del self.pynccl_comm
143-
self.pynccl_comm = None # Ensure attribute is reset to None
147+
if self.weight_update_group is not None:
148+
del self.weight_update_group
149+
self.weight_update_group = None # Ensure attribute is reset to None
144150
self.client_rank = None # Ensure attribute is reset to None
145151

146152

@@ -345,13 +351,15 @@ async def generate(request: GenerateRequest):
345351
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
346352
return {"completion_ids": completion_ids}
347353

348-
class InitCommunicatorRequest(BaseModel):
354+
class InitWeightUpdateGroupRequest(BaseModel):
349355
host: str
350356
port: int
351357
world_size: int
358+
backend: str
359+
352360

353-
@app.post("/init_communicator/")
354-
async def init_communicator(request: InitCommunicatorRequest, background_tasks: BackgroundTasks):
361+
@app.post("/init_weight_update_group/")
362+
async def init_weight_update_group(request: InitWeightUpdateGroupRequest, background_tasks: BackgroundTasks):
355363
"""
356364
Initializes the communicator for synchronizing model weights between a client and multiple server
357365
workers.
@@ -364,8 +372,8 @@ async def init_communicator(request: InitCommunicatorRequest, background_tasks:
364372
"""
365373
background_tasks.add_task(
366374
llm.collective_rpc,
367-
"init_communicator",
368-
args=(request.host, request.port, script_args.tensor_parallel_size + 1),
375+
"init_weight_update_group",
376+
args=(request.host, request.port, script_args.tensor_parallel_size + 1, request.backend),
369377
)
370378
return {"message": "Request received, initializing communicator"}
371379

@@ -406,13 +414,13 @@ async def reset_prefix_cache():
406414
success = llm.llm_engine.reset_prefix_cache()
407415
return {"message": "Request received, resetting prefix cache status: " + str(success)}
408416

409-
@app.post("/close_communicator/")
410-
async def close_communicator():
417+
@app.post("/close_weight_update_group/")
418+
async def close_weight_update_group():
411419
"""
412420
Closes the weight update group and cleans up associated resources.
413421
"""
414-
llm.collective_rpc("close_communicator")
415-
return {"message": "Request received, closing communicator"}
422+
llm.collective_rpc("close_weight_update_group")
423+
return {"message": "Request received, closing weight update group"}
416424

417425
# Start the server
418426
uvicorn.run(app, host=script_args.host, port=script_args.port)

trl/trainer/grpo_trainer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch
2323
import torch.utils.data
2424
import transformers
25-
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
25+
from accelerate.utils import broadcast_object_list, gather, gather_object, is_npu_available, is_peft_model, set_seed
2626
from datasets import Dataset, IterableDataset
2727
from packaging import version
2828
from torch import nn
@@ -474,8 +474,9 @@ def data_collator(features): # No data collation is needed in GRPO
474474
)
475475

476476
if self.accelerator.is_main_process:
477+
backend = "hccl" if is_npu_available() else "nccl"
477478
self.vllm_client = VLLMClient(
478-
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout
479+
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout, backend=backend
479480
)
480481

481482
# vLLM specific sampling arguments

0 commit comments

Comments
 (0)