Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
It appears the focus of the community has been largely shifted to Flux.dev1. So the main purpose of this PR is to demonstrate the capability of Candle and serves a smoke-test to the MMDiT (#2397).
As such, I intend to minimize the intrusive change to the existing stable-diffusion codebase, such as using renaming function to adapt the VAE var-builder to the official safetensor weights of SD3 VAE. Still, there are some changes I have to make to
candle_nn::stable_diffusion
to support the CLIP and VAE of SD3, including:forward_until_encoder_layer
toClipTextTransformer
. The Comfy implementation for SD3 uses the penultimate hidden layer of CLIP-l and CLIP-g instead of the final layer (see sd3_clip.py and sdxl_clip.py). This practice, although not mentioned in the SD3 tech report, is referred and specified in Chapter 2.1 of the SDXL tech report.use_quant_conv
anduse_post_quant_conv
options to theAutoEncoderKL
, as SD3's VAE does not have those layers. These changes might be considered unspecific to SD3, asdiffusers
has these options supported.get_qkv_linear
to load the attention block incandle_nn::stable-diffusion::attention
, as some weight of linear layer of VAE in official SD3 Medium safetensors follow the dimension convention of(channel, channel, 1, 1)
instead of the regular(channel, channel)
that is natually supported bynn::linear
constructor.These changes allows reusing existing CLIP and VAE implementations, but inevitably add complexity to existing codebase. @LaurentMazare Let me know if these intrusive changes are justified. We may consider alternatives like re-implementing VAE and CLIP from scratch.
On top of these changes, I added the support to flash-attention for MMDiT based on whether the feature
flash-attn
is enabled. Also done a simple performance benchmark on GPUs like 3090 Ti and 4090.A side note is the T5 implementation on current main branch hasn't supported for FP16. I attempted to insert simple clampings within the FP16 dynamic range but it didn't work well on my GPUs. Looks like I need to wait for a more sophiscated implementation such as #2481. So for now, I use two different VarBuilders, one maps weights in safetensor into FP32 specifically for T5, the other for the rest compoents.