Skip to content

Commit

Permalink
[GCU] Support llama for GCU
Browse files Browse the repository at this point in the history
  • Loading branch information
EnflameGCU committed May 16, 2024
1 parent 5170664 commit 32d66ef
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 14 deletions.
12 changes: 9 additions & 3 deletions examples/benchmark/wiki_lambada/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def get_parser():
"--device",
type=str,
default="gpu",
choices=["cpu", "eval_pathgpu", "xpu", "npu"],
help="select cpu, gpu, xpu devices.",
choices=["cpu", "gpu", "xpu", "npu", "gcu"],
help="select cpu, gpu, xpu, gcu devices.",
)
parser.add_argument(
"--dtype",
Expand All @@ -67,6 +67,12 @@ def get_parser():
choices=["bfloat16", "float16", "float32"],
help="set the dtype of model",
)
parser.add_argument(
"--use_flash_attention",
type=bool,
default=False,
help="Whether to use flash attention",
)

# load autodist name files, eg: bloom-176b
parser.add_argument("--load_autodist", action="store_true", help="whether load auto-dist wieght file")
Expand Down Expand Up @@ -316,7 +322,7 @@ def do_generation():
tensor_parallel_output=False,
tensor_parallel_degree=args.tensor_parallel_degree,
tensor_parallel_rank=paddle.distributed.get_rank(),
use_flash_attention=False,
use_flash_attention=args.use_flash_attention,
dtype=args.dtype, # todo enable set dtype to avoid additional mem usage
)

Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,8 @@ def sample(
probs = TopKProcess(probs, top_k, min_tokens_to_keep)
if top_p is not None and top_p < 1.0:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)
if paddle.device.is_compiled_with_custom_device("gcu"):
probs = paddle.cast(probs, "float32")

Check warning on line 1212 in paddlenlp/generation/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/generation/utils.py#L1212

Added line #L1212 was not covered by tests

# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
next_tokens = paddle.multinomial(probs)
Expand Down
24 changes: 21 additions & 3 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def swiglu(x, y=None):
except ImportError:
fused_rotary_position_embedding = None
try:
if get_env_device() == "npu":
if get_env_device() in ["npu", "gcu"]:
from paddle.base import core

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
Expand All @@ -53,13 +53,18 @@ def swiglu(x, y=None):


def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb):
assert past_key_value is None, "fuse rotary not support cache kv for now"
if get_env_device() != "gcu":
assert past_key_value is None, "fuse rotary not support cache kv for now"

Check warning on line 57 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L56-L57

Added lines #L56 - L57 were not covered by tests
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, cos_sin = rotary_emb(value_states, seq_len=kv_seq_len)

Check warning on line 60 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L60

Added line #L60 was not covered by tests
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
elif get_env_device() == "gcu":
query_states, key_states = core.eager._run_custom_op(

Check warning on line 65 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L64-L65

Added lines #L64 - L65 were not covered by tests
"fused_rotary_embedding_gcu", query_states, key_states, cos_sin, position_ids, True
)
else:
# paddle version > 2.6 or develop support q and k/v with different num_heads
paddle_version = float(paddle.__version__[:3])
Expand Down Expand Up @@ -103,6 +108,8 @@ def rms_norm_fused(x_in, w, eps):
def fusion_rms_norm(hidden_states, weight, variance_epsilon):
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "gcu":
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]

Check warning on line 112 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L111-L112

Added lines #L111 - L112 were not covered by tests
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821
Expand Down Expand Up @@ -158,6 +165,17 @@ def fusion_flash_attention(
False,
npu_is_casual,
)[0]
elif get_env_device() == "gcu":
attn_output = core.eager._run_custom_op(

Check warning on line 169 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L168-L169

Added lines #L168 - L169 were not covered by tests
"fused_sdp_flash_attention_gcu",
query_states,
key_states,
value_states,
attention_mask,
0.0,
attention_mask is None,
True,
)[0]
else:
attn_output = F.scaled_dot_product_attention(
query_states,
Expand Down
25 changes: 17 additions & 8 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def swiglu(x, y=None):
)

try:
if get_env_device() == "npu":
if get_env_device() in ["npu", "gcu"]:

for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")):
if lib.endswith(".so"):
Expand Down Expand Up @@ -410,6 +410,7 @@ def _set_cos_sin_cache(self, seq_len):
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
Expand All @@ -418,6 +419,9 @@ def forward(self, x, seq_len=None):
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
self.cos_sin_table.cast(x.dtype)
if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype
else self.cos_sin_table,
)


Expand All @@ -439,6 +443,7 @@ def _set_cos_sin_cache(self, seq_len):
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)

Check warning on line 446 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L446

Added line #L446 was not covered by tests


class LlamaNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down Expand Up @@ -471,19 +476,23 @@ def _scale_cos_sin(self, seq_len):
# [1, seqlen, 1, dim]
scale_cos = emb.cos()[None, :, None, :]
scale_sin = emb.sin()[None, :, None, :]
return scale_cos, scale_sin
scale_cos_sin = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1)
return scale_cos, scale_sin, scale_cos_sin

Check warning on line 480 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L479-L480

Added lines #L479 - L480 were not covered by tests

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_position_embeddings:
scale_cos, scale_sin = self._scale_cos_sin(seq_len=seq_len)
scale_cos, scale_sin, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len)

Check warning on line 485 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L485

Added line #L485 was not covered by tests
else:
scale_cos, scale_sin = self.cos_cached, self.sin_cached
scale_cos, scale_sin, scale_cos_sin = self.cos_cached, self.sin_cached, self.cos_sin_table

Check warning on line 487 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L487

Added line #L487 was not covered by tests
cos = scale_cos[:, :seq_len, :, ...]
sin = scale_sin[:, :seq_len, :, ...]
return (
cos.cast(x.dtype) if cos.dtype != x.dtype else cos,
sin.cast(x.dtype) if sin.dtype != x.dtype else sin,
scale_cos_sin.cast(x.dtype)
if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype
else scale_cos_sin,
)


Expand Down Expand Up @@ -638,7 +647,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]:
if self.use_fused_rope and get_env_device() not in ["npu", "xpu", "gcu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand Down Expand Up @@ -934,7 +943,7 @@ def forward(
sin.cast(value_states.dtype) if sin.dtype != value_states.dtype else sin,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin, _ = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

Expand Down Expand Up @@ -1398,7 +1407,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = expanded_attn_mask.astype("float32")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
elif get_env_device() in ["xpu", "gcu"]:
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
Expand Down Expand Up @@ -1528,7 +1537,7 @@ def forward(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
is_casual = False
if self.config.use_flash_attention:
if self.config.use_flash_attention and get_env_device() != "gcu":
is_casual = is_casual_mask(attention_mask)
if get_env_device() != "npu":
if is_casual and alibi is None:
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def get_env_device():
return "gpu"
elif "npu" in paddle.device.get_all_custom_device_type():
return "npu"
elif "gcu" in paddle.device.get_all_custom_device_type():
return "gcu"

Check warning on line 128 in paddlenlp/utils/tools.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/utils/tools.py#L128

Added line #L128 was not covered by tests
elif paddle.is_compiled_with_rocm():
return "rocm"
elif paddle.is_compiled_with_xpu():
Expand Down

0 comments on commit 32d66ef

Please sign in to comment.