Skip to content

Commit

Permalink
Merge branch 'main' into integrates/iree
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida authored Mar 3, 2025
2 parents 21424ef + 868afc3 commit b709b49
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 63 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build_packages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ jobs:
submodules: false

- name: Download version_local.json files
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806 # v4.1.9
with:
name: version_local_files
path: ./c/
Expand All @@ -148,7 +148,7 @@ jobs:
echo "SHORTFIN_ENABLE_TRACING=OFF" >> $GITHUB_ENV
- name: Setup cache
if: ${{ inputs.build_type == 'dev' }}
uses: actions/cache@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
uses: actions/cache@d4323d4df104b026a6aa633fdb11d772146be0bf # v4.2.2
with:
path: ${{ env.CACHE_DIR }}
key: build-packages-${{ matrix.package }}-${{ matrix.platform }}-${{ matrix.python-version }}-v1-${{ github.sha }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci-sglang-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,14 @@ jobs:
run: pip install pytest-html-merger

- name: Download shortfin report
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16
uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806
with:
name: shortfin_benchmark
path: reports
continue-on-error: true

- name: Download sglang report
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16
uses: actions/download-artifact@cc203385981b70ca67e1cc392babf9cc229d5806
with:
name: sglang_benchmark
path: reports
Expand Down
11 changes: 6 additions & 5 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
python-version: ${{matrix.python-version}}

- name: Cache Pip Packages
uses: actions/cache@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
uses: actions/cache@d4323d4df104b026a6aa633fdb11d772146be0bf # v4.2.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
Expand Down Expand Up @@ -93,15 +93,16 @@ jobs:
strategy:
matrix:
python-version: [3.11]
runs-on: [llama-mi300x-3]
runs-on: [linux-mi300-1gpu-ossci]
fail-fast: false
runs-on: ${{matrix.runs-on}}
defaults:
run:
shell: bash
env:
VENV_DIR: ${{ github.workspace }}/.venv
HF_HOME: "/data/huggingface"
HF_HOME: "/shark-cache/data/huggingface"
HF_TOKEN: ${{ secrets.HF_FLUX_TOKEN }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

Expand Down Expand Up @@ -169,7 +170,7 @@ jobs:
python-version: 3.11

- name: Cache Pip Packages
uses: actions/cache@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
uses: actions/cache@d4323d4df104b026a6aa633fdb11d772146be0bf # v4.2.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
Expand All @@ -193,7 +194,7 @@ jobs:
run: |
pytest -v sharktank/ -m punet_quick \
--durations=0 \
--timeout=600
--timeout=900
# Depends on other jobs to provide an aggregate job status.
# TODO(#584): move test_with_data and test_integration to a pkgci integration test workflow?
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/ci_linux_x64_asan-libshortfin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
steps:
- name: Cache Python ASan
id: cache-python-asan
uses: actions/cache@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
uses: actions/cache@d4323d4df104b026a6aa633fdb11d772146be0bf # v4.2.2
with:
path: ${{ env.PYENV_ROOT }}
key: ${{ runner.os }}-python-asan-${{ env.PYENV_REF }}-${{ env.PYTHON_VER }}-v${{ env.CACHE_ASAN_VER }}
Expand Down Expand Up @@ -101,15 +101,15 @@ jobs:
- name: Restore Python dependencies cache
id: cache-python-deps-restore
uses: actions/cache/restore@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
uses: actions/cache/restore@d4323d4df104b026a6aa633fdb11d772146be0bf # v4.2.2
with:
path: ${{ env.PYENV_ROOT }}
key: ${{ runner.os }}-python-deps-${{ hashFiles('shortfin/requirements-tests.txt', 'requirements-iree-pinned.txt') }}-v${{ env.CACHE_DEPS_VER }}

- name: Restore Python ASan cache
id: cache-python-asan
if: steps.cache-python-deps-restore.outputs.cache-hit != 'true'
uses: actions/cache/restore@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
uses: actions/cache/restore@d4323d4df104b026a6aa633fdb11d772146be0bf # v4.2.2
with:
path: ${{ env.PYENV_ROOT }}
key: ${{ runner.os }}-python-asan-${{ env.PYENV_REF }}-${{ env.PYTHON_VER }}-v${{ env.CACHE_ASAN_VER }}
Expand All @@ -130,7 +130,7 @@ jobs:
- name: Save Python dependencies cache
if: steps.cache-python-deps-restore.outputs.cache-hit != 'true'
id: cache-python-deps-save
uses: actions/cache/save@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
uses: actions/cache/save@d4323d4df104b026a6aa633fdb11d772146be0bf # v4.2.2
with:
path: ${{ env.PYENV_ROOT }}
key: ${{ steps.cache-python-deps-restore.outputs.cache-primary-key }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/update_iree_requirement_pins.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
- name: Update IREE requirement pins
run: build_tools/update_iree_requirement_pins.py

- uses: actions/create-github-app-token@0d564482f06ca65fa9e77e2510873638c82206f2 # v1.11.5
- uses: actions/create-github-app-token@21cfef2b496dd8ef5b904c159339626a10ad380e # v1.11.6
if: ${{ env.CREATE_PULL_REQUEST_TOKEN_APP_ID != '' && env.CREATE_PULL_REQUEST_TOKEN_APP_PRIVATE_KEY != '' }}
id: generate-token
with:
Expand Down
4 changes: 2 additions & 2 deletions sharktank/tests/models/llama/quark_parity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class QuarkParityTest(TempDirTestBase):
def setUp(self):
super().setUp()
self.path_prefix = Path("/shark-dev/quark_test")
self.path_prefix = Path("/shark-cache/quark_test")

@with_quark_data
def test_compare_against_quark(self):
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_compare_against_quark(self):
"sharktank.examples.paged_llm_v1",
"The capitol of Texas is",
f"--irpa-file={self.path_prefix}/fp8_bf16_weight.irpa",
f"--tokenizer-config-json=/data/llama3.1/8b/tokenizer.json",
f"--tokenizer-config-json=/shark-dev/data/llama3.1/8b/tokenizer.json",
"--fake-quant",
"--attention-kernel=torch",
"--activation-dtype=bfloat16",
Expand Down
97 changes: 50 additions & 47 deletions sharktank/tests/models/vae/vae_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,42 +49,43 @@ def setUp(self):
hf_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
hf_hub_download(
repo_id=hf_model_id,
local_dir="{self._temp_dir}",
local_dir=f"{self._temp_dir}",
local_dir_use_symlinks=False,
revision="main",
filename="vae/config.json",
)
hf_hub_download(
repo_id=hf_model_id,
local_dir="{self._temp_dir}",
local_dir=f"{self._temp_dir}",
local_dir_use_symlinks=False,
revision="main",
filename="vae/diffusion_pytorch_model.safetensors",
)
hf_hub_download(
repo_id="amd-shark/sdxl-quant-models",
local_dir="{self._temp_dir}",
local_dir=f"{self._temp_dir}",
local_dir_use_symlinks=False,
revision="main",
filename="vae/vae.safetensors",
)
torch.manual_seed(12345)
f32_dataset = import_hf_dataset(
"{self._temp_dir}/vae/config.json",
["{self._temp_dir}/vae/diffusion_pytorch_model.safetensors"],
f"{self._temp_dir}/vae/config.json",
[f"{self._temp_dir}/vae/diffusion_pytorch_model.safetensors"],
)
f32_dataset.save("{self._temp_dir}/vae_f32.irpa", io_report_callback=print)
f32_dataset.save(f"{self._temp_dir}/vae_f32.irpa", io_report_callback=print)
f16_dataset = import_hf_dataset(
"{self._temp_dir}/vae/config.json", ["{self._temp_dir}/vae/vae.safetensors"]
f"{self._temp_dir}/vae/config.json",
[f"{self._temp_dir}/vae/vae.safetensors"],
)
f16_dataset.save("{self._temp_dir}/vae_f16.irpa", io_report_callback=print)
f16_dataset.save(f"{self._temp_dir}/vae_f16.irpa", io_report_callback=print)

def testCompareF32EagerVsHuggingface(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
ref_results = run_torch_vae("{self._temp_dir}", inputs)
ref_results = run_torch_vae(f"{self._temp_dir}", inputs)

ds = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa")
ds = Dataset.load(f"{self._temp_dir}/vae_f32.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")

results = model.forward(inputs)
Expand All @@ -95,9 +96,9 @@ def testCompareF32EagerVsHuggingface(self):
def testCompareF16EagerVsHuggingface(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
ref_results = run_torch_vae("{self._temp_dir}", inputs)
ref_results = run_torch_vae(f"{self._temp_dir}", inputs)

ds = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa")
ds = Dataset.load(f"{self._temp_dir}/vae_f16.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")

results = model.forward(inputs.to(torch.float16))
Expand All @@ -107,10 +108,10 @@ def testCompareF16EagerVsHuggingface(self):
def testVaeIreeVsHuggingFace(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
ref_results = run_torch_vae("{self._temp_dir}", inputs)
ref_results = run_torch_vae(f"{self._temp_dir}", inputs)

ds_f16 = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa")
ds_f32 = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa")
ds_f16 = Dataset.load(f"{self._temp_dir}/vae_f16.irpa", file_type="irpa")
ds_f32 = Dataset.load(f"{self._temp_dir}/vae_f32.irpa", file_type="irpa")

model_f16 = VaeDecoderModel.from_dataset(ds_f16).to(device="cpu")
model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu")
Expand All @@ -119,8 +120,8 @@ def testVaeIreeVsHuggingFace(self):
module_f16 = export_vae(model_f16, inputs.to(torch.float16), True)
module_f32 = export_vae(model_f32, inputs, True)

module_f16.save_mlir("{self._temp_dir}/vae_f16.mlir")
module_f32.save_mlir("{self._temp_dir}/vae_f32.mlir")
module_f16.save_mlir(f"{self._temp_dir}/vae_f16.mlir")
module_f32.save_mlir(f"{self._temp_dir}/vae_f32.mlir")
extra_args = [
"--iree-hal-target-backends=rocm",
"--iree-hip-target=gfx942",
Expand All @@ -137,22 +138,22 @@ def testVaeIreeVsHuggingFace(self):
]

iree.compiler.compile_file(
"{self._temp_dir}/vae_f16.mlir",
output_file="{self._temp_dir}/vae_f16.vmfb",
f"{self._temp_dir}/vae_f16.mlir",
output_file=f"{self._temp_dir}/vae_f16.vmfb",
extra_args=extra_args,
)
iree.compiler.compile_file(
"{self._temp_dir}/vae_f32.mlir",
output_file="{self._temp_dir}/vae_f32.vmfb",
f"{self._temp_dir}/vae_f32.mlir",
output_file=f"{self._temp_dir}/vae_f32.vmfb",
extra_args=extra_args,
)

iree_devices = get_iree_devices(driver="hip", device_count=1)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="{self._temp_dir}/vae_f16.vmfb",
module_path=f"{self._temp_dir}/vae_f16.vmfb",
devices=iree_devices,
parameters_path="{self._temp_dir}/vae_f16.irpa",
parameters_path=f"{self._temp_dir}/vae_f16.irpa",
)

input_args = OrderedDict([("inputs", inputs.to(torch.float16))])
Expand All @@ -178,9 +179,9 @@ def testVaeIreeVsHuggingFace(self):
)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="{self._temp_dir}/vae_f32.vmfb",
module_path=f"{self._temp_dir}/vae_f32.vmfb",
devices=iree_devices,
parameters_path="{self._temp_dir}/vae_f32.irpa",
parameters_path=f"{self._temp_dir}/vae_f32.irpa",
)

input_args = OrderedDict([("inputs", inputs)])
Expand Down Expand Up @@ -209,30 +210,32 @@ def setUp(self):
hf_model_id = "black-forest-labs/FLUX.1-dev"
hf_hub_download(
repo_id=hf_model_id,
local_dir="{self._temp_dir}/flux_vae/",
local_dir=f"{self._temp_dir}/flux_vae/",
local_dir_use_symlinks=False,
revision="main",
filename="vae/config.json",
)
hf_hub_download(
repo_id=hf_model_id,
local_dir="{self._temp_dir}/flux_vae/",
local_dir=f"{self._temp_dir}/flux_vae/",
local_dir_use_symlinks=False,
revision="main",
filename="vae/diffusion_pytorch_model.safetensors",
)
torch.manual_seed(12345)
dataset = import_hf_dataset(
"{self._temp_dir}/flux_vae/vae/config.json",
["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
f"{self._temp_dir}/flux_vae/vae/config.json",
[f"{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
)
dataset.save("{self._temp_dir}/flux_vae_bf16.irpa", io_report_callback=print)
dataset.save(f"{self._temp_dir}/flux_vae_bf16.irpa", io_report_callback=print)
dataset_f32 = import_hf_dataset(
"{self._temp_dir}/flux_vae/vae/config.json",
["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
f"{self._temp_dir}/flux_vae/vae/config.json",
[f"{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
target_dtype=torch.float32,
)
dataset_f32.save("{self._temp_dir}/flux_vae_f32.irpa", io_report_callback=print)
dataset_f32.save(
f"{self._temp_dir}/flux_vae_f32.irpa", io_report_callback=print
)

def testCompareBF16EagerVsHuggingface(self):
dtype = torch.bfloat16
Expand All @@ -241,7 +244,7 @@ def testCompareBF16EagerVsHuggingface(self):
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, dtype
)

ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
ds = Dataset.load(f"{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")

results = model.forward(inputs)
Expand All @@ -255,7 +258,7 @@ def testCompareF32EagerVsHuggingface(self):
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, dtype
)

ds = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
ds = Dataset.load(f"{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu", dtype=dtype)

results = model.forward(inputs)
Expand All @@ -270,8 +273,8 @@ def testVaeIreeVsHuggingFace(self):
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, torch.float32
)

ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
ds_f32 = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
ds = Dataset.load(f"{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
ds_f32 = Dataset.load(f"{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")

model = VaeDecoderModel.from_dataset(ds).to(device="cpu")
model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu")
Expand All @@ -280,8 +283,8 @@ def testVaeIreeVsHuggingFace(self):
module = export_vae(model, inputs.to(dtype=dtype), True)
module_f32 = export_vae(model_f32, inputs, True)

module.save_mlir("{self._temp_dir}/flux_vae_bf16.mlir")
module_f32.save_mlir("{self._temp_dir}/flux_vae_f32.mlir")
module.save_mlir(f"{self._temp_dir}/flux_vae_bf16.mlir")
module_f32.save_mlir(f"{self._temp_dir}/flux_vae_f32.mlir")

extra_args = [
"--iree-hal-target-backends=rocm",
Expand All @@ -299,22 +302,22 @@ def testVaeIreeVsHuggingFace(self):
]

iree.compiler.compile_file(
"{self._temp_dir}/flux_vae_bf16.mlir",
output_file="{self._temp_dir}/flux_vae_bf16.vmfb",
f"{self._temp_dir}/flux_vae_bf16.mlir",
output_file=f"{self._temp_dir}/flux_vae_bf16.vmfb",
extra_args=extra_args,
)
iree.compiler.compile_file(
"{self._temp_dir}/flux_vae_f32.mlir",
output_file="{self._temp_dir}/flux_vae_f32.vmfb",
f"{self._temp_dir}/flux_vae_f32.mlir",
output_file=f"{self._temp_dir}/flux_vae_f32.vmfb",
extra_args=extra_args,
)

iree_devices = get_iree_devices(driver="hip", device_count=1)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="{self._temp_dir}/flux_vae_bf16.vmfb",
module_path=f"{self._temp_dir}/flux_vae_bf16.vmfb",
devices=iree_devices,
parameters_path="{self._temp_dir}/flux_vae_bf16.irpa",
parameters_path=f"{self._temp_dir}/flux_vae_bf16.irpa",
)

input_args = OrderedDict([("inputs", inputs.to(dtype=dtype))])
Expand All @@ -339,9 +342,9 @@ def testVaeIreeVsHuggingFace(self):
)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="{self._temp_dir}/flux_vae_f32.vmfb",
module_path=f"{self._temp_dir}/flux_vae_f32.vmfb",
devices=iree_devices,
parameters_path="{self._temp_dir}/flux_vae_f32.irpa",
parameters_path=f"{self._temp_dir}/flux_vae_f32.irpa",
)

input_args = OrderedDict([("inputs", inputs)])
Expand Down

0 comments on commit b709b49

Please sign in to comment.