Skip to content

Commit

Permalink
[XPU] update is_fused_matmul_bias_supported for xpu (#5820)
Browse files Browse the repository at this point in the history
  • Loading branch information
houj04 authored May 4, 2023
1 parent 53b2ea0 commit 358ce43
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def register_sequence_parallel_allreduce_hooks(model, accumulation_steps, fuse_s


def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu():
return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue")
else:
return False
Expand Down Expand Up @@ -283,7 +283,7 @@ def __init__(
"You set fuse_matmul_bias=True in ColumnSequenceParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher."
"with cuda 11.6 or higher, or use xpu version."
)
from paddle.incubate.nn.functional import fused_linear

Expand Down
4 changes: 2 additions & 2 deletions model_zoo/gpt-3/ppfleetx/models/language_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def is_fused_matmul_bias_supported():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu():
return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue")
else:
return False
Expand Down Expand Up @@ -70,7 +70,7 @@ def process_model_configs(config):
configs["fused_linear"] = False
logging.warning(
"The flag fused_linear only valid for cuda version higher than 11.6, "
"but the paddle is compiled with cuda " + paddle.version.cuda()
"but the paddle is compiled with cuda " + paddle.version.cuda() + ", or you can use xpu version."
)

pp_degree = config.Distributed.pp_degree
Expand Down

0 comments on commit 358ce43

Please sign in to comment.