-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce StableDiffusion memory usage #147
Comments
More on attention: https://pytorch.org/blog/flash-decoding/ |
I'd also suggest FlashAttention-2 and Medusa |
Alternative to DPM Solver: https://arxiv.org/abs/2311.05556 |
I tested SD v1-4 on a GPU using the new lower precision options
Note that the reported memory is just the final memory after using Source (first row)# Stable Diffusion testing
```elixir
Mix.install([
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
{:axon, github: "elixir-nx/axon", override: true},
{:kino, "~> 0.11.3"},
{:bumblebee, github: "elixir-nx/bumblebee"}
])
Application.put_env(:exla, :clients,
host: [platform: :host],
cuda: [platform: :cuda, preallocate: false]
# cuda: [platform: :cuda, memory_fraction: 0.3]
# cuda: [platform: :cuda]
)
Application.put_env(:exla, :preferred_clients, [:cuda, :host])
Nx.global_default_backend({EXLA.Backend, client: :host})
```
## init
```elixir
with {output, 0} <- System.shell("nvidia-smi --query-gpu=memory.total,memory.used --format=csv") do
IO.puts(output)
end
```
<!-- livebook:{"branch_parent_index":0} -->
## Stable Diffusion fp16
```elixir
repository_id = "CompVis/stable-diffusion-v1-4"
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "openai/clip-vit-large-patch14"})
{:ok, clip} =
Bumblebee.load_model({:hf, repository_id, subdir: "text_encoder"},
params_variant: "fp16",
type: :bf16
)
{:ok, unet} =
Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
params_variant: "fp16",
type: :bf16
)
{:ok, vae} =
Bumblebee.load_model({:hf, repository_id, subdir: "vae"},
architecture: :decoder,
params_variant: "fp16",
type: :bf16
)
{:ok, scheduler} = Bumblebee.load_scheduler({:hf, repository_id, subdir: "scheduler"})
clip = update_in(clip.params, &Nx.backend_copy(&1, {EXLA.Backend, client: :cuda}))
unet = update_in(unet.params, &Nx.backend_copy(&1, {EXLA.Backend, client: :cuda}))
vae = update_in(vae.params, &Nx.backend_copy(&1, {EXLA.Backend, client: :cuda}))
serving =
Bumblebee.Diffusion.StableDiffusion.text_to_image(clip, unet, vae, tokenizer, scheduler,
num_steps: 20,
num_images_per_prompt: 1,
compile: [batch_size: 1, sequence_length: 60],
defn_options: [compiler: EXLA]
)
Kino.start_child({Nx.Serving, name: SD, serving: serving})
```
```elixir
prompt = "numbat, forest, high quality, detailed, digital art"
output = Nx.Serving.batched_run(SD, prompt)
for result <- output.results do
Kino.Image.new(result.image)
end
|> Kino.Layout.grid(columns: 2)
```
|
I experimented with different values of
So lazy transfers do help a bit, but imply a significant slowdown. What's interesting though is that |
preallocate/jit will transfer the data twice, one as arguments, one as return type. So we probably need a new callback/abstraction to make this easier :D |
FTR fixed in #317, now |
I think we should update Axon to better support LoRA, I have a draft in place right now but I have to revisit it to make it work as I intend :) |
LCM just adapts these nodes in the unet model: https://github.com/wtedw/lorax/blob/main/lib/lorax/lcm.ex#L121-L139 For Bumblebee, (if trying to make it compatible w/ most LoRA files in HuggingFace)
If you guys need any PRs, lmk! |
Just a heads up that Stability AI just announced Stable Diffusion 3, so that makes us wonder how much effort we should pour into SD vs SDXL vs SD3. It still probably makes sense to support LoRA on Stable Diffusion, because that will require improvements in Axon and elsewhere that we could use for other models, but custom schedulers and token merging is up to debate at the moment. |
Checking off attention slicing, it has actually been removed from diffusers docs (huggingface/diffusers#4487) because of flash attention. Either way, the trick is about slicing a dimension and using a while loop, which is similar to flash attention on defn level (as opposed to custom CUDA kernel), and that didn't turn out to be beneficial. |
The main part of StableDiffusion is iterative U-Net model pass, which happens for a specified number of timesteps. DeepCache is about reusing some of the intermediate layer outputs across some diffusion iterations, that is outputs expected to change slowly over time. This technique is not going to reduce memory usage, because we still need to periodically do a uncached model pass. Given that we need to keep the cached intermediate results, it can increase the usage if anything. It can have a significant speedup, assuming we do a fair amount of steps. For SD Turbo or LCM, where we do 1 or at most a few steps, the caching is not applicable. So this may be something we want to explore in the future, depending on SD3 and other research going forward, but I don't think it's immediately relevant for us now. |
A list of ideas to explore:
Attention slicing(no longer applicable Remove attention slicing from docs huggingface/diffusers#4487)Flash attention (JAX version)(see notes in Refactor attention implementation #300)DeepCache(not applicable Reduce StableDiffusion memory usage #147 (comment))The text was updated successfully, but these errors were encountered: