diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 3c42d5bf1213..275d38414803 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -448,7 +448,7 @@ def forward( assert past_key_value is None, "fuse rotary not support cache kv for now" batch_size, seq_length, num_heads, head_dim = query_states.shape _, kv_seq_len, num_key_value_heads, _ = key_states.shape - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) paddle_version = float(paddle.__version__[:3]) if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads): @@ -481,7 +481,7 @@ def forward( use_neox_rotary_style=False, ) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) # hack here, because elementwise infer spmd not support broadcast now query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/paddlenlp/transformers/llama/modeling_auto_static.py b/paddlenlp/transformers/llama/modeling_auto_static.py index d9af478b808c..d5774d381ea4 100644 --- a/paddlenlp/transformers/llama/modeling_auto_static.py +++ b/paddlenlp/transformers/llama/modeling_auto_static.py @@ -421,7 +421,7 @@ def forward( if self.config.rope: if self.use_fused_rope: assert past_key_value is None, "fuse rotary not support cache kv for now" - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states, _ = fused_rotary_position_embedding( query_states, key_states, @@ -432,7 +432,7 @@ def forward( use_neox_rotary_style=False, ) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bs, seq_len, num_head, head_dim] diff --git a/scripts/distribute/run_ci.sh b/scripts/distribute/run_ci.sh index f558cde651b2..da0636186622 100644 --- a/scripts/distribute/run_ci.sh +++ b/scripts/distribute/run_ci.sh @@ -30,6 +30,7 @@ target_lists_for_llama=( "paddlenlp/trainer/auto_trainer.py" "paddlenlp/transformers/llama/modeling_auto_static.py" "paddlenlp/transformers/llama/modeling_auto.py" + "paddlenlp/transformers/llama/modeling.py" "scripts/distribute" )