Skip to content

Commit

Permalink
Use correct paths for inputs
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <avinash.sharma@amd.com>
  • Loading branch information
aviator19941 committed Feb 27, 2025
1 parent 9a4b1c3 commit 7911393
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ class BenchmarkLlama3_1_70B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
self.artifacts_dir = Path("/shark-dev/70b")
self.weights_dir = self.artifacts_dir / "instruct/weights"
self.irpa_path = self.weights_dir / "llama3.1_70b_instruct_fp16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "fp8/llama70b_fp8.irpa"
self.artifacts_dir = Path("/shark-dev/data/llama3.1/weights/70b")
self.artifacts_dir_2048 = Path("/shark-dev/70b")
self.irpa_path = self.artifacts_dir / "fp16/llama3.1_70b_f16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "f8/llama70b_fp8.irpa"
self.tensor_parallelism_size = 8
self.dir_path_70b = self.dir_path / "llama-70b"
self.temp_dir_70b = Path(self.dir_path_70b)
Expand Down Expand Up @@ -334,10 +334,10 @@ def setUp(self):
self.artifacts_dir / "decode_args_bs4_128_stride_32"
)
self.prefill_args_bs4_2048_stride_32_tp1_f16 = (
self.artifacts_dir / "prefill_args_bs4_2048_stride_32"
self.artifacts_dir_2048 / "prefill_args_bs4_2048_stride_32"
)
self.decode_args_bs4_2048_stride_32_tp1_f16 = (
self.artifacts_dir / "decode_args_bs4_2048_stride_32"
self.artifacts_dir_2048 / "decode_args_bs4_2048_stride_32"
)
self.prefill_args_bs4_128_stride_32_tp8_f16 = (
self.artifacts_dir / "prefill_args_bs4_128_stride_32_tp8"
Expand All @@ -346,10 +346,10 @@ def setUp(self):
self.artifacts_dir / "decode_args_bs4_128_stride_32_tp8"
)
self.prefill_args_bs4_2048_stride_32_tp8_f16 = (
self.artifacts_dir / "prefill_args_bs4_2048_stride_32_tp8"
self.artifacts_dir_2048 / "prefill_args_bs4_2048_stride_32_tp8"
)
self.decode_args_bs4_2048_stride_32_tp8_f16 = (
self.artifacts_dir / "decode_args_bs4_2048_stride_32_tp8"
self.artifacts_dir_2048 / "decode_args_bs4_2048_stride_32_tp8"
)
self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8"
self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8"
Expand Down Expand Up @@ -560,8 +560,8 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_128(self):
suffix=".vmfb", prefix=output_file_name
)
output_shard_file_name = (
self.weights_dir
/ f"tp8/llama3_70b_instruct_fp16_tp{self.tensor_parallelism_size}.irpa"
self.artifacts_dir
/ f"fp16/tp8/llama3.1_70b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa"
)
if output_shard_file_name.exists():
self.llama70b_f16_torch_sdpa_artifacts_tp8.irpa_path = (
Expand Down Expand Up @@ -610,8 +610,8 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
suffix=".vmfb", prefix=output_file_name
)
output_shard_file_name = (
self.weights_dir
/ f"tp8/llama3_70b_instruct_fp16_tp{self.tensor_parallelism_size}.irpa"
self.artifacts_dir
/ f"fp16/tp8/llama3.1_70b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa"
)
if output_shard_file_name.exists():
self.llama70b_f16_torch_sdpa_artifacts_tp8.irpa_path = (
Expand Down Expand Up @@ -694,11 +694,9 @@ class BenchmarkLlama3_1_405B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
self.artifacts_dir = Path("/shark-dev/405b")
self.weights_dir = self.artifacts_dir / "instruct/weights"
self.irpa_path = Path(
"/shark-dev/data/llama3.1/weights/405b/fp16/llama3.1_405b_fp16.irpa"
)
self.artifacts_dir = Path("/shark-dev/data/llama3.1/weights/405b")
self.artifacts_dir_2048 = Path("/shark-dev/405b")
self.irpa_path = self.artifacts_dir / "fp16/llama3.1_405b_fp16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "f8/llama3.1_405b_fp8.irpa"
self.tensor_parallelism_size = 8
self.dir_path_405b = self.dir_path / "llama-405b"
Expand Down Expand Up @@ -835,7 +833,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_128(self):
)
output_shard_file_name = (
self.artifacts_dir
/ f"tp8/llama3_405b_instruct_fp16_tp{self.tensor_parallelism_size}.irpa"
/ f"fp16/tp8/llama3.1_405b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa"
)
if output_shard_file_name.exists():
self.llama405b_f16_torch_sdpa_artifacts.irpa_path = output_shard_file_name
Expand Down Expand Up @@ -876,7 +874,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
)
output_shard_file_name = (
self.artifacts_dir
/ f"tp8/llama3_405b_instruct_fp16_tp{self.tensor_parallelism_size}.irpa"
/ f"fp16/tp8/llama3.1_405b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa"
)
if output_shard_file_name.exists():
self.llama405b_f16_torch_sdpa_artifacts.irpa_path = output_shard_file_name
Expand Down

0 comments on commit 7911393

Please sign in to comment.