diff --git a/examples/benchmark/wiki_lambada/eval.py b/examples/benchmark/wiki_lambada/eval.py index cd1c572d1972..e2e32b319784 100644 --- a/examples/benchmark/wiki_lambada/eval.py +++ b/examples/benchmark/wiki_lambada/eval.py @@ -57,8 +57,8 @@ def get_parser(): "--device", type=str, default="gpu", - choices=["cpu", "eval_pathgpu", "xpu", "npu"], - help="select cpu, gpu, xpu devices.", + choices=["cpu", "gpu", "xpu", "npu", "gcu"], + help="select cpu, gpu, xpu, gcu devices.", ) parser.add_argument( "--dtype", @@ -67,6 +67,12 @@ def get_parser(): choices=["bfloat16", "float16", "float32"], help="set the dtype of model", ) + parser.add_argument( + "--use_flash_attention", + type=bool, + default=False, + help="Whether to use flash attention", + ) # load autodist name files, eg: bloom-176b parser.add_argument("--load_autodist", action="store_true", help="whether load auto-dist wieght file") @@ -316,7 +322,7 @@ def do_generation(): tensor_parallel_output=False, tensor_parallel_degree=args.tensor_parallel_degree, tensor_parallel_rank=paddle.distributed.get_rank(), - use_flash_attention=False, + use_flash_attention=args.use_flash_attention, dtype=args.dtype, # todo enable set dtype to avoid additional mem usage ) diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index 65830c6ca244..0391647ab65e 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -1208,6 +1208,8 @@ def sample( probs = TopKProcess(probs, top_k, min_tokens_to_keep) if top_p is not None and top_p < 1.0: probs = TopPProcess(probs, top_p, min_tokens_to_keep) + if paddle.device.is_compiled_with_custom_device("gcu"): + probs = paddle.cast(probs, "float32") # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852 next_tokens = paddle.multinomial(probs) diff --git a/paddlenlp/transformers/llama/fusion_ops.py b/paddlenlp/transformers/llama/fusion_ops.py index 96f160534fc7..f9cdf8547dfd 100644 --- a/paddlenlp/transformers/llama/fusion_ops.py +++ b/paddlenlp/transformers/llama/fusion_ops.py @@ -41,7 +41,7 @@ def swiglu(x, y=None): except ImportError: fused_rotary_position_embedding = None try: - if get_env_device() == "npu": + if get_env_device() in ["npu", "gcu"]: from paddle.base import core for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): @@ -53,13 +53,18 @@ def swiglu(x, y=None): def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb): - assert past_key_value is None, "fuse rotary not support cache kv for now" + if get_env_device() != "gcu": + assert past_key_value is None, "fuse rotary not support cache kv for now" batch_size, seq_length, num_heads, head_dim = query_states.shape _, kv_seq_len, num_key_value_heads, _ = key_states.shape - cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len) if get_env_device() == "npu": query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0] key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0] + elif get_env_device() == "gcu": + query_states, key_states = core.eager._run_custom_op( + "fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True + ) else: # paddle version > 2.6 or develop support q and k/v with different num_heads paddle_version = float(paddle.__version__[:3]) @@ -103,6 +108,8 @@ def rms_norm_fused(x_in, w, eps): def fusion_rms_norm(hidden_states, weight, variance_epsilon): if get_env_device() == "npu": return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "gcu": + return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0] elif get_env_device() == "xpu": try: import paddle_xpu_nn # noqa: F821 @@ -158,6 +165,17 @@ def fusion_flash_attention( False, npu_is_casual, )[0] + elif get_env_device() == "gcu": + attn_output = core.eager._run_custom_op( + "fused_sdp_flash_attention_gcu", + query_states, + key_states, + value_states, + attention_mask, + 0.0, + attention_mask is None, + True, + )[0] else: attn_output = F.scaled_dot_product_attention( query_states, diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 320f6b4b5a54..8f2dd1c36415 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -79,7 +79,7 @@ def swiglu(x, y=None): ) try: - if get_env_device() == "npu": + if get_env_device() in ["npu", "gcu"]: for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): if lib.endswith(".so"): @@ -410,6 +410,7 @@ def _set_cos_sin_cache(self, seq_len): # [1, seqlen, 1, dim] self.cos_cached = emb.cos()[None, :, None, :] self.sin_cached = emb.sin()[None, :, None, :] + self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -418,6 +419,9 @@ def forward(self, x, seq_len=None): return ( cos.cast(x.dtype) if cos.dtype != x.dtype else cos, sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + self.cos_sin_table.cast(x.dtype) + if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype + else self.cos_sin_table, ) @@ -439,6 +443,7 @@ def _set_cos_sin_cache(self, seq_len): # [1, seqlen, 1, dim] self.cos_cached = emb.cos()[None, :, None, :] self.sin_cached = emb.sin()[None, :, None, :] + self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) class LlamaNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): @@ -471,19 +476,23 @@ def _scale_cos_sin(self, seq_len): # [1, seqlen, 1, dim] scale_cos = emb.cos()[None, :, None, :] scale_sin = emb.sin()[None, :, None, :] - return scale_cos, scale_sin + scale_cos_sin = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) + return scale_cos, scale_sin, scale_cos_sin def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_position_embeddings: - scale_cos, scale_sin = self._scale_cos_sin(seq_len=seq_len) + scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len) else: - scale_cos, scale_sin = self.cos_cached, self.sin_cached + scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table cos = scale_cos[:, :seq_len, :, ...] sin = scale_sin[:, :seq_len, :, ...] return ( cos.cast(x.dtype) if cos.dtype != x.dtype else cos, sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + scale_cos_sin.cast(x.dtype) + if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype + else scale_cos_sin, ) @@ -638,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ) self.use_fused_rope = config.use_fused_rope - if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]: + if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]: if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: warnings.warn( "Enable fuse rope in the config, but fuse rope is not available. " @@ -934,7 +943,7 @@ def forward( sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin, ) else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -1398,7 +1407,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") expanded_attn_mask = expanded_attn_mask.astype("float32") expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) - elif get_env_device() == "xpu": + elif get_env_device() in ["xpu", "gcu"]: x = paddle.to_tensor(0.0, dtype=dtype) y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) expanded_attn_mask = expanded_attn_mask.astype(dtype) @@ -1528,7 +1537,7 @@ def forward( attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype ) # [bs, 1, seq_len, seq_len] is_casual = False - if self.config.use_flash_attention: + if self.config.use_flash_attention and get_env_device() != "gcu": is_casual = is_casual_mask(attention_mask) if get_env_device() != "npu": if is_casual and alibi is None: diff --git a/paddlenlp/utils/tools.py b/paddlenlp/utils/tools.py index 256381a17de5..8f7b90f1591a 100644 --- a/paddlenlp/utils/tools.py +++ b/paddlenlp/utils/tools.py @@ -124,6 +124,8 @@ def get_env_device(): return "gpu" elif "npu" in paddle.device.get_all_custom_device_type(): return "npu" + elif "gcu" in paddle.device.get_all_custom_device_type(): + return "gcu" elif paddle.is_compiled_with_rocm(): return "rocm" elif paddle.is_compiled_with_xpu():