-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GCU] Support llama for GCU #8445
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,7 +79,7 @@ | |
) | ||
|
||
try: | ||
if get_env_device() == "npu": | ||
if get_env_device() in ["npu", "gcu"]: | ||
|
||
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): | ||
if lib.endswith(".so"): | ||
|
@@ -410,6 +410,7 @@ | |
# [1, seqlen, 1, dim] | ||
self.cos_cached = emb.cos()[None, :, None, :] | ||
self.sin_cached = emb.sin()[None, :, None, :] | ||
self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) | ||
|
||
def forward(self, x, seq_len=None): | ||
# x: [bs, num_attention_heads, seq_len, head_size] | ||
|
@@ -418,6 +419,9 @@ | |
return ( | ||
cos.cast(x.dtype) if cos.dtype != x.dtype else cos, | ||
sin.cast(x.dtype) if sin.dtype != x.dtype else sin, | ||
self.cos_sin_table.cast(x.dtype) | ||
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype | ||
else self.cos_sin_table, | ||
) | ||
|
||
|
||
|
@@ -439,6 +443,7 @@ | |
# [1, seqlen, 1, dim] | ||
self.cos_cached = emb.cos()[None, :, None, :] | ||
self.sin_cached = emb.sin()[None, :, None, :] | ||
self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) | ||
|
||
|
||
class LlamaNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): | ||
|
@@ -471,19 +476,23 @@ | |
# [1, seqlen, 1, dim] | ||
scale_cos = emb.cos()[None, :, None, :] | ||
scale_sin = emb.sin()[None, :, None, :] | ||
return scale_cos, scale_sin | ||
scale_cos_sin = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) | ||
return scale_cos, scale_sin, scale_cos_sin | ||
|
||
def forward(self, x, seq_len=None): | ||
# x: [bs, num_attention_heads, seq_len, head_size] | ||
if seq_len > self.max_position_embeddings: | ||
scale_cos, scale_sin = self._scale_cos_sin(seq_len=seq_len) | ||
scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len) | ||
else: | ||
scale_cos, scale_sin = self.cos_cached, self.sin_cached | ||
scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table | ||
cos = scale_cos[:, :seq_len, :, ...] | ||
sin = scale_sin[:, :seq_len, :, ...] | ||
return ( | ||
cos.cast(x.dtype) if cos.dtype != x.dtype else cos, | ||
sin.cast(x.dtype) if sin.dtype != x.dtype else sin, | ||
scale_cos_sin.cast(x.dtype) | ||
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype | ||
else scale_cos_sin, | ||
) | ||
|
||
|
||
|
@@ -638,7 +647,7 @@ | |
) | ||
|
||
self.use_fused_rope = config.use_fused_rope | ||
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]: | ||
if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]: | ||
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: | ||
warnings.warn( | ||
"Enable fuse rope in the config, but fuse rope is not available. " | ||
|
@@ -934,7 +943,7 @@ | |
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin, | ||
) | ||
else: | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
|
||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
||
|
@@ -1398,7 +1407,7 @@ | |
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") | ||
expanded_attn_mask = expanded_attn_mask.astype("float32") | ||
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) | ||
elif get_env_device() == "xpu": | ||
elif get_env_device() in ["xpu", "gcu"]: | ||
x = paddle.to_tensor(0.0, dtype=dtype) | ||
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) | ||
expanded_attn_mask = expanded_attn_mask.astype(dtype) | ||
|
@@ -1528,7 +1537,7 @@ | |
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype | ||
) # [bs, 1, seq_len, seq_len] | ||
is_casual = False | ||
if self.config.use_flash_attention: | ||
if self.config.use_flash_attention and get_env_device() != "gcu": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里在attention mask的处理上,GCU不一样的地方是什么? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 基于 |
||
is_casual = is_casual_mask(attention_mask) | ||
if get_env_device() != "npu": | ||
if is_casual and alibi is None: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一定要加 cos_sin 的优化吗?代码改动很大,而且会导致其他设备性能下降,凭空多了很多开销。
或者你们需要的时候再自己去造一个 cos_sin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里主要是因为算子的实现与
paper
或者vllm
一致,使用了与这里不同的sin/cos
。关于其他设备性能开销,一方面,应该大多table
的计算只在初始化阶段,另一方面,我们将按照第一个issue
的建议,在特定设备进行计算,这里仅仅只会多返回一个None
。