-
Notifications
You must be signed in to change notification settings - Fork 400
Open
Description
Repro:
pip install torchao
pip install diffusers
- model.compile()
from diffusers import PipelineQuantizationConfig, DiffusionPipeline, TorchAoConfig
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
import torch
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(Float8DynamicActivationFloat8WeightConfig(version=2)),
},
)
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
# pipe.transformer.compile_repeated_blocks()
pipe.transformer.compile()
_ = pipe("a dog", num_inference_steps=4)
Error: https://gist.github.com/jerryzh168/0fc8b4284b69ab4d9b3e49d378eaf166
- model.compile_repeated_blocks()
from diffusers import PipelineQuantizationConfig, DiffusionPipeline, TorchAoConfig
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
import torch
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(Float8DynamicActivationFloat8WeightConfig(version=2)),
},
)
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
pipe.transformer.compile_repeated_blocks()
# pipe.transformer.compile()
_ = pipe("a dog", num_inference_steps=4)
Error:
Traceback (most recent call last):
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 986, in _apply
torch.utils.swap_tensors(param, param_applied)
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/utils/__init__.py", line 45, in swap_tensors
raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
RuntimeError: Cannot swap t1 because it has weakref associated with it
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/jerryzh/ao/diffusers_repro.py", line 23, in <module>
_ = pipe("a dog", num_inference_steps=4)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/diffusers-0.37.0.dev0-py3.12.egg/diffusers/pipelines/flux/pipeline_flux.py", line 1011, in __call__
image = self.vae.decode(latents, return_dict=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/diffusers-0.37.0.dev0-py3.12.egg/diffusers/utils/accelerate_utils.py", line 45, in wrapper
self._hf_hook.pre_forward(self)
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/accelerate/hooks.py", line 723, in pre_forward
self.prev_module_hook.offload()
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/accelerate/hooks.py", line 746, in offload
self.hook.init_hook(self.model)
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/accelerate/hooks.py", line 714, in init_hook
return module.to("cpu")
^^^^^^^^^^^^^^^^
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/diffusers-0.37.0.dev0-py3.12.egg/diffusers/models/modeling_utils.py", line 1432, in to
return super().to(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1383, in to
return self._apply(convert)
^^^^^^^^^^^^^^^^^^^^
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 933, in _apply
module._apply(fn)
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 933, in _apply
module._apply(fn)
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 933, in _apply
module._apply(fn)
[Previous line repeated 1 more time]
File "/home/jerryzh/.conda/envs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 990, in _apply
raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Linear.weight
related issue: pytorch/pytorch#141548
Metadata
Metadata
Assignees
Labels
No labels