Skip to content

Commit

Permalink
fix rotary_emb for llama
Browse files Browse the repository at this point in the history
  • Loading branch information
EnflameGCU committed May 20, 2024
1 parent b36b6a0 commit 98a9cc2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def forward(
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)

Check warning on line 451 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L451

Added line #L451 was not covered by tests

paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
Expand Down Expand Up @@ -481,7 +481,7 @@ def forward(
use_neox_rotary_style=False,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)

Check warning on line 484 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L484

Added line #L484 was not covered by tests
# hack here, because elementwise infer spmd not support broadcast now
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/llama/modeling_auto_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def forward(
if self.config.rope:
if self.use_fused_rope:
assert past_key_value is None, "fuse rotary not support cache kv for now"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)

Check warning on line 424 in paddlenlp/transformers/llama/modeling_auto_static.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto_static.py#L424

Added line #L424 was not covered by tests
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
Expand All @@ -432,7 +432,7 @@ def forward(
use_neox_rotary_style=False,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)

Check warning on line 435 in paddlenlp/transformers/llama/modeling_auto_static.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto_static.py#L435

Added line #L435 was not covered by tests
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# [bs, seq_len, num_head, head_dim]
Expand Down

0 comments on commit 98a9cc2

Please sign in to comment.