-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Imports/validates the FP16 model running eagerly. * Imports/validates the int8 model running eagerly. * Exports the models.
- Loading branch information
1 parent
8f3f93d
commit 4ec1e65
Showing
8 changed files
with
414 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from pathlib import Path | ||
|
||
|
||
# Tests under each top-level directory will get a mark. | ||
TLD_MARKS = { | ||
"tests": "unit", | ||
"integration": "integration", | ||
} | ||
|
||
|
||
def pytest_collection_modifyitems(items, config): | ||
# Add marks to all tests based on their top-level directory component. | ||
root_path = Path(__file__).resolve().parent | ||
for item in items: | ||
item_path = Path(item.path) | ||
rel_path = item_path.relative_to(root_path) | ||
tld = rel_path.parts[0] | ||
mark = TLD_MARKS.get(tld) | ||
if mark: | ||
item.add_marker(mark) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from pathlib import Path | ||
import pytest | ||
|
||
from sharktank.utils import testing | ||
|
||
|
||
@pytest.fixture | ||
def temp_dir(): | ||
with testing.temporary_directory(__name__) as td: | ||
yield Path(td) | ||
|
||
|
||
@pytest.fixture | ||
def punet_goldens(): | ||
from huggingface_hub import hf_hub_download | ||
|
||
REPO_ID = "amd-shark/sharktank-goldens" | ||
REVISION = "230dad4d85fbcb8759a331dcf1d45f0562875abe" | ||
|
||
def download(filename): | ||
return hf_hub_download( | ||
repo_id=REPO_ID, subfolder="punet", filename=filename, revision=REVISION | ||
) | ||
|
||
return { | ||
"inputs.safetensors": download("classifier_free_guidance_inputs.safetensors"), | ||
"outputs.safetensors": download( | ||
"classifier_free_guidance_fp16_outputs.safetensors" | ||
), | ||
"outputs_int8.safetensors": download( | ||
"classifier_free_guidance_int8_outputs.safetensors" | ||
), | ||
} | ||
|
||
|
||
################################################################################ | ||
# FP16 Dataset | ||
################################################################################ | ||
|
||
|
||
@pytest.fixture | ||
def sdxl_fp16_base_files(): | ||
from huggingface_hub import hf_hub_download | ||
|
||
REPO_ID = "stabilityai/stable-diffusion-xl-base-1.0" | ||
REVISION = "76d28af79639c28a79fa5c6c6468febd3490a37e" | ||
|
||
def download(filename): | ||
return hf_hub_download( | ||
repo_id=REPO_ID, subfolder="unet", filename=filename, revision=REVISION | ||
) | ||
|
||
return { | ||
"config.json": download("config.json"), | ||
"params.safetensors": download("diffusion_pytorch_model.fp16.safetensors"), | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir): | ||
from sharktank.models.punet.tools import import_hf_dataset | ||
|
||
dataset = temp_dir / "dataset.irpa" | ||
import_hf_dataset.main( | ||
[ | ||
f"--config-json={sdxl_fp16_base_files['config.json']}", | ||
f"--params={sdxl_fp16_base_files['params.safetensors']}", | ||
f"--output-irpa-file={dataset}", | ||
] | ||
) | ||
yield dataset | ||
|
||
|
||
################################################################################ | ||
# INT8 Dataset | ||
################################################################################ | ||
|
||
|
||
@pytest.fixture | ||
def sdxl_int8_base_files(): | ||
from huggingface_hub import hf_hub_download | ||
|
||
REPO_ID = "amd-shark/sdxl-quant-models" | ||
REVISION = "82e06d6ea22ac78102a9aded69e8ddfb9fa4ae37" | ||
|
||
def download(filename): | ||
return hf_hub_download( | ||
repo_id=REPO_ID, subfolder="unet/int8", filename=filename, revision=REVISION | ||
) | ||
|
||
return { | ||
"config.json": download("config.json"), | ||
"params.safetensors": download("params.safetensors"), | ||
"quant_params.json": download("quant_params.json"), | ||
} | ||
|
||
|
||
@pytest.fixture | ||
def sdxl_int8_dataset(sdxl_int8_base_files, temp_dir): | ||
from sharktank.models.punet.tools import import_brevitas_dataset | ||
|
||
dataset = temp_dir / "dataset.irpa" | ||
import_brevitas_dataset.main( | ||
[ | ||
f"--config-json={sdxl_int8_base_files['config.json']}", | ||
f"--params={sdxl_int8_base_files['params.safetensors']}", | ||
f"--quant-params={sdxl_int8_base_files['quant_params.json']}", | ||
f"--output-irpa-file={dataset}", | ||
] | ||
) | ||
yield dataset | ||
|
||
|
||
################################################################################ | ||
# Export fixtures | ||
################################################################################ | ||
|
||
|
||
@pytest.fixture | ||
def sdxl_fp16_export_mlir(sdxl_fp16_dataset, temp_dir): | ||
from sharktank.models.punet.tools import run_punet | ||
|
||
output_path = temp_dir / "sdxl_fp16_export_mlir.mlir" | ||
print(f"Exporting to {output_path}") | ||
run_punet.main( | ||
[ | ||
f"--irpa-file={sdxl_fp16_dataset}", | ||
"--dtype=float16", | ||
f"--device=cpu", | ||
f"--export={output_path}", | ||
] | ||
) | ||
return output_path | ||
|
||
|
||
@pytest.mark.model_punet | ||
@pytest.mark.export | ||
def test_sdxl_export_fp16_mlir(sdxl_fp16_export_mlir): | ||
print(f"Exported: {sdxl_fp16_export_mlir}") | ||
|
||
|
||
@pytest.fixture | ||
def sdxl_int8_export_mlir(sdxl_int8_dataset, temp_dir): | ||
from sharktank.models.punet.tools import run_punet | ||
|
||
output_path = temp_dir / "sdxl_int8_export_mlir.mlir" | ||
print(f"Exporting to {output_path}") | ||
run_punet.main( | ||
[ | ||
f"--irpa-file={sdxl_int8_dataset}", | ||
"--dtype=float16", | ||
f"--device=cpu", | ||
f"--export={output_path}", | ||
] | ||
) | ||
return output_path | ||
|
||
|
||
@pytest.mark.model_punet | ||
@pytest.mark.export | ||
def test_sdxl_export_int8_mlir(sdxl_int8_export_mlir): | ||
print(f"Exported: {sdxl_int8_export_mlir}") | ||
|
||
|
||
################################################################################ | ||
# Eager tests | ||
################################################################################ | ||
|
||
|
||
@pytest.mark.model_punet | ||
@pytest.mark.golden | ||
def test_punet_eager_fp16_validation(punet_goldens, sdxl_fp16_dataset, temp_dir): | ||
from sharktank.models.punet.tools import run_punet | ||
|
||
device = testing.get_best_torch_device() | ||
output_path = temp_dir / "actual_outputs.safetensors" | ||
print("Using torch device:", device) | ||
run_punet.main( | ||
[ | ||
f"--irpa-file={sdxl_fp16_dataset}", | ||
"--dtype=float16", | ||
f"--device={device}", | ||
f"--inputs={punet_goldens['inputs.safetensors']}", | ||
f"--outputs={output_path}", | ||
] | ||
) | ||
testing.assert_golden_safetensors(output_path, punet_goldens["outputs.safetensors"]) | ||
|
||
|
||
# Executes eagerly using custom integer kernels. | ||
@pytest.mark.model_punet | ||
@pytest.mark.expensive | ||
@pytest.mark.golden | ||
@pytest.mark.skip("Not yet working") | ||
def test_punet_eager_int8_validation(punet_goldens, sdxl_int8_dataset, temp_dir): | ||
from sharktank.models.punet.tools import run_punet | ||
|
||
device = testing.get_best_torch_device() | ||
output_path = temp_dir / "actual_outputs.safetensors" | ||
print("Using torch device:", device) | ||
run_punet.main( | ||
[ | ||
f"--irpa-file={sdxl_int8_dataset}", | ||
"--dtype=float16", | ||
f"--device={device}", | ||
f"--inputs={punet_goldens['inputs.safetensors']}", | ||
f"--outputs={output_path}", | ||
] | ||
) | ||
testing.assert_golden_safetensors(output_path, punet_goldens["outputs.safetensors"]) | ||
|
||
|
||
# Executes using emulated fp kernels for key integer operations. | ||
# Useful for speed/comparison. | ||
@pytest.mark.model_punet | ||
@pytest.mark.golden | ||
def test_punet_eager_int8_emulated_validation( | ||
punet_goldens, sdxl_int8_dataset, temp_dir | ||
): | ||
from sharktank.models.punet.tools import run_punet | ||
|
||
device = testing.get_best_torch_device() | ||
output_path = temp_dir / "actual_outputs.safetensors" | ||
print("Using torch device:", device) | ||
with testing.override_debug_flags( | ||
{ | ||
"use_custom_int_conv_kernel": False, | ||
"use_custom_int_mm_kernel": False, | ||
} | ||
): | ||
run_punet.main( | ||
[ | ||
f"--irpa-file={sdxl_int8_dataset}", | ||
"--dtype=float16", | ||
f"--device={device}", | ||
f"--inputs={punet_goldens['inputs.safetensors']}", | ||
f"--outputs={output_path}", | ||
] | ||
) | ||
testing.assert_golden_safetensors( | ||
output_path, punet_goldens["outputs_int8.safetensors"] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.