From e9057986aa3c744ab360b175b09ecc12f249d1ad Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Thu, 27 Feb 2025 16:37:27 -0800 Subject: [PATCH] [sharktank] Enable f8 model perplexity tests (#1001) Enable f8 model perplexity tests in sharktank --- .github/workflows/ci_eval.yaml | 4 +- sharktank/conftest.py | 16 +- .../sharktank/evaluate/perplexity_iree.py | 73 +++++- .../sharktank/evaluate/perplexity_torch.py | 33 ++- sharktank/sharktank/utils/export_artifacts.py | 22 +- .../evaluate/baseline_perplexity_scores.json | 211 +++++++++++++++++- .../tests/evaluate/perplexity_iree_test.py | 90 +------- .../tests/evaluate/perplexity_torch_test.py | 136 ++++------- 8 files changed, 392 insertions(+), 193 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 6ccc8c15d..105cc3604 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -65,7 +65,7 @@ jobs: - name: Run perplexity test with IREE run: | source ${VENV_DIR}/bin/activate - pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/iree_perplexity/index.html --log-cli-level=INFO + pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --run-nightly-llama-tests --bs=100 --iree-device=hip://0 --iree-hip-target=gfx942 --iree-hal-target-device=hip --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --llama3-8b-f8-model-path=/shark-dev/8b/fp8/native_fp8_e4m3fnuz_llama3_8b.irpa --html=out/llm/llama/perplexity/iree_perplexity/index.html --log-cli-level=INFO ls -lha ${{ github.workspace }}/perplexity_ci_artifacts @@ -121,7 +121,7 @@ jobs: - name: Run perplexity test with Torch run: | source ${VENV_DIR}/bin/activate - pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --run-nightly-llama-tests --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html + pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --run-nightly-llama-tests --bs=100 --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 diff --git a/sharktank/conftest.py b/sharktank/conftest.py index dfbc18f08..f12fa2e8e 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -142,11 +142,11 @@ def pytest_addoption(parser): ) parser.addoption( - "--llama3-8b-fp8-model-path", + "--llama3-8b-f8-model-path", type=Path, action="store", default=None, - help="Llama3.1 8b fp8 model path", + help="Llama3.1 8b f8 model path", ) parser.addoption( @@ -164,11 +164,11 @@ def pytest_addoption(parser): ) parser.addoption( - "--llama3-405b-fp8-model-path", + "--llama3-405b-f8-model-path", type=Path, action="store", default=None, - help="Llama3.1 405b fp8 model path", + help="Llama3.1 405b f8 model path", ) # To obtain a T5 GGUF file you can use llama.cpp's convert_hf_to_gguf.py. @@ -316,8 +316,8 @@ def get_model_artifacts(request: FixtureRequest): model_path["llama3_8b_f16_model_path"] = set_fixture_from_cli_option( request, "--llama3-8b-f16-model-path", "llama3_8b_f16_model" ) - model_path["llama3_8b_fp8_model_path"] = set_fixture_from_cli_option( - request, "--llama3-8b-fp8-model-path", "llama3_8b_fp8_model" + model_path["llama3_8b_f8_model_path"] = set_fixture_from_cli_option( + request, "--llama3-8b-f8-model-path", "llama3_8b_f8_model" ) model_path["llama3_405b_tokenizer_path"] = set_fixture_from_cli_option( request, "--llama3-405b-tokenizer-path", "llama3_405b_tokenizer" @@ -325,8 +325,8 @@ def get_model_artifacts(request: FixtureRequest): model_path["llama3_405b_f16_model_path"] = set_fixture_from_cli_option( request, "--llama3-405b-f16-model-path", "llama3_405b_f16_model" ) - model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option( - request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model" + model_path["llama3_405b_f8_model_path"] = set_fixture_from_cli_option( + request, "--llama3-405b-f8-model-path", "llama3_405b_f8_model" ) model_path["google__t5_v1_1_small_f32_model_path"] = set_fixture_from_cli_option( request, diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py index fe620fd30..65beef779 100644 --- a/sharktank/sharktank/evaluate/perplexity_iree.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -34,6 +34,7 @@ from sharktank.utils.load_llm import * from sharktank.utils.create_cache import * from sharktank.utils.export_artifacts import * +from sharktank.utils.iree import iree_to_torch log_levels = { "info": logging.INFO, @@ -69,8 +70,10 @@ def __init__( attention_kernel, block_seq_stride, use_attention_mask, - activation_dtype=torch.float16, - attention_dtype=torch.float16, + activation_dtype, + attention_dtype, + kv_cache_dtype, + use_hf, ): self.torch_device = torch_device self.iree_device = iree_device @@ -79,9 +82,15 @@ def __init__( self.block_seq_stride = block_seq_stride self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype + self.kv_cache_dtype = kv_cache_dtype self.tensor_parallelism_size = tensor_parallelism_size self.attention_kernel = attention_kernel self.use_attention_mask = use_attention_mask + self.use_hf = use_hf + self.halelementtype_map = { + torch.float8_e4m3fnuz: ireert.HalElementType.FLOAT_8_E4M3_FNUZ, + torch.bfloat16: ireert.HalElementType.BFLOAT_16, + } def timeit(func): def wrapper(*args, **kwargs): @@ -133,6 +142,9 @@ def compile_model(self, weight_path_str, mlir_path, json_path, vmfb_path): logger.info(f" Model: {self.weight_path_str}") + if self.kv_cache_dtype is None: + self.kv_cache_dtype = self.attention_dtype + if vmfb_path: self.vmfb_path = vmfb_path logger.info(f" Using pre-compiled vmfb: {self.vmfb_path}") @@ -146,6 +158,10 @@ def compile_model(self, weight_path_str, mlir_path, json_path, vmfb_path): tensor_parallelism_size=self.tensor_parallelism_size, block_seq_stride=self.block_seq_stride, use_attention_mask=self.use_attention_mask, + activation_dtype=str(self.activation_dtype).split(".")[-1], + attention_dtype=str(self.attention_dtype).split(".")[-1], + kv_cache_dtype=str(self.kv_cache_dtype).split(".")[-1], + use_hf=self.use_hf, mlir_path=mlir_path, json_path=json_path, ) @@ -160,7 +176,9 @@ def load_model(self, weight_path, tokenizer): device=self.torch_device, activation_dtype=self.activation_dtype, attention_dtype=self.attention_dtype, + kv_cache_dtype=self.kv_cache_dtype, tensor_parallelism_size=self.tensor_parallelism_size, + use_hf=self.use_hf, ) if self.config.tensor_parallelism_size > 1: @@ -219,7 +237,7 @@ def prefill_vmfb(self, token_batch, i): seq_block_ids, self.cache_state, ) - + prefill_logits = iree_to_torch(prefill_logits)[0] prefill_logits = torch.tensor(prefill_logits[:, :, :]) tokens = torch.tensor( @@ -251,6 +269,7 @@ def decode_vmfb(self, token_batch, i): seq_block_ids, self.cache_state, ) + decode_logits = iree_to_torch(decode_logits)[0] decode_logits = torch.tensor(decode_logits[:, :, :]) @@ -304,9 +323,27 @@ def get_logits(self, page_cache_size): page_cache_size=page_cache_size, ) - self.cache_state = ireert.asdevicearray( - self.haldevice, self.batch.cache_state[0].to("cpu").numpy() - ) + if self.kv_cache_dtype in self.halelementtype_map.keys(): + + cache_state = self.batch.cache_state[0] + + cache_as_int16 = cache_state.to(dtype=torch.int16) + + device_array_as_int16 = ireert.asdevicearray( + self.haldevice, unbox_tensor(cache_as_int16).to("cpu").numpy() + ) + + buffer_view = ireert.HalBufferView( + buffer=device_array_as_int16._buffer_view.get_buffer(), + shape=device_array_as_int16._buffer_view.shape, + element_type=self.halelementtype_map[self.kv_cache_dtype], + ) + self.cache_state = ireert.DeviceArray(self.haldevice, buffer_view) + + else: + self.cache_state = ireert.asdevicearray( + self.haldevice, self.batch.cache_state[0].to("cpu").numpy() + ) prefill_logits = self.prefill_vmfb(token_batch, i) self.out_logits = prefill_logits[:, -1:, :] @@ -405,6 +442,10 @@ def run_perplexity( num_prompts, block_seq_stride, use_attention_mask, + activation_dtype, + attention_dtype, + kv_cache_dtype, + use_hf, mlir_path, json_path, vmfb_path, @@ -419,6 +460,10 @@ def run_perplexity( attention_kernel=attention_kernel, block_seq_stride=block_seq_stride, use_attention_mask=use_attention_mask, + activation_dtype=activation_dtype, + attention_dtype=attention_dtype, + kv_cache_dtype=kv_cache_dtype, + use_hf=use_hf, ) perplexity.get_prompts(num_prompts=num_prompts) @@ -459,6 +504,11 @@ def main(argv): default=100, help="Number of prompts for perplexity test (1 to 100)", ) + parser.add_argument( + "--use-attention-mask", + help="Generates attention mask during export", + action="store_true", + ) parser.add_argument( "--mlir-path", type=str, @@ -476,16 +526,15 @@ def main(argv): ) cli.add_model_options(parser) - cli.add_tokenizer_options(parser) cli.add_input_dataset_options(parser) + cli.add_tokenizer_options(parser) + args = cli.parse(parser, args=argv) torch_device = torch.device(args.device) if args.device else None weight_path = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) - use_attention_mask = True - if args.mlir_path or args.json_path: assert ( args.json_path is not None and args.mlir_path is not None @@ -510,7 +559,11 @@ def main(argv): attention_kernel=args.attention_kernel, num_prompts=args.num_prompts, block_seq_stride=args.block_seq_stride, - use_attention_mask=use_attention_mask, + use_attention_mask=args.use_attention_mask, + attention_dtype=args.attention_dtype, + activation_dtype=args.activation_dtype, + kv_cache_dtype=args.kv_cache_dtype, + use_hf=args.use_hf, mlir_path=args.mlir_path, json_path=args.json_path, vmfb_path=args.vmfb_path, diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index 4974de1c2..ffeec78fa 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -58,12 +58,18 @@ class Perplexity_torch: def __init__( self, device, + use_hf, + fake_quant, activation_dtype=torch.float32, attention_dtype=torch.float32, + kv_cache_dtype=torch.float32, ): self.device = device self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype + self.kv_cache_dtype = kv_cache_dtype + self.use_hf = use_hf + self.fake_quant = fake_quant def timeit(func): def wrapper(*args, **kwargs): @@ -112,11 +118,13 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), - block_seq_stride=16, device=self.device, activation_dtype=self.activation_dtype, attention_dtype=self.attention_dtype, + kv_cache_dtype=self.kv_cache_dtype, tensor_parallelism_size=tensor_parallelism_size, + use_hf=self.use_hf, + fake_quant=self.fake_quant, ) if self.config.tensor_parallelism_size > 1: @@ -298,10 +306,23 @@ def run_perplexity_torch( tensor_parallelism_size, attention_kernel, num_prompts, + activation_dtype, + attention_dtype, + kv_cache_dtype, + use_hf, + fake_quant, ): start = time.time() - perplexity = Perplexity_torch(device=device) + perplexity = Perplexity_torch( + device=device, + activation_dtype=activation_dtype, + attention_dtype=attention_dtype, + kv_cache_dtype=kv_cache_dtype, + fake_quant=fake_quant, + use_hf=use_hf, + ) + perplexity.get_prompts(num_prompts=num_prompts) perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel) ppl = perplexity.get_perplexity() @@ -330,11 +351,14 @@ def main(argv): cli.add_model_options(parser) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) + cli.add_quantization_options(parser) + args = cli.parse(parser, args=argv) device = torch.device(args.device) if args.device else None dataset = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) + # Override flag if dataset disagrees tensor_parallelism_size = ( dataset.properties["tensor_parallelism_size"] @@ -349,6 +373,11 @@ def main(argv): tensor_parallelism_size=tensor_parallelism_size, attention_kernel=args.attention_kernel, num_prompts=args.num_prompts, + attention_dtype=args.attention_dtype, + activation_dtype=args.activation_dtype, + kv_cache_dtype=args.kv_cache_dtype, + use_hf=args.use_hf, + fake_quant=args.fake_quant, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 2fbdf035e..7d838badc 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -93,6 +93,7 @@ def __init__( block_seq_stride: int, iree_hal_target_device: str, use_attention_mask: bool = False, + use_hf: bool = False, activation_dtype: str = "float16", attention_dtype: str = "float16", kv_cache_dtype: Optional[str] = None, @@ -115,6 +116,7 @@ def __init__( self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype self.kv_cache_dtype = kv_cache_dtype + self.use_hf = use_hf def timeit(func): def wrapper(*args, **kwargs): @@ -195,15 +197,17 @@ def export_to_mlir( f"--attention-dtype={self.attention_dtype}", f"--activation-dtype={self.activation_dtype}", ] + if self.kv_cache_dtype is not None: export_args.append(f"--kv-cache-dtype={self.kv_cache_dtype}") if skip_decode: export_args.append("--skip-decode") if self.attention_kernel in ["decomposed", "torch"]: - export_args.append("--attention-kernel") - export_args.append(self.attention_kernel) + export_args.append(f"--attention-kernel={self.attention_kernel}") if self.use_attention_mask: export_args.append("--use-attention-mask") + if self.use_hf: + export_args.append("--use-hf") cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) @@ -228,6 +232,7 @@ def compile_to_vmfb( hal_dump_path: Optional[Path] = None, args: Optional[List[str]] = None, ): + # TODO: Control flag to enable multiple backends compile_args = [ f"iree-compile", @@ -252,6 +257,19 @@ def compile_to_vmfb( # Append optional arguments if provided if args: compile_args += args + else: + compile_args += [ + "--iree-dispatch-creation-enable-aggressive-fusion=true", + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-opt-data-tiling=false", + "--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))'", + "--iree-stream-resource-memory-model=discrete", + "--iree-hal-indirect-command-buffers=true", + "--iree-hal-memoization=true", + "--iree-opt-strip-assertions", + ] + cmd = subprocess.list2cmdline(compile_args) logger.info(f" Launching compile command:\n" f"cd {cwd} && {cmd}") diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json index 4686c29fe..2d68a0736 100644 --- a/sharktank/tests/evaluate/baseline_perplexity_scores.json +++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json @@ -104,7 +104,111 @@ ], "mean_perplexity": 20.303255 }, - + "llama3_8B_f8_torch": { + "perplexities": [ + 7.987008, + 31.637552, + 18.19816, + 20.363277, + 16.565462, + 8.606422, + 12.116635, + 26.269148, + 9.418326, + 23.361168, + 10.572304, + 16.878731, + 11.465238, + 7.992308, + 19.318066, + 8.522518, + 8.776274, + 8.397976, + 12.157149, + 10.401586, + 11.473348, + 50.941158, + 12.210995, + 89.215263, + 51.855133, + 24.01277, + 25.304604, + 14.883923, + 11.769134, + 7.640766, + 13.312871, + 28.035353, + 17.975294, + 11.286968, + 7.158895, + 16.413784, + 6.992403, + 20.976494, + 10.470797, + 11.711576, + 9.612861, + 21.268475, + 13.660164, + 4.428795, + 14.667919, + 11.658493, + 12.790344, + 26.909279, + 57.743385, + 32.719883, + 16.844536, + 15.799707, + 9.557705, + 12.123102, + 18.276352, + 13.760847, + 6.18409, + 18.14222, + 9.32373, + 15.762177, + 10.868351, + 8.007616, + 4.479449, + 19.759375, + 60.26759, + 9.586048, + 13.494859, + 10.685859, + 11.171473, + 13.379374, + 12.508539, + 24.456158, + 9.929442, + 17.214191, + 35.396305, + 19.028091, + 20.722212, + 19.862995, + 9.629982, + 16.18185, + 13.544695, + 18.490759, + 18.346514, + 14.769532, + 17.279345, + 44.321655, + 17.641129, + 12.118256, + 17.685837, + 18.853605, + 35.741291, + 15.628735, + 37.692272, + 62.550331, + 15.223442, + 21.12615, + 16.560131, + 17.758917, + 14.72166, + 20.553099 + ], + "mean_perplexity": 18.71118 + }, "llama3_405B_f16_torch": { "perplexities": [ 2.170036, @@ -314,5 +418,110 @@ 20.146591 ], "mean_perplexity": 19.786807 + }, + "llama3_8B_f8_iree": { + "perplexities": [ + 8.319567, + 26.466255, + 17.899769, + 26.947678, + 16.996069, + 8.655825, + 12.582002, + 25.489838, + 9.40163, + 23.569002, + 10.249103, + 16.300808, + 11.28682, + 7.989717, + 19.21876, + 8.347802, + 8.758007, + 8.595772, + 12.029103, + 11.022297, + 12.31207, + 45.64912, + 11.912854, + 89.361259, + 49.282413, + 22.231857, + 22.501842, + 15.007542, + 11.839622, + 7.677801, + 13.295479, + 27.71237, + 18.459564, + 10.973207, + 7.235689, + 16.219749, + 7.276456, + 20.947447, + 10.459918, + 10.827126, + 9.388341, + 22.592453, + 13.162258, + 4.406937, + 14.98263, + 11.668829, + 14.739593, + 26.750174, + 60.646706, + 36.269802, + 17.438318, + 15.362193, + 8.857885, + 12.631617, + 18.473522, + 13.440568, + 6.248583, + 17.176323, + 9.127771, + 15.820963, + 10.712682, + 7.870706, + 4.307317, + 19.970316, + 42.918587, + 9.816761, + 11.956828, + 10.237666, + 11.071741, + 13.513668, + 11.942967, + 24.744184, + 9.270083, + 17.2152, + 31.280437, + 16.308413, + 19.988907, + 20.251242, + 9.714547, + 16.635504, + 13.890906, + 16.608139, + 17.471615, + 15.046515, + 17.39975, + 25.871479, + 19.227293, + 12.394986, + 17.33046, + 18.642508, + 35.221233, + 15.619496, + 40.129433, + 40.794601, + 15.376877, + 20.857782, + 16.339769, + 17.793709, + 14.628002, + 20.941704 + ], + "mean_perplexity": 18.018087 } } diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py index 019e4b92d..1a2db68bb 100644 --- a/sharktank/tests/evaluate/perplexity_iree_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -51,6 +51,7 @@ def test_llama3_8B_f16(self): f"--tensor-parallelism-size=1", f"--attention-kernel=torch", f"--num-prompts={self.batch_size}", + f"--use-attention-mask", ] ) @@ -69,53 +70,16 @@ def test_llama3_8B_f16(self): ) @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error") - def test_llama3_8B_fp8_decomposed(self): - - # Llama 3.1 8B decomposed - - model_name = "llama3_8B_fp8_iree" - baseline_perplexity = self.baseline_perplexity[model_name] - - current_perplexity = perplexity_iree.main( - [ - f"--irpa-file={self.llama3_8b_fp8_model}", - f"--tokenizer-config-json={self.llama3_8b_tokenizer}", - f"--iree-device={self.iree_device}", - f"--iree-hal-target-device={self.iree_hal_target_device}", - f"--iree-hip-target={self.iree_hip_target}", - f"--tensor-parallelism-size=1", - f"--attention-kernel=decomposed", - f"--num-prompts={self.batch_size}", - ] - ) - - baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 - ) - current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) - - perplexity_difference = current_mean_perplexity - baseline_mean_perplexity - - self.assertAlmostEqual( - baseline_mean_perplexity, - current_mean_perplexity, - delta=self.delta, - msg=f"Current perplexity deviates baseline by {perplexity_difference}", - ) - - @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error") - def test_llama3_8B_fp8(self): + def test_llama3_8B_f8(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_fp8_iree" + model_name = "llama3_8B_f8_iree" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_iree.main( [ - f"--irpa-file={self.llama3_8b_fp8_model}", + f"--irpa-file={self.llama3_8b_f8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--iree-device={self.iree_device}", f"--iree-hal-target-device={self.iree_hal_target_device}", @@ -123,6 +87,10 @@ def test_llama3_8B_fp8(self): f"--tensor-parallelism-size=1", f"--attention-kernel=torch", f"--num-prompts={self.batch_size}", + f"--attention-dtype=bfloat16", + f"--activation-dtype=bfloat16", + f"--kv-cache-dtype=float8_e4m3fnuz", + "--use-hf", ] ) @@ -178,52 +146,16 @@ def test_llama3_405B_f16(self): @skipif_run_quick_llama_test @pytest.mark.xfail(reason="Compile Error") - def test_llama3_405B_fp8_decomposed(self): - - # Llama 3.1 405B decomposed - - model_name = "llama3_405B_fp8_iree" - baseline_perplexity = self.baseline_perplexity[model_name] - - current_perplexity = perplexity_iree.main( - [ - f"--irpa-file={self.llama3_405b_fp8_model}", - f"--tokenizer-config-json={self.llama3_405b_tokenizer}", - f"--iree-device={self.iree_device}", - f"--iree-hal-target-device={self.iree_hal_target_device}", - f"--iree-hip-target={self.iree_hip_target}", - f"--tensor-parallelism-size={self.tensor_parallelism_size}", - f"--attention-kernel=decomposed", - f"--num-prompts={self.batch_size}", - ] - ) - - baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 - ) - current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) - - perplexity_difference = current_mean_perplexity - baseline_mean_perplexity - - self.assertAlmostEqual( - baseline_mean_perplexity, - current_mean_perplexity, - delta=self.delta, - msg=f"Current perplexity deviates baseline by {perplexity_difference}", - ) - - @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error") - def test_llama3_405B_fp8(self): + def test_llama3_405B_f8(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_fp8_iree" + model_name = "llama3_405B_f8_iree" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_iree.main( [ - f"--irpa-file={self.llama3_405b_fp8_model}", + f"--irpa-file={self.llama3_405b_f8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--iree-device={self.iree_device}", f"--iree-hal-target-device={self.iree_hal_target_device}", diff --git a/sharktank/tests/evaluate/perplexity_torch_test.py b/sharktank/tests/evaluate/perplexity_torch_test.py index 3229ff840..bf55912db 100644 --- a/sharktank/tests/evaluate/perplexity_torch_test.py +++ b/sharktank/tests/evaluate/perplexity_torch_test.py @@ -7,6 +7,7 @@ import unittest import pytest import json +import numpy as np from sharktank.evaluate import perplexity_torch @@ -17,7 +18,10 @@ @pytest.mark.usefixtures( - "get_model_artifacts", "tensor_parallelism_size", "baseline_perplexity_scores" + "get_model_artifacts", + "tensor_parallelism_size", + "baseline_perplexity_scores", + "batch_size", ) class PerplexityTest(unittest.TestCase): def setUp(self): @@ -40,78 +44,57 @@ def test_llama3_8B_f16(self): f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--attention-kernel=torch", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) - self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], - delta=self.delta, - msg=f"Current perplexity deviates baseline by {perplexity_difference}", - ) - - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @skipif_run_quick_llama_test - def test_llama3_8B_fp8_decomposed(self): - - # Llama 3.1 8B decomposed - - model_name = "llama3_8B_fp8_torch" - baseline_perplexity = self.baseline_perplexity[model_name] - - current_perplexity = perplexity_torch.main( - [ - f"--irpa-file={self.llama3_8b_fp8_model}", - f"--tokenizer-config-json={self.llama3_8b_tokenizer}", - ] - ) - - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] - ) + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) @skipif_run_quick_llama_test - def test_llama3_8B_fp8(self): + def test_llama3_8B_f8(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_fp8_torch" + model_name = "llama3_8B_f8_torch" baseline_perplexity = self.baseline_perplexity[model_name] + batch_size = 8 + current_perplexity = perplexity_torch.main( [ - f"--irpa-file={self.llama3_8b_fp8_model}", + f"--irpa-file={self.llama3_8b_f8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", f"--attention-kernel=torch", + f"--num-prompts={batch_size}", + f"--attention-dtype=bfloat16", + f"--activation-dtype=bfloat16", + "--use-hf", + "--fake-quant", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0:batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -133,48 +116,20 @@ def test_llama3_405B_f16(self): f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] - ) - - self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], - delta=self.delta, - msg=f"Current perplexity deviates baseline by {perplexity_difference}", - ) - - @pytest.mark.xfail( - reason="FP8 model is unsupported", - ) - @skipif_run_quick_llama_test - def test_llama3_405B_fp8_decomposed(self): - - # Llama 3.1 405B decomposed - - model_name = "llama3_405B_fp8_torch" - baseline_perplexity = self.baseline_perplexity[model_name] - - current_perplexity = perplexity_torch.main( - [ - f"--irpa-file={self.llama3_405b_fp8_model}", - f"--tokenizer-config-json={self.llama3_405b_tokenizer}", - f"--tensor-parallelism-size={self.tensor_parallelism_size}", - ] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] - ) + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) @@ -183,30 +138,33 @@ def test_llama3_405B_fp8_decomposed(self): reason="Non-decomposed attention is not supported yet", ) @skipif_run_quick_llama_test - def test_llama3_405B_fp8(self): + def test_llama3_405B_f8(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_fp8_torch" + model_name = "llama3_405B_f8_torch" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_torch.main( [ - f"--irpa-file={self.llama3_405b_fp8_model}", + f"--irpa-file={self.llama3_405b_f8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", f"--attention-kernel=torch", + f"--num-prompts={self.batch_size}", ] ) - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] + baseline_mean_perplexity = round( + np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 ) + current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) + + perplexity_difference = current_mean_perplexity - baseline_mean_perplexity self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], + baseline_mean_perplexity, + current_mean_perplexity, delta=self.delta, msg=f"Current perplexity deviates baseline by {perplexity_difference}", )