From 11e5e656a3fe86308c720919cdcce678aca5e8c2 Mon Sep 17 00:00:00 2001 From: heavyrain_lzy <1528794076@qq.com> Date: Tue, 21 May 2024 21:07:28 +0800 Subject: [PATCH 1/4] update rotary_emb in auto_parallel --- paddlenlp/transformers/llama/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 3c42d5bf1213..0562c701b63e 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -481,7 +481,7 @@ def forward( use_neox_rotary_style=False, ) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) # hack here, because elementwise infer spmd not support broadcast now query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) From 4a964b299ed18bc10b0bc272d726e014b50a270f Mon Sep 17 00:00:00 2001 From: heavyrain_lzy <1528794076@qq.com> Date: Wed, 22 May 2024 10:33:46 +0800 Subject: [PATCH 2/4] update rotay_emb --- paddlenlp/transformers/llama/modeling_auto_static.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling_auto_static.py b/paddlenlp/transformers/llama/modeling_auto_static.py index d9af478b808c..d5774d381ea4 100644 --- a/paddlenlp/transformers/llama/modeling_auto_static.py +++ b/paddlenlp/transformers/llama/modeling_auto_static.py @@ -421,7 +421,7 @@ def forward( if self.config.rope: if self.use_fused_rope: assert past_key_value is None, "fuse rotary not support cache kv for now" - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states, _ = fused_rotary_position_embedding( query_states, key_states, @@ -432,7 +432,7 @@ def forward( use_neox_rotary_style=False, ) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bs, seq_len, num_head, head_dim] From 86b7d19588747dfec14ae2f49977d54f808d903f Mon Sep 17 00:00:00 2001 From: heavyrain_lzy <1528794076@qq.com> Date: Wed, 22 May 2024 10:36:18 +0800 Subject: [PATCH 3/4] update rotay_emb --- paddlenlp/transformers/llama/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 0562c701b63e..275d38414803 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -448,7 +448,7 @@ def forward( assert past_key_value is None, "fuse rotary not support cache kv for now" batch_size, seq_length, num_heads, head_dim = query_states.shape _, kv_seq_len, num_key_value_heads, _ = key_states.shape - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) paddle_version = float(paddle.__version__[:3]) if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads): From 89fd4b616ddb420cc676faaec1d9984bafb3fb5a Mon Sep 17 00:00:00 2001 From: heavyrain_lzy <1528794076@qq.com> Date: Wed, 22 May 2024 16:19:28 +0800 Subject: [PATCH 4/4] add case for auto_parallel --- scripts/distribute/run_ci.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/distribute/run_ci.sh b/scripts/distribute/run_ci.sh index f558cde651b2..da0636186622 100644 --- a/scripts/distribute/run_ci.sh +++ b/scripts/distribute/run_ci.sh @@ -30,6 +30,7 @@ target_lists_for_llama=( "paddlenlp/trainer/auto_trainer.py" "paddlenlp/transformers/llama/modeling_auto_static.py" "paddlenlp/transformers/llama/modeling_auto.py" + "paddlenlp/transformers/llama/modeling.py" "scripts/distribute" )