Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GCU] Support llama for GCU #8445

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/benchmark/wiki_lambada/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def get_parser():
"--device",
type=str,
default="gpu",
choices=["cpu", "eval_pathgpu", "xpu", "npu"],
help="select cpu, gpu, xpu devices.",
choices=["cpu", "gpu", "xpu", "npu", "gcu"],
help="select cpu, gpu, xpu, gcu devices.",
)
parser.add_argument(
"--dtype",
Expand All @@ -67,6 +67,12 @@ def get_parser():
choices=["bfloat16", "float16", "float32"],
help="set the dtype of model",
)
parser.add_argument(
"--use_flash_attention",
type=bool,
default=False,
help="Whether to use flash attention",
)

# load autodist name files, eg: bloom-176b
parser.add_argument("--load_autodist", action="store_true", help="whether load auto-dist wieght file")
Expand Down Expand Up @@ -316,7 +322,7 @@ def do_generation():
tensor_parallel_output=False,
tensor_parallel_degree=args.tensor_parallel_degree,
tensor_parallel_rank=paddle.distributed.get_rank(),
use_flash_attention=False,
use_flash_attention=args.use_flash_attention,
dtype=args.dtype, # todo enable set dtype to avoid additional mem usage
)

Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,8 @@
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
if top_p is not None and top_p < 1.0:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
if paddle.device.is_compiled_with_custom_device("gcu"):
probs = paddle.cast(probs, "float32")

Check warning on line 1212 in paddlenlp/generation/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/generation/utils.py#L1212

Added line #L1212 was not covered by tests

# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
next_tokens = paddle.multinomial(probs)
Expand Down
24 changes: 21 additions & 3 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
except ImportError:
fused_rotary_position_embedding = None
try:
if get_env_device() == "npu":
if get_env_device() in ["npu", "gcu"]:
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
Expand All @@ -53,13 +53,18 @@


def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb):
assert past_key_value is None, "fuse rotary not support cache kv for now"
if get_env_device() != "gcu":
assert past_key_value is None, "fuse rotary not support cache kv for now"

Check warning on line 57 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L56-L57

Added lines #L56 - L57 were not covered by tests
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len)

Check warning on line 60 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L60

Added line #L60 was not covered by tests
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
elif get_env_device() == "gcu":
query_states, key_states = core.eager._run_custom_op(

Check warning on line 65 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L64-L65

Added lines #L64 - L65 were not covered by tests
"fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True
)
else:
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
Expand Down Expand Up @@ -103,6 +108,8 @@
def fusion_rms_norm(hidden_states, weight, variance_epsilon):
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "gcu":
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]

Check warning on line 112 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L111-L112

Added lines #L111 - L112 were not covered by tests
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821
Expand Down Expand Up @@ -158,6 +165,17 @@
False,
npu_is_casual,
)[0]
elif get_env_device() == "gcu":
attn_output = core.eager._run_custom_op(

Check warning on line 169 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L168-L169

Added lines #L168 - L169 were not covered by tests
"fused_sdp_flash_attention_gcu",
query_states,
key_states,
value_states,
attention_mask,
0.0,
attention_mask is None,
True,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
query_states,
Expand Down
25 changes: 17 additions & 8 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
)

try:
if get_env_device() == "npu":
if get_env_device() in ["npu", "gcu"]:

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
Expand Down Expand Up @@ -410,6 +410,7 @@
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
Expand All @@ -418,6 +419,9 @@
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
self.cos_sin_table.cast(x.dtype)
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype
else self.cos_sin_table,
)


Expand All @@ -439,6 +443,7 @@
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)

Check warning on line 446 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L446

Added line #L446 was not covered by tests


class LlamaNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down Expand Up @@ -471,19 +476,23 @@
# [1, seqlen, 1, dim]
scale_cos = emb.cos()[None, :, None, :]
scale_sin = emb.sin()[None, :, None, :]
return scale_cos, scale_sin
scale_cos_sin = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return scale_cos, scale_sin, scale_cos_sin

Check warning on line 480 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L479-L480

Added lines #L479 - L480 were not covered by tests

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_position_embeddings:
scale_cos, scale_sin = self._scale_cos_sin(seq_len=seq_len)
scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)

Check warning on line 485 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L485

Added line #L485 was not covered by tests
else:
scale_cos, scale_sin = self.cos_cached, self.sin_cached
scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table

Check warning on line 487 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L487

Added line #L487 was not covered by tests
cos = scale_cos[:, :seq_len, :, ...]
sin = scale_sin[:, :seq_len, :, ...]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
scale_cos_sin.cast(x.dtype)
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype
else scale_cos_sin,
)


Expand Down Expand Up @@ -638,7 +647,7 @@
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]:
if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand Down Expand Up @@ -934,7 +943,7 @@
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一定要加 cos_sin 的优化吗?代码改动很大,而且会导致其他设备性能下降,凭空多了很多开销。

或者你们需要的时候再自己去造一个 cos_sin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里主要是因为算子的实现与paper或者vllm一致,使用了与这里不同的sin/cos。关于其他设备性能开销,一方面,应该大多table的计算只在初始化阶段,另一方面,我们将按照第一个issue的建议,在特定设备进行计算,这里仅仅只会多返回一个None


query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down Expand Up @@ -1398,7 +1407,7 @@
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = expanded_attn_mask.astype("float32")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
elif get_env_device() in ["xpu", "gcu"]:
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
Expand Down Expand Up @@ -1528,7 +1537,7 @@
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
is_casual = False
if self.config.use_flash_attention:
if self.config.use_flash_attention and get_env_device() != "gcu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在attention mask的处理上,GCU不一样的地方是什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

基于 use_flash_attention kernel 的实现,is_casual 情况下也是需要当前与输入相同dtypeattention_mask,而不是None或者bool类型的mask

is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
if is_casual and alibi is None:
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@
return "gpu"
elif "npu" in paddle.device.get_all_custom_device_type():
return "npu"
elif "gcu" in paddle.device.get_all_custom_device_type():
return "gcu"

Check warning on line 128 in paddlenlp/utils/tools.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/tools.py#L128

Added line #L128 was not covered by tests
elif paddle.is_compiled_with_rocm():
return "rocm"
elif paddle.is_compiled_with_xpu():
Expand Down
Loading