Skip to content

Commit

Permalink
Merge branch 'main' into integrates/iree
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida authored Feb 28, 2025
2 parents b4bfb08 + e905798 commit 21424ef
Show file tree
Hide file tree
Showing 8 changed files with 392 additions and 193 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
- name: Run perplexity test with IREE
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html --log-cli-level=INFO
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --llama3-8b-f8-model-path=/shark-dev/8b/fp8/native_fp8_e4m3fnuz_llama3_8b.irpa --html=out/llm/llama/perplexity/iree_perplexity/index.html --log-cli-level=INFO
ls -lha ${{ github.workspace }}/perplexity_ci_artifacts
Expand Down Expand Up @@ -121,7 +121,7 @@ jobs:
- name: Run perplexity test with Torch
run: |
source ${VENV_DIR}/bin/activate
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --run-nightly-llama-tests --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --run-nightly-llama-tests --bs=100 --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
16 changes: 8 additions & 8 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def pytest_addoption(parser):
)

parser.addoption(
"--llama3-8b-fp8-model-path",
"--llama3-8b-f8-model-path",
type=Path,
action="store",
default=None,
help="Llama3.1 8b fp8 model path",
help="Llama3.1 8b f8 model path",
)

parser.addoption(
Expand All @@ -164,11 +164,11 @@ def pytest_addoption(parser):
)

parser.addoption(
"--llama3-405b-fp8-model-path",
"--llama3-405b-f8-model-path",
type=Path,
action="store",
default=None,
help="Llama3.1 405b fp8 model path",
help="Llama3.1 405b f8 model path",
)

# To obtain a T5 GGUF file you can use llama.cpp's convert_hf_to_gguf.py.
Expand Down Expand Up @@ -316,17 +316,17 @@ def get_model_artifacts(request: FixtureRequest):
model_path["llama3_8b_f16_model_path"] = set_fixture_from_cli_option(
request, "--llama3-8b-f16-model-path", "llama3_8b_f16_model"
)
model_path["llama3_8b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-8b-fp8-model-path", "llama3_8b_fp8_model"
model_path["llama3_8b_f8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-8b-f8-model-path", "llama3_8b_f8_model"
)
model_path["llama3_405b_tokenizer_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-tokenizer-path", "llama3_405b_tokenizer"
)
model_path["llama3_405b_f16_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-f16-model-path", "llama3_405b_f16_model"
)
model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
model_path["llama3_405b_f8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-f8-model-path", "llama3_405b_f8_model"
)
model_path["google__t5_v1_1_small_f32_model_path"] = set_fixture_from_cli_option(
request,
Expand Down
73 changes: 63 additions & 10 deletions sharktank/sharktank/evaluate/perplexity_iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sharktank.utils.load_llm import *
from sharktank.utils.create_cache import *
from sharktank.utils.export_artifacts import *
from sharktank.utils.iree import iree_to_torch

log_levels = {
"info": logging.INFO,
Expand Down Expand Up @@ -69,8 +70,10 @@ def __init__(
attention_kernel,
block_seq_stride,
use_attention_mask,
activation_dtype=torch.float16,
attention_dtype=torch.float16,
activation_dtype,
attention_dtype,
kv_cache_dtype,
use_hf,
):
self.torch_device = torch_device
self.iree_device = iree_device
Expand All @@ -79,9 +82,15 @@ def __init__(
self.block_seq_stride = block_seq_stride
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype
self.kv_cache_dtype = kv_cache_dtype
self.tensor_parallelism_size = tensor_parallelism_size
self.attention_kernel = attention_kernel
self.use_attention_mask = use_attention_mask
self.use_hf = use_hf
self.halelementtype_map = {
torch.float8_e4m3fnuz: ireert.HalElementType.FLOAT_8_E4M3_FNUZ,
torch.bfloat16: ireert.HalElementType.BFLOAT_16,
}

def timeit(func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -133,6 +142,9 @@ def compile_model(self, weight_path_str, mlir_path, json_path, vmfb_path):

logger.info(f" Model: {self.weight_path_str}")

if self.kv_cache_dtype is None:
self.kv_cache_dtype = self.attention_dtype

if vmfb_path:
self.vmfb_path = vmfb_path
logger.info(f" Using pre-compiled vmfb: {self.vmfb_path}")
Expand All @@ -146,6 +158,10 @@ def compile_model(self, weight_path_str, mlir_path, json_path, vmfb_path):
tensor_parallelism_size=self.tensor_parallelism_size,
block_seq_stride=self.block_seq_stride,
use_attention_mask=self.use_attention_mask,
activation_dtype=str(self.activation_dtype).split(".")[-1],
attention_dtype=str(self.attention_dtype).split(".")[-1],
kv_cache_dtype=str(self.kv_cache_dtype).split(".")[-1],
use_hf=self.use_hf,
mlir_path=mlir_path,
json_path=json_path,
)
Expand All @@ -160,7 +176,9 @@ def load_model(self, weight_path, tokenizer):
device=self.torch_device,
activation_dtype=self.activation_dtype,
attention_dtype=self.attention_dtype,
kv_cache_dtype=self.kv_cache_dtype,
tensor_parallelism_size=self.tensor_parallelism_size,
use_hf=self.use_hf,
)

if self.config.tensor_parallelism_size > 1:
Expand Down Expand Up @@ -219,7 +237,7 @@ def prefill_vmfb(self, token_batch, i):
seq_block_ids,
self.cache_state,
)

prefill_logits = iree_to_torch(prefill_logits)[0]
prefill_logits = torch.tensor(prefill_logits[:, :, :])

tokens = torch.tensor(
Expand Down Expand Up @@ -251,6 +269,7 @@ def decode_vmfb(self, token_batch, i):
seq_block_ids,
self.cache_state,
)
decode_logits = iree_to_torch(decode_logits)[0]

decode_logits = torch.tensor(decode_logits[:, :, :])

Expand Down Expand Up @@ -304,9 +323,27 @@ def get_logits(self, page_cache_size):
page_cache_size=page_cache_size,
)

self.cache_state = ireert.asdevicearray(
self.haldevice, self.batch.cache_state[0].to("cpu").numpy()
)
if self.kv_cache_dtype in self.halelementtype_map.keys():

cache_state = self.batch.cache_state[0]

cache_as_int16 = cache_state.to(dtype=torch.int16)

device_array_as_int16 = ireert.asdevicearray(
self.haldevice, unbox_tensor(cache_as_int16).to("cpu").numpy()
)

buffer_view = ireert.HalBufferView(
buffer=device_array_as_int16._buffer_view.get_buffer(),
shape=device_array_as_int16._buffer_view.shape,
element_type=self.halelementtype_map[self.kv_cache_dtype],
)
self.cache_state = ireert.DeviceArray(self.haldevice, buffer_view)

else:
self.cache_state = ireert.asdevicearray(
self.haldevice, self.batch.cache_state[0].to("cpu").numpy()
)

prefill_logits = self.prefill_vmfb(token_batch, i)
self.out_logits = prefill_logits[:, -1:, :]
Expand Down Expand Up @@ -405,6 +442,10 @@ def run_perplexity(
num_prompts,
block_seq_stride,
use_attention_mask,
activation_dtype,
attention_dtype,
kv_cache_dtype,
use_hf,
mlir_path,
json_path,
vmfb_path,
Expand All @@ -419,6 +460,10 @@ def run_perplexity(
attention_kernel=attention_kernel,
block_seq_stride=block_seq_stride,
use_attention_mask=use_attention_mask,
activation_dtype=activation_dtype,
attention_dtype=attention_dtype,
kv_cache_dtype=kv_cache_dtype,
use_hf=use_hf,
)

perplexity.get_prompts(num_prompts=num_prompts)
Expand Down Expand Up @@ -459,6 +504,11 @@ def main(argv):
default=100,
help="Number of prompts for perplexity test (1 to 100)",
)
parser.add_argument(
"--use-attention-mask",
help="Generates attention mask during export",
action="store_true",
)
parser.add_argument(
"--mlir-path",
type=str,
Expand All @@ -476,16 +526,15 @@ def main(argv):
)

cli.add_model_options(parser)
cli.add_tokenizer_options(parser)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)

args = cli.parse(parser, args=argv)

torch_device = torch.device(args.device) if args.device else None
weight_path = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)

use_attention_mask = True

if args.mlir_path or args.json_path:
assert (
args.json_path is not None and args.mlir_path is not None
Expand All @@ -510,7 +559,11 @@ def main(argv):
attention_kernel=args.attention_kernel,
num_prompts=args.num_prompts,
block_seq_stride=args.block_seq_stride,
use_attention_mask=use_attention_mask,
use_attention_mask=args.use_attention_mask,
attention_dtype=args.attention_dtype,
activation_dtype=args.activation_dtype,
kv_cache_dtype=args.kv_cache_dtype,
use_hf=args.use_hf,
mlir_path=args.mlir_path,
json_path=args.json_path,
vmfb_path=args.vmfb_path,
Expand Down
33 changes: 31 additions & 2 deletions sharktank/sharktank/evaluate/perplexity_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,18 @@ class Perplexity_torch:
def __init__(
self,
device,
use_hf,
fake_quant,
activation_dtype=torch.float32,
attention_dtype=torch.float32,
kv_cache_dtype=torch.float32,
):
self.device = device
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype
self.kv_cache_dtype = kv_cache_dtype
self.use_hf = use_hf
self.fake_quant = fake_quant

def timeit(func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -112,11 +118,13 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern

self.config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
device=self.device,
activation_dtype=self.activation_dtype,
attention_dtype=self.attention_dtype,
kv_cache_dtype=self.kv_cache_dtype,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=self.use_hf,
fake_quant=self.fake_quant,
)

if self.config.tensor_parallelism_size > 1:
Expand Down Expand Up @@ -298,10 +306,23 @@ def run_perplexity_torch(
tensor_parallelism_size,
attention_kernel,
num_prompts,
activation_dtype,
attention_dtype,
kv_cache_dtype,
use_hf,
fake_quant,
):
start = time.time()

perplexity = Perplexity_torch(device=device)
perplexity = Perplexity_torch(
device=device,
activation_dtype=activation_dtype,
attention_dtype=attention_dtype,
kv_cache_dtype=kv_cache_dtype,
fake_quant=fake_quant,
use_hf=use_hf,
)

perplexity.get_prompts(num_prompts=num_prompts)
perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel)
ppl = perplexity.get_perplexity()
Expand Down Expand Up @@ -330,11 +351,14 @@ def main(argv):
cli.add_model_options(parser)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
cli.add_quantization_options(parser)

args = cli.parse(parser, args=argv)

device = torch.device(args.device) if args.device else None
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)

# Override flag if dataset disagrees
tensor_parallelism_size = (
dataset.properties["tensor_parallelism_size"]
Expand All @@ -349,6 +373,11 @@ def main(argv):
tensor_parallelism_size=tensor_parallelism_size,
attention_kernel=args.attention_kernel,
num_prompts=args.num_prompts,
attention_dtype=args.attention_dtype,
activation_dtype=args.activation_dtype,
kv_cache_dtype=args.kv_cache_dtype,
use_hf=args.use_hf,
fake_quant=args.fake_quant,
)

logger.info(f"\n{json.dumps(ppl, indent=2)}")
Expand Down
22 changes: 20 additions & 2 deletions sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
block_seq_stride: int,
iree_hal_target_device: str,
use_attention_mask: bool = False,
use_hf: bool = False,
activation_dtype: str = "float16",
attention_dtype: str = "float16",
kv_cache_dtype: Optional[str] = None,
Expand All @@ -115,6 +116,7 @@ def __init__(
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype
self.kv_cache_dtype = kv_cache_dtype
self.use_hf = use_hf

def timeit(func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -195,15 +197,17 @@ def export_to_mlir(
f"--attention-dtype={self.attention_dtype}",
f"--activation-dtype={self.activation_dtype}",
]

if self.kv_cache_dtype is not None:
export_args.append(f"--kv-cache-dtype={self.kv_cache_dtype}")
if skip_decode:
export_args.append("--skip-decode")
if self.attention_kernel in ["decomposed", "torch"]:
export_args.append("--attention-kernel")
export_args.append(self.attention_kernel)
export_args.append(f"--attention-kernel={self.attention_kernel}")
if self.use_attention_mask:
export_args.append("--use-attention-mask")
if self.use_hf:
export_args.append("--use-hf")

cwd = self.sharktank_dir
cmd = subprocess.list2cmdline(export_args)
Expand All @@ -228,6 +232,7 @@ def compile_to_vmfb(
hal_dump_path: Optional[Path] = None,
args: Optional[List[str]] = None,
):

# TODO: Control flag to enable multiple backends
compile_args = [
f"iree-compile",
Expand All @@ -252,6 +257,19 @@ def compile_to_vmfb(
# Append optional arguments if provided
if args:
compile_args += args
else:
compile_args += [
"--iree-dispatch-creation-enable-aggressive-fusion=true",
"--iree-global-opt-propagate-transposes=true",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-opt-data-tiling=false",
"--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'",
"--iree-stream-resource-memory-model=discrete",
"--iree-hal-indirect-command-buffers=true",
"--iree-hal-memoization=true",
"--iree-opt-strip-assertions",
]

cmd = subprocess.list2cmdline(compile_args)

logger.info(f" Launching compile command:\n" f"cd {cwd} && {cmd}")
Expand Down
Loading

0 comments on commit 21424ef

Please sign in to comment.