Skip to content

Commit

Permalink
[punet] Add integration tests.
Browse files Browse the repository at this point in the history
* Imports/validates the FP16 model running eagerly.
* Imports/validates the int8 model running eagerly.
* Exports the models.
  • Loading branch information
stellaraccident committed Jun 29, 2024
1 parent f577d8b commit 471473b
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 13 deletions.
26 changes: 26 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path


# Tests under each top-level directory will get a mark.
TLD_MARKS = {
"tests": "unit",
"integration": "integration",
}


def pytest_collection_modifyitems(items, config):
# Add marks to all tests based on their top-level directory component.
root_path = Path(__file__).resolve().parent
for item in items:
item_path = Path(item.path)
rel_path = item_path.relative_to(root_path)
tld = rel_path.parts[0]
mark = TLD_MARKS.get(tld)
if mark:
item.add_marker(mark)
248 changes: 248 additions & 0 deletions sharktank/integration/models/punet/integration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pathlib import Path
import pytest

from sharktank.utils import testing


@pytest.fixture
def temp_dir():
with testing.temporary_directory(__name__) as td:
yield Path(td)


@pytest.fixture
def punet_goldens():
from huggingface_hub import hf_hub_download

REPO_ID = "amd-shark/sharktank-goldens"
REVISION = "230dad4d85fbcb8759a331dcf1d45f0562875abe"

def download(filename):
return hf_hub_download(
repo_id=REPO_ID, subfolder="punet", filename=filename, revision=REVISION
)

return {
"inputs.safetensors": download("classifier_free_guidance_inputs.safetensors"),
"outputs.safetensors": download(
"classifier_free_guidance_fp16_outputs.safetensors"
),
"outputs_int8.safetensors": download(
"classifier_free_guidance_int8_outputs.safetensors"
),
}


################################################################################
# FP16 Dataset
################################################################################


@pytest.fixture
def sdxl_fp16_base_files():
from huggingface_hub import hf_hub_download

REPO_ID = "stabilityai/stable-diffusion-xl-base-1.0"
REVISION = "76d28af79639c28a79fa5c6c6468febd3490a37e"

def download(filename):
return hf_hub_download(
repo_id=REPO_ID, subfolder="unet", filename=filename, revision=REVISION
)

return {
"config.json": download("config.json"),
"params.safetensors": download("diffusion_pytorch_model.fp16.safetensors"),
}


@pytest.fixture
def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir):
from sharktank.models.punet.tools import import_hf_dataset

dataset = temp_dir / "dataset.irpa"
import_hf_dataset.main(
[
f"--config-json={sdxl_fp16_base_files['config.json']}",
f"--params={sdxl_fp16_base_files['params.safetensors']}",
f"--output-irpa-file={dataset}",
]
)
yield dataset


################################################################################
# INT8 Dataset
################################################################################


@pytest.fixture
def sdxl_int8_base_files():
from huggingface_hub import hf_hub_download

REPO_ID = "amd-shark/sdxl-quant-models"
REVISION = "82e06d6ea22ac78102a9aded69e8ddfb9fa4ae37"

def download(filename):
return hf_hub_download(
repo_id=REPO_ID, subfolder="unet/int8", filename=filename, revision=REVISION
)

return {
"config.json": download("config.json"),
"params.safetensors": download("params.safetensors"),
"quant_params.json": download("quant_params.json"),
}


@pytest.fixture
def sdxl_int8_dataset(sdxl_int8_base_files, temp_dir):
from sharktank.models.punet.tools import import_brevitas_dataset

dataset = temp_dir / "dataset.irpa"
import_brevitas_dataset.main(
[
f"--config-json={sdxl_int8_base_files['config.json']}",
f"--params={sdxl_int8_base_files['params.safetensors']}",
f"--quant-params={sdxl_int8_base_files['quant_params.json']}",
f"--output-irpa-file={dataset}",
]
)
yield dataset


################################################################################
# Export fixtures
################################################################################


@pytest.fixture
def sdxl_fp16_export_mlir(sdxl_fp16_dataset, temp_dir):
from sharktank.models.punet.tools import run_punet

output_path = temp_dir / "sdxl_fp16_export_mlir.mlir"
print(f"Exporting to {output_path}")
run_punet.main(
[
f"--irpa-file={sdxl_fp16_dataset}",
"--dtype=float16",
f"--device=cpu",
f"--export={output_path}",
]
)
return output_path


@pytest.mark.model_punet
@pytest.mark.export
def test_sdxl_export_fp16_mlir(sdxl_fp16_export_mlir):
print(f"Exported: {sdxl_fp16_export_mlir}")


@pytest.fixture
def sdxl_int8_export_mlir(sdxl_int8_dataset, temp_dir):
from sharktank.models.punet.tools import run_punet

output_path = temp_dir / "sdxl_int8_export_mlir.mlir"
print(f"Exporting to {output_path}")
run_punet.main(
[
f"--irpa-file={sdxl_int8_dataset}",
"--dtype=float16",
f"--device=cpu",
f"--export={output_path}",
]
)
return output_path


@pytest.mark.model_punet
@pytest.mark.export
def test_sdxl_export_int8_mlir(sdxl_int8_export_mlir):
print(f"Exported: {sdxl_int8_export_mlir}")


################################################################################
# Eager tests
################################################################################


@pytest.mark.model_punet
@pytest.mark.golden
def test_punet_eager_fp16_validation(punet_goldens, sdxl_fp16_dataset, temp_dir):
from sharktank.models.punet.tools import run_punet

device = testing.get_best_torch_device()
output_path = temp_dir / "actual_outputs.safetensors"
print("Using torch device:", device)
run_punet.main(
[
f"--irpa-file={sdxl_fp16_dataset}",
"--dtype=float16",
f"--device={device}",
f"--inputs={punet_goldens['inputs.safetensors']}",
f"--outputs={output_path}",
]
)
testing.assert_golden_safetensors(output_path, punet_goldens["outputs.safetensors"])


# Executes eagerly using custom integer kernels.
@pytest.mark.model_punet
@pytest.mark.expensive
@pytest.mark.golden
@pytest.mark.skip("Not yet working")
def test_punet_eager_int8_validation(punet_goldens, sdxl_int8_dataset, temp_dir):
from sharktank.models.punet.tools import run_punet

device = testing.get_best_torch_device()
output_path = temp_dir / "actual_outputs.safetensors"
print("Using torch device:", device)
run_punet.main(
[
f"--irpa-file={sdxl_int8_dataset}",
"--dtype=float16",
f"--device={device}",
f"--inputs={punet_goldens['inputs.safetensors']}",
f"--outputs={output_path}",
]
)
testing.assert_golden_safetensors(output_path, punet_goldens["outputs.safetensors"])


# Executes using emulated fp kernels for key integer operations.
# Useful for speed/comparison.
@pytest.mark.model_punet
@pytest.mark.golden
def test_punet_eager_int8_emulated_validation(
punet_goldens, sdxl_int8_dataset, temp_dir
):
from sharktank.models.punet.tools import run_punet

device = testing.get_best_torch_device()
output_path = temp_dir / "actual_outputs.safetensors"
print("Using torch device:", device)
with testing.override_debug_flags(
{
"use_custom_int_conv_kernel": False,
"use_custom_int_mm_kernel": False,
}
):
run_punet.main(
[
f"--irpa-file={sdxl_int8_dataset}",
"--dtype=float16",
f"--device={device}",
f"--inputs={punet_goldens['inputs.safetensors']}",
f"--outputs={output_path}",
]
)
testing.assert_golden_safetensors(
output_path, punet_goldens["outputs_int8.safetensors"]
)
15 changes: 14 additions & 1 deletion sharktank/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,22 @@ requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[tool.pytest.ini_options]
addopts = "-ra"
addopts = [
"-ra",
"--import-mode=importlib",
"-m=unit",
]
markers = [
"expensive: tests that are very expensive",
"export: tests that require export from torch",
"golden: tests that compare to some golden values",
"integration: marks tests as integration, requiring access to network/models",
"model_punet: tests specific to the punet model",
"unit: unit tests requiring no out of repo resources",
]
testpaths = [
"tests",
"integration",
]
pythonpath = [
".",
Expand Down
16 changes: 16 additions & 0 deletions sharktank/sharktank/models/punet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ python -m sharktank.models.punet.tools.run_diffuser_ref
python -m sharktank.models.punet.tools.run_punet --irpa-file ~/models/punet_fp16.irpa
```

## Integration Testing

Integration testing is set up via pytest:

```
pytest -v -m punet
```

These perform a variety of expensive tests that involve downloading live data
that can be of considerable size. It is often helpful to run specific tests
with the `-s` option (stream output) and by setting `SHARKTANK_TEST_ASSETS_DIR`
to an explicit temp directory (in this mode, the temp directory will not
be cleared, allowing you to inspect assets and intermediates -- but delete
manually as every run will accumulate). Filtering by test name with
`-k test_some_name` is also useful. Names have been chosen to facilitate this.

## License

Significant portions of this implementation were derived from diffusers,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import json
from pathlib import Path
import safetensors
import sys
import torch

from ....types import *
Expand Down Expand Up @@ -138,7 +139,6 @@ def _get_json_tensor(name: str, dtype: torch.dtype) -> Optional[torch.Tensor]:
# Spot check that things look sane.
weight_dequant = weight_quant.unpack().dequant()
weight_diff = weight.as_torch() - weight_dequant
print("WEIGHT_DIFF:", weight_diff)

# Bias/output scaling.
bias = layer_theta.optional_tensor("bias")
Expand All @@ -157,7 +157,6 @@ def _get_json_tensor(name: str, dtype: torch.dtype) -> Optional[torch.Tensor]:
# Spot check that things look sane.
bias_dequant = bias_quant.unpack().dequant()
bias_diff = bias.as_torch() - bias_dequant
print("BIAS_DIFF:", bias_diff)

# Input scaling.
# Assume per tensor scaling of input.
Expand All @@ -171,7 +170,7 @@ def _get_json_tensor(name: str, dtype: torch.dtype) -> Optional[torch.Tensor]:
updated_tensors[input_quantizer.name] = input_quantizer


def main():
def main(argv):
from ....utils import cli

parser = cli.create_parser()
Expand All @@ -196,7 +195,7 @@ def main():
type=Path,
help="Base parameters to initialize from (will be augmented with quantized)",
)
args = cli.parse(parser)
args = cli.parse(parser, args=argv)

config_json_path: Path = args.config_json
params_path: Path = args.params
Expand Down Expand Up @@ -241,4 +240,4 @@ def main():


if __name__ == "__main__":
main()
main(sys.argv[1:])
Loading

0 comments on commit 471473b

Please sign in to comment.