Skip to content

diffuser <> torchao model.compile with tensor subclasses issue #3632

@jerryzh168

Description

@jerryzh168

Repro:

pip install torchao
pip install diffusers
  1. model.compile()
from diffusers import PipelineQuantizationConfig, DiffusionPipeline, TorchAoConfig
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
import torch


pipeline_quant_config = PipelineQuantizationConfig(
    quant_mapping={
        "transformer": TorchAoConfig(Float8DynamicActivationFloat8WeightConfig(version=2)),
    },
)
pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    quantization_config=pipeline_quant_config,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
# pipe.transformer.compile_repeated_blocks()
pipe.transformer.compile()

_ = pipe("a dog", num_inference_steps=4)

Error: https://gist.github.com/jerryzh168/0fc8b4284b69ab4d9b3e49d378eaf166

  1. model.compile_repeated_blocks()
from diffusers import PipelineQuantizationConfig, DiffusionPipeline, TorchAoConfig
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
import torch


pipeline_quant_config = PipelineQuantizationConfig(
    quant_mapping={
        "transformer": TorchAoConfig(Float8DynamicActivationFloat8WeightConfig(version=2)),
    },
)
pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    quantization_config=pipeline_quant_config,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
pipe.transformer.compile_repeated_blocks()
# pipe.transformer.compile()

_ = pipe("a dog", num_inference_steps=4)

Error:

Traceback (most recent call last):
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 986, in _apply
    torch.utils.swap_tensors(param, param_applied)
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/utils/__init__.py", line 45, in swap_tensors
    raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
RuntimeError: Cannot swap t1 because it has weakref associated with it

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/jerryzh/ao/diffusers_repro.py", line 23, in <module>
    _ = pipe("a dog", num_inference_steps=4)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/diffusers-0.37.0.dev0-py3.12.egg/diffusers/pipelines/flux/pipeline_flux.py", line 1011, in __call__
    image = self.vae.decode(latents, return_dict=False)[0]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/diffusers-0.37.0.dev0-py3.12.egg/diffusers/utils/accelerate_utils.py", line 45, in wrapper
    self._hf_hook.pre_forward(self)
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/accelerate/hooks.py", line 723, in pre_forward
    self.prev_module_hook.offload()
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/accelerate/hooks.py", line 746, in offload
    self.hook.init_hook(self.model)
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/accelerate/hooks.py", line 714, in init_hook
    return module.to("cpu")
           ^^^^^^^^^^^^^^^^
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/diffusers-0.37.0.dev0-py3.12.egg/diffusers/models/modeling_utils.py", line 1432, in to
    return super().to(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1383, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 933, in _apply
    module._apply(fn)
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 933, in _apply
    module._apply(fn)
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 933, in _apply
    module._apply(fn)
  [Previous line repeated 1 more time]
  File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 990, in _apply
    raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Linear.weight

related issue: pytorch/pytorch#141548

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions