Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cp 28 for xpu #8812

Merged
merged 4 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,8 @@ def sample(
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
if top_p is not None and top_p < 1.0:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
if paddle.device.is_compiled_with_xpu():
probs = paddle.cast(probs, "float32")

# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
next_tokens = paddle.multinomial(probs)
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ class RowSequenceParallelLinear:
load_state_dict,
)
from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix
from ...utils.distributed import distributed_gather
from ...utils.distributed import distributed_allgather, distributed_gather
from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME
from ...utils.log import logger
from ...utils.tools import get_env_device
from .lora_config import LoRAConfig
from .lora_layers import (
ColumnParallelLoRALinear,
Expand Down Expand Up @@ -301,7 +302,10 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict):
for key in trainable_state_dict:
tensor = trainable_state_dict[key]
if key in trainable_name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = trainable_name_action_mappings[key]
if key in self.lora_split_mapping and not self.lora_split_mapping[key] and "_scale" in key and is_dst:
ret = paddle.to_tensor(ret)
Expand Down
13 changes: 10 additions & 3 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
get_checkpoint_shard_files,
is_safetensors_available,
)
from paddlenlp.utils.distributed import distributed_gather
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
PADDLE_MASTER_WEIGHTS_INDEX_NAME,
Expand All @@ -64,6 +64,7 @@
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy, nested_copy_place
from paddlenlp.utils.tools import get_env_device

if is_safetensors_available():
# from safetensors import safe_open
Expand Down Expand Up @@ -1747,7 +1748,10 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
key = filter_keys[i]
tensor = state_dict[key]
if key in tp_actions:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions.pop(key)
tensor = action(ret) if is_dst else None
else:
Expand Down Expand Up @@ -1784,7 +1788,10 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
if tensor.numel().item() == 1:
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=tp_group, offload=False)
else:
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
action = tp_actions[model_key]
tensor = action(ret) if is_dst else None
else:
Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from paddle import Tensor
from paddle.nn import Layer

from paddlenlp.utils.distributed import distributed_gather
from paddlenlp.utils.distributed import distributed_allgather, distributed_gather
from paddlenlp.utils.env import CONFIG_NAME, PADDLE_WEIGHTS_NAME, PYTORCH_WEIGHTS_NAME
from paddlenlp.utils.import_utils import (
is_package_available,
Expand All @@ -46,6 +46,7 @@
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.serialization import load_torch
from paddlenlp.utils.tools import get_env_device

if TYPE_CHECKING:
from paddlenlp.transformers import PretrainedConfig, PretrainedModel
Expand Down Expand Up @@ -1269,7 +1270,10 @@ def merge_tensor_parallel(cls, state_dict, config) -> None:
for key in state_dict.keys():
tensor = state_dict[key]
if key in name_action_mappings:
ret = distributed_gather(tensor, group=mp_group, offload=True)
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = name_action_mappings.pop(key)
tensor = action(ret) if is_dst else None
else:
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def distributed_allgather(tensor: Any, group=None, offload=False):
x.reshape_(origin_shape)

else:
distributed.all_gather(output_tensors, tensor)
distributed.all_gather(output_tensors, tensor, group=group)

return output_tensors

Expand Down
Loading