Skip to content

Commit

Permalink
[XPU] use allgather and fp32 multinomial for XPU (#8787)
Browse files Browse the repository at this point in the history
  • Loading branch information
houj04 authored and FeixLiu committed Jul 26, 2024
1 parent 96c5236 commit 727ea59
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 2 additions & 0 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,8 @@ def sample(
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
if top_p is not None and top_p < 1.0:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
if paddle.device.is_compiled_with_xpu():
probs = paddle.cast(probs, "float32")

# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
next_tokens = paddle.multinomial(probs)
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ class RowSequenceParallelLinear:
load_state_dict,
)
from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix
from ...utils.distributed import distributed_gather
from ...utils.distributed import distributed_allgather, distributed_gather
from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME
from ...utils.log import logger
from ...utils.tools import get_env_device
from .lora_config import LoRAConfig
from .lora_layers import (
ColumnParallelLoRALinear,
Expand Down Expand Up @@ -301,7 +302,10 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict):
for key in trainable_state_dict:
tensor = trainable_state_dict[key]
if key in trainable_name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = trainable_name_action_mappings[key]
if key in self.lora_split_mapping and not self.lora_split_mapping[key] and "_scale" in key and is_dst:
ret = paddle.to_tensor(ret)
Expand Down

0 comments on commit 727ea59

Please sign in to comment.