From 3273c83fac21ba9d789c1c7f3769b98dca3ec4e0 Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Wed, 26 Feb 2025 12:16:40 -0800 Subject: [PATCH] [sharktank] Update perplexity README and enable torch attention-kernel (#1002) - Update perplexity instructions in README - Remove decomposed tests and enable non-decomposed perplexity CIs - Allow passing custom mlir/vmfb to perplexity script --- .github/workflows/ci_eval.yaml | 2 +- docs/model_cookbook.md | 13 +- sharktank/conftest.py | 10 +- sharktank/sharktank/evaluate/README.md | 70 +++++- .../sharktank/evaluate/perplexity_iree.py | 67 ++++-- sharktank/sharktank/utils/cli.py | 4 +- sharktank/sharktank/utils/export_artifacts.py | 49 +++-- .../evaluate/baseline_perplexity_scores.json | 208 +++++++++--------- .../tests/evaluate/perplexity_iree_test.py | 78 +------ .../tests/evaluate/perplexity_torch_test.py | 98 ++------- 10 files changed, 269 insertions(+), 330 deletions(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index cc49e660b..6ccc8c15d 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -121,7 +121,7 @@ jobs: - name: Run perplexity test with Torch run: | source ${VENV_DIR}/bin/activate - pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html + pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --run-nightly-llama-tests --llama3-8b-f16-model-path=/shark-dev/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa --llama3-8b-tokenizer-path=/shark-dev/data/llama3.1/weights/8b/fp16/tokenizer_config.json --html=out/llm/llama/perplexity/torch_perplexity/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md index 03e625b96..1bc541d70 100644 --- a/docs/model_cookbook.md +++ b/docs/model_cookbook.md @@ -256,18 +256,7 @@ iree-run-module \ ## Evaluation pipeline -Run perplexity test: - -```bash -pytest sharktank/tests/evaluate/perplexity_test.py --longrun -``` - -Run perplexity for a new model: -```bash -python -m sharktank.evaluate.perplexity \ - --gguf-file=llama8b_f16.gguf \ - --tokenizer-config-json=tokenizer_config.json -``` +[Instructions](../sharktank/sharktank/evaluate/README.md) to run perplexity test ## Generating data for llama models diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 5aae72d41..dfbc18f08 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -64,20 +64,12 @@ def pytest_addoption(parser): help="Load cached results if present instead of recomputing.", ) - parser.addoption( - "--longrun", - action="store_true", - dest="longrun", - default=False, - help="Enable long tests", - ) - parser.addoption( "--run-quick-llama-test", action="store_true", dest="run-quick-llama-test", default=False, - help="Enable llama 8b f16 decomposed benchmarking test", + help="Run large llama tests if passed", ) parser.addoption( diff --git a/sharktank/sharktank/evaluate/README.md b/sharktank/sharktank/evaluate/README.md index beb0281cd..90ab6f235 100644 --- a/sharktank/sharktank/evaluate/README.md +++ b/sharktank/sharktank/evaluate/README.md @@ -13,28 +13,74 @@ Perplexity score measures the ability of a language model to predict the next to In SHARK-Platform, we use perplexity to track code regressions and quality loss across quantized models (with FP16 as baseline). We use 100 prompts randomly selected from the Wikitext-2 test set and calculate the mean perplexities shown below. These numbers are neither comparable between models with different tokenizers nor with other projects due to varying implementations. -* Test perplexity for Llama3.1 8B (FP16) model: +Perplexity script takes a given `--irpa-file` or `--gguf-file`, exports and compiles it in order to calculate the perplexity. There are options to pass a custom `--mlir-path` or `--vmfb-path` too. +#### Run perplexity +For Llama3.1 8B (FP16) model on a MI300 server: +##### Torch mode ```bash -pytest sharktank/tests/evaluate/perplexity_test.py --longrun +pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py -k test_llama3_8B_f16 \ + --llama3-8b-f16-model-path=llama3.1_8b_instruct_fp16.irpa \ + --llama3-8b-tokenizer-path=tokenizer_config.json \ + --bs=4 \ + --run-nightly-llama-tests ``` -* Calculate perplexity for a new model: +##### IREE mode +```bash +pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py -k test_llama3_8B_f16 \ + --llama3-8b-f16-model-path=llama3.1_8b_instruct_fp16.irpa \ + --llama3-8b-tokenizer-path=tokenizer_config.json \ + --bs=4 \ + --iree-device=hip://1 \ + --iree-hip-target=gfx942 \ + --iree-hal-target-device=hip +``` + +For a new model: + +Replace `--irpa-file` with `--gguf-file` flag if necessary (eg: `--gguf-file=llama3_70b_instruct_fp16.gguf`) + +##### Torch mode +```bash +python -m sharktank.evaluate.perplexity_torch \ + --irpa-file=llama3_70b_instruct_fp16.irpa \ + --tokenizer-config-json=tokenizer_config.json \ + --num-prompts=4 +``` + +##### IREE mode + +To run on MI300: +```bash +python -m sharktank.evaluate.perplexity_iree \ + --irpa-file=llama3_70b_instruct_fp16.irpa \ + --tokenizer-config-json=tokenizer_config.json \ + --num-prompts=4 \ + --iree-device='hip://0' \ + --iree-hal-target-device=hip \ + --iree-hip-target=gfx942 +``` + +To run on CPU, replace the above --iree-* flags with: +```bash + --iree-device='local-task' --iree-hal-target-device=llvm-cpu +``` +For additional options: ```bash -python -m sharktank.evaluate.perplexity \ - --gguf-file=llama3_70b_f16.gguf \ - --tokenizer-config-json=tokenizer_config.json +python -m sharktank.evaluate.perplexity_torch -h +python -m sharktank.evaluate.perplexity_iree -h ``` ### Perplexity Scoreboard -| CPU | GPU | -|:-------------: |:----------:| -| AMD EPYC 9554 | MI300X | +| CPU | GPU | Num of prompts | +|:-------------: |:----------:|:----------------:| +| AMD EPYC 9554 | MI300X | 100 | #### LLaMA 3.1 -|Models |Model size (GB) |Torch score |IREE score | -|:----------------------|:---------------|:-------------|:-------------| -|8B FP16 TP1 decomposed |16.07 |14.930181 |14.991893 | +|Models |Torch score |IREE score | Model size (GB) | +|:-------------------------------|:-------------|:-------------|:----------------| +|8B FP16 Instruct TP1 |20.303255 |19.786807 |16.07 | diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py index 05f34b5ff..fe620fd30 100644 --- a/sharktank/sharktank/evaluate/perplexity_iree.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -128,26 +128,31 @@ def print_token_comparison(self, i): logger.debug(f"{expected_token_id}") @timeit - def compile_model(self, weight_path_str): + def compile_model(self, weight_path_str, mlir_path, json_path, vmfb_path): self.weight_path_str = weight_path_str - logger.info(f" Compiling: {self.weight_path_str}") + logger.info(f" Model: {self.weight_path_str}") - export_artifacts = ExportArtifacts( - irpa_path=self.weight_path_str, - batch_size=self.bs, - iree_hip_target=self.iree_hip_target, - iree_hal_target_device=self.iree_hal_target_device, - attention_kernel=self.attention_kernel, - tensor_parallelism_size=self.tensor_parallelism_size, - block_seq_stride=self.block_seq_stride, - use_attention_mask=self.use_attention_mask, - ) - vmfb_path = export_artifacts.get_artifacts() - return vmfb_path + if vmfb_path: + self.vmfb_path = vmfb_path + logger.info(f" Using pre-compiled vmfb: {self.vmfb_path}") + else: + export_artifacts = ExportArtifacts( + irpa_path=self.weight_path_str, + batch_size=self.bs, + iree_hip_target=self.iree_hip_target, + iree_hal_target_device=self.iree_hal_target_device, + attention_kernel=self.attention_kernel, + tensor_parallelism_size=self.tensor_parallelism_size, + block_seq_stride=self.block_seq_stride, + use_attention_mask=self.use_attention_mask, + mlir_path=mlir_path, + json_path=json_path, + ) + self.vmfb_path = export_artifacts.get_artifacts() @timeit - def load_model(self, weight_path, tokenizer, vmfb_path): + def load_model(self, weight_path, tokenizer): self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(weight_path.properties), @@ -175,7 +180,7 @@ def load_model(self, weight_path, tokenizer, vmfb_path): self.runner = vmfbRunner( device=self.iree_device, - vmfb_path=vmfb_path, + vmfb_path=self.vmfb_path, external_weight_path=self.weight_path_str, ) @@ -400,6 +405,9 @@ def run_perplexity( num_prompts, block_seq_stride, use_attention_mask, + mlir_path, + json_path, + vmfb_path, ): start = time.time() perplexity = Perplexity( @@ -415,8 +423,8 @@ def run_perplexity( perplexity.get_prompts(num_prompts=num_prompts) - vmfb_path = perplexity.compile_model(weight_path_str) - perplexity.load_model(weight_path, tokenizer, vmfb_path) + perplexity.compile_model(weight_path_str, mlir_path, json_path, vmfb_path) + perplexity.load_model(weight_path, tokenizer) ppl = perplexity.get_perplexity() end = time.time() @@ -451,6 +459,21 @@ def main(argv): default=100, help="Number of prompts for perplexity test (1 to 100)", ) + parser.add_argument( + "--mlir-path", + type=str, + help="Path to exported mlir file", + ) + parser.add_argument( + "--json-path", + type=str, + help="Path to exported config json file", + ) + parser.add_argument( + "--vmfb-path", + type=str, + help="Path to compiled vmfb file", + ) cli.add_model_options(parser) cli.add_tokenizer_options(parser) @@ -463,6 +486,11 @@ def main(argv): use_attention_mask = True + if args.mlir_path or args.json_path: + assert ( + args.json_path is not None and args.mlir_path is not None + ), "If using pre-exported mlir, both --mlir-path and --json-path must be passed" + # Override flag if dataset disagrees tensor_parallelism_size = ( weight_path.properties["tensor_parallelism_size"] @@ -483,6 +511,9 @@ def main(argv): num_prompts=args.num_prompts, block_seq_stride=args.block_seq_stride, use_attention_mask=use_attention_mask, + mlir_path=args.mlir_path, + json_path=args.json_path, + vmfb_path=args.vmfb_path, ) logger.info(f"\n{json.dumps(ppl, indent=2)}") diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 36ae89cc6..12e27273a 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -84,7 +84,7 @@ def add_model_options(parser: argparse.ArgumentParser): ) parser.add_argument( "--skip-decode", - help="Skips export decode", + help="Skips exporting decode", action="store_true", ) parser.add_argument( @@ -99,7 +99,7 @@ def add_model_options(parser: argparse.ArgumentParser): ) parser.add_argument( "--attention-dtype", - help="DType to use for activations in the model", + help="DType to use for attention in the model", default="float16", ) parser.add_argument( diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 608f65c48..2fbdf035e 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -96,11 +96,15 @@ def __init__( activation_dtype: str = "float16", attention_dtype: str = "float16", kv_cache_dtype: Optional[str] = None, + mlir_path: Optional[str] = None, + json_path: Optional[str] = None, ): self.sharktank_dir = str( Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent ) self.irpa_path = irpa_path + self.mlir_path = mlir_path + self.json_path = json_path self.batch_size = batch_size self.iree_hip_target = iree_hip_target self.iree_hal_target_device = iree_hal_target_device @@ -318,6 +322,11 @@ def create_file(self, *, suffix, prefix): def get_artifacts(self): + assert self.attention_kernel in [ + "decomposed", + "torch", + ], "Only torch or decomposed attention_kernel types are supported" + self.dir_path = self.sharktank_dir + "/" + "perplexity_ci_artifacts/" temp_dir = Path(self.dir_path) temp_dir.mkdir(parents=True, exist_ok=True) @@ -327,27 +336,31 @@ def get_artifacts(self): + "_" + self.attention_kernel ) - mlir_path = str( - self.create_file(suffix=".mlir", prefix=self.dir_path + model_name) - ) - json_path = str( - self.create_file(suffix=".json", prefix=self.dir_path + model_name) - ) + + if self.mlir_path is None: + self.mlir_path = str( + self.create_file(suffix=".mlir", prefix=self.dir_path + model_name) + ) + self.json_path = str( + self.create_file(suffix=".json", prefix=self.dir_path + model_name) + ) + + self.export_to_mlir( + mlir_path=self.mlir_path, + json_path=self.json_path, + ) + else: + logger.info(f" Using pre-exported mlir: {self.mlir_path}") + logger.info(f" Using pre-exported config json: {self.json_path}") + vmfb_path = str( self.create_file(suffix=".vmfb", prefix=self.dir_path + model_name) ) - if self.attention_kernel == "decomposed": - returncode = self.export_to_mlir( - mlir_path=mlir_path, - json_path=json_path, - ) - - if returncode == 0: - self.compile_to_vmfb( - mlir_path=mlir_path, - vmfb_path=vmfb_path, - cwd=self.sharktank_dir, - ) + self.compile_to_vmfb( + mlir_path=self.mlir_path, + vmfb_path=vmfb_path, + cwd=self.sharktank_dir, + ) return vmfb_path diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json index 89ef8a28c..4686c29fe 100644 --- a/sharktank/tests/evaluate/baseline_perplexity_scores.json +++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json @@ -1,5 +1,5 @@ { - "llama3_8B_f16_decomposed": { + "llama3_8B_f16_torch": { "perplexities": [ 8.497354, 32.688416, @@ -105,7 +105,7 @@ "mean_perplexity": 20.303255 }, - "llama3_405B_f16_decomposed": { + "llama3_405B_f16_torch": { "perplexities": [ 2.170036, 8.014498, @@ -210,109 +210,109 @@ ], "mean_perplexity": 6.060831 }, - "llama3_8B_f16_decomposed_iree": { + "llama3_8B_f16_iree": { "perplexities": [ - 8.440756, - 30.652054, - 17.888412, - 19.536772, - 16.796043, - 8.346771, - 11.861192, - 26.763889, - 9.173795, - 23.033909, - 11.352187, - 16.573378, - 12.112949, - 8.010364, - 18.730316, - 8.6333, - 9.501037, - 8.262012, - 11.97891, - 10.264008, - 11.920672, - 69.091133, - 12.336814, - 86.867386, - 64.182724, - 23.46203, - 24.490368, - 14.806305, - 11.657981, - 7.551426, - 12.687276, - 27.908455, - 17.644726, - 11.72216, - 7.362862, - 16.815594, - 7.300734, - 21.7402, - 10.908464, - 11.369816, - 9.685975, - 21.589924, - 14.286399, - 4.391925, - 14.578301, - 11.402468, - 13.292189, - 29.865273, - 69.573578, - 32.313053, - 16.852655, - 15.690125, - 9.070885, - 12.365053, - 19.031122, - 13.50634, - 6.177163, - 17.558884, - 11.126381, - 16.31493, - 12.074834, - 8.143363, - 4.383546, - 22.30547, - 67.379478, - 9.771169, - 14.976262, - 13.190949, - 12.115523, - 13.713801, - 12.285597, - 27.620756, - 10.077079, - 16.959599, - 38.700069, - 18.701574, - 21.370266, - 19.843102, - 9.626931, - 16.948587, - 13.99235, - 18.002541, - 19.955822, - 14.684701, - 18.199661, - 54.364548, - 17.486914, - 11.966838, - 20.195158, - 18.739677, - 37.123013, - 16.243725, - 42.555786, - 79.502098, - 15.497804, - 21.469637, - 17.508774, - 18.646059, - 14.222944, - 20.153019 + 8.441894, + 30.664005, + 17.892792, + 19.534321, + 16.799404, + 8.350323, + 11.85428, + 26.780806, + 9.178082, + 23.016304, + 11.346938, + 16.58647, + 12.119105, + 8.009021, + 18.730669, + 8.626894, + 9.514493, + 8.262836, + 11.978479, + 10.257318, + 11.904112, + 69.113541, + 12.32754, + 86.707275, + 64.143013, + 23.460867, + 24.493572, + 14.806771, + 11.655938, + 7.556813, + 12.692316, + 27.887575, + 17.64237, + 11.717109, + 7.367557, + 16.81706, + 7.309159, + 21.744083, + 10.904769, + 11.362426, + 9.688876, + 21.607227, + 14.289747, + 4.390551, + 14.578684, + 11.401663, + 13.330022, + 29.583973, + 69.485657, + 32.276409, + 16.859266, + 15.69192, + 9.069295, + 12.36391, + 19.012669, + 13.517773, + 6.173307, + 17.568123, + 11.119997, + 16.323587, + 12.059111, + 8.142889, + 4.383047, + 22.327551, + 67.429008, + 9.765663, + 14.97523, + 13.186365, + 12.104312, + 13.710983, + 12.288424, + 27.65626, + 10.087663, + 16.962837, + 38.674923, + 18.699448, + 21.353647, + 19.847359, + 9.62309, + 16.944445, + 13.994191, + 18.002008, + 19.970192, + 14.678355, + 18.187948, + 54.508171, + 17.515779, + 11.978756, + 20.18362, + 18.731052, + 37.074276, + 16.22677, + 42.56427, + 79.458481, + 15.499112, + 21.473051, + 17.519047, + 18.651396, + 14.204421, + 20.146591 ], - "mean_perplexity": 19.791108 + "mean_perplexity": 19.786807 } } diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py index f529ac0fa..019e4b92d 100644 --- a/sharktank/tests/evaluate/perplexity_iree_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -34,42 +34,6 @@ def setUp(self): with open(self.baseline_perplexity_scores, "r") as f: self.baseline_perplexity = json.load(f) - def test_llama3_8B_f16_decomposed(self): - - # Llama 3.1 8B decomposed - - model_name = "llama3_8B_f16_decomposed_iree" - baseline_perplexity = self.baseline_perplexity[model_name] - - current_perplexity = perplexity_iree.main( - [ - f"--irpa-file={self.llama3_8b_f16_model}", - f"--tokenizer-config-json={self.llama3_8b_tokenizer}", - f"--iree-device={self.iree_device}", - f"--iree-hal-target-device={self.iree_hal_target_device}", - f"--iree-hip-target={self.iree_hip_target}", - f"--tensor-parallelism-size=1", - f"--attention-kernel=decomposed", - f"--num-prompts={self.batch_size}", - ] - ) - - baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 - ) - current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) - - perplexity_difference = current_mean_perplexity - baseline_mean_perplexity - - self.assertAlmostEqual( - baseline_mean_perplexity, - current_mean_perplexity, - delta=self.delta, - msg=f"Current perplexity deviates baseline by {perplexity_difference}", - ) - - @skipif_run_quick_llama_test - @pytest.mark.xfail(reason="Compile Error") def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed @@ -110,7 +74,7 @@ def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed - model_name = "llama3_8B_fp8_decomposed_iree" + model_name = "llama3_8B_fp8_iree" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_iree.main( @@ -176,44 +140,6 @@ def test_llama3_8B_fp8(self): msg=f"Current perplexity deviates baseline by {perplexity_difference}", ) - @skipif_run_quick_llama_test - @pytest.mark.xfail( - reason="Sharding is unsupported", - ) - def test_llama3_405B_f16_decomposed(self): - - # Llama 3.1 405B decomposed - - model_name = "llama3_405B_f16_decomposed_iree" - baseline_perplexity = self.baseline_perplexity[model_name] - - current_perplexity = perplexity_iree.main( - [ - f"--irpa-file={self.llama3_405b_f16_model}", - f"--tokenizer-config-json={self.llama3_405b_tokenizer}", - f"--iree-device={self.iree_device}", - f"--iree-hal-target-device={self.iree_hal_target_device}", - f"--iree-hip-target={self.iree_hip_target}", - f"--tensor-parallelism-size={self.tensor_parallelism_size}", - f"--attention-kernel=decomposed", - f"--num-prompts={self.batch_size}", - ] - ) - - baseline_mean_perplexity = round( - np.mean(baseline_perplexity["perplexities"][0 : self.batch_size]), 6 - ) - current_mean_perplexity = round(current_perplexity["mean_perplexity"], 6) - - perplexity_difference = current_mean_perplexity - baseline_mean_perplexity - - self.assertAlmostEqual( - baseline_mean_perplexity, - current_mean_perplexity, - delta=self.delta, - msg=f"Current perplexity deviates baseline by {perplexity_difference}", - ) - @skipif_run_quick_llama_test @pytest.mark.xfail(reason="Compile Error") def test_llama3_405B_f16(self): @@ -256,7 +182,7 @@ def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed - model_name = "llama3_405B_fp8_decomposed_iree" + model_name = "llama3_405B_fp8_iree" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_iree.main( diff --git a/sharktank/tests/evaluate/perplexity_torch_test.py b/sharktank/tests/evaluate/perplexity_torch_test.py index 042132f20..3229ff840 100644 --- a/sharktank/tests/evaluate/perplexity_torch_test.py +++ b/sharktank/tests/evaluate/perplexity_torch_test.py @@ -10,7 +10,10 @@ from sharktank.evaluate import perplexity_torch -longrun = pytest.mark.skipif("not config.getoption('longrun')") +skipif_run_quick_llama_test = pytest.mark.skipif( + 'not config.getoption("run-nightly-llama-tests")', + reason="Run large tests if --run-nightly-llama-tests is passed", +) @pytest.mark.usefixtures( @@ -24,49 +27,19 @@ def setUp(self): with open(self.baseline_perplexity_scores, "r") as f: self.baseline_perplexity = json.load(f) - @longrun - def test_llama3_8B_f16_decomposed(self): - - # Llama 3.1 8B decomposed - - model_name = "llama3_8B_f16_decomposed" - baseline_perplexity = self.baseline_perplexity[model_name] - - current_perplexity = perplexity_torch.main( - [ - f"--irpa-file={self.llama3_8b_f16_model}", - f"--tokenizer-config-json={self.llama3_8b_tokenizer}", - ] - ) - - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] - ) - - self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], - delta=self.delta, - msg=f"Current perplexity deviates baseline by {perplexity_difference}", - ) - - @pytest.mark.xfail( - reason="Non-decomposed attention is not supported yet", - ) - @longrun + @skipif_run_quick_llama_test def test_llama3_8B_f16(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_f16" + model_name = "llama3_8B_f16_torch" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_torch.main( [ f"--irpa-file={self.llama3_8b_f16_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", - f"--attention-kernel=torch_sdpa", + f"--attention-kernel=torch", ] ) @@ -85,12 +58,12 @@ def test_llama3_8B_f16(self): @pytest.mark.xfail( reason="FP8 model is unsupported", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_8B_fp8_decomposed(self): # Llama 3.1 8B decomposed - model_name = "llama3_8B_fp8_decomposed" + model_name = "llama3_8B_fp8_torch" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_torch.main( @@ -115,50 +88,19 @@ def test_llama3_8B_fp8_decomposed(self): @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_8B_fp8(self): # Llama 3.1 8B non-decomposed - model_name = "llama3_8B_fp8" + model_name = "llama3_8B_fp8_torch" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_torch.main( [ f"--irpa-file={self.llama3_8b_fp8_model}", f"--tokenizer-config-json={self.llama3_8b_tokenizer}", - f"--attention-kernel=torch_sdpa", - ] - ) - - perplexity_difference = ( - current_perplexity["mean_perplexity"] - - baseline_perplexity["mean_perplexity"] - ) - - self.assertAlmostEqual( - baseline_perplexity["mean_perplexity"], - current_perplexity["mean_perplexity"], - delta=self.delta, - msg=f"Current perplexity deviates baseline by {perplexity_difference}", - ) - - @pytest.mark.xfail( - reason="Sharding needs to be fixed", - ) - @longrun - def test_llama3_405B_f16_decomposed(self): - - # Llama 3.1 405B decomposed - - model_name = "llama3_405B_f16_decomposed" - baseline_perplexity = self.baseline_perplexity[model_name] - - current_perplexity = perplexity_torch.main( - [ - f"--irpa-file={self.llama3_405b_f16_model}", - f"--tokenizer-config-json={self.llama3_405b_tokenizer}", - f"--tensor-parallelism-size={self.tensor_parallelism_size}", + f"--attention-kernel=torch", ] ) @@ -177,12 +119,12 @@ def test_llama3_405B_f16_decomposed(self): @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_405B_f16(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_f16" + model_name = "llama3_405B_f16_torch" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_torch.main( @@ -190,7 +132,7 @@ def test_llama3_405B_f16(self): f"--irpa-file={self.llama3_405b_f16_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", - f"--attention-kernel=torch_sdpa", + f"--attention-kernel=torch", ] ) @@ -209,12 +151,12 @@ def test_llama3_405B_f16(self): @pytest.mark.xfail( reason="FP8 model is unsupported", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_405B_fp8_decomposed(self): # Llama 3.1 405B decomposed - model_name = "llama3_405B_fp8_decomposed" + model_name = "llama3_405B_fp8_torch" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_torch.main( @@ -240,12 +182,12 @@ def test_llama3_405B_fp8_decomposed(self): @pytest.mark.xfail( reason="Non-decomposed attention is not supported yet", ) - @longrun + @skipif_run_quick_llama_test def test_llama3_405B_fp8(self): # Llama 3.1 405B non-decomposed - model_name = "llama3_405B_fp8" + model_name = "llama3_405B_fp8_torch" baseline_perplexity = self.baseline_perplexity[model_name] current_perplexity = perplexity_torch.main( @@ -253,7 +195,7 @@ def test_llama3_405B_fp8(self): f"--irpa-file={self.llama3_405b_fp8_model}", f"--tokenizer-config-json={self.llama3_405b_tokenizer}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", - f"--attention-kernel=torch_sdpa", + f"--attention-kernel=torch", ] )