Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading lora weights for FLUX pipeline is extremely slow #2055

Open
2 of 4 tasks
nachoal opened this issue Sep 8, 2024 · 11 comments
Open
2 of 4 tasks

Loading lora weights for FLUX pipeline is extremely slow #2055

nachoal opened this issue Sep 8, 2024 · 11 comments

Comments

@nachoal
Copy link

nachoal commented Sep 8, 2024

System Info

Installed packages:

accelerate==0.34.0
asttokens==2.4.1
certifi==2024.8.30
charset-normalizer==3.3.2
comm==0.2.2
compel==2.0.3
debugpy==1.8.5
decorator==5.1.1
diffusers==0.30.2
exceptiongroup==1.2.2
executing==2.1.0
filelock==3.15.4
fsspec==2024.6.1
huggingface-hub==0.24.6
idna==3.8
importlib_metadata==8.4.0
ipykernel==6.29.5
ipython==8.27.0
ipywidgets==8.1.5
jedi==0.19.1
Jinja2==3.1.4
jupyter_client==8.6.2
jupyter_core==5.7.2
jupyterlab_widgets==3.0.13
lark==1.2.2
MarkupSafe==2.1.5
matplotlib-inline==0.1.7
mpmath==1.3.0
nest-asyncio==1.6.0
networkx==3.3
ninja==1.11.1.1
numpy==2.1.1
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.68
nvidia-nvtx-cu12==12.1.105
optimum-quanto==0.2.4
packaging==24.1
parso==0.8.4
peft==0.12.0
pexpect==4.9.0
pillow==10.4.0
platformdirs==4.2.2
prompt_toolkit==3.0.47
protobuf==5.28.0
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
Pygments==2.18.0
pyparsing==3.1.4
python-dateutil==2.9.0.post0
PyYAML==6.0.2
pyzmq==26.2.0
regex==2024.7.24
requests==2.32.3
safetensors==0.4.4
sd-embed==1.240829.1
sentencepiece==0.2.0
six==1.16.0
stack-data==0.6.3
sympy==1.13.2
tokenizers==0.19.1
torch==2.4.0
torchaudio==2.4.0
torchvision==0.19.0
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
transformers==4.44.2
triton==3.0.0
typing_extensions==4.12.2
urllib3==2.2.2
wcwidth==0.2.13
widgetsnbextension==4.0.13
zipp==3.20.1

Python version: 3.10.12
System: Linux 1936c0a77ae2 5.15.0-102-generic #112-Ubuntu SMP Tue Mar 5 16:50:32 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

nvidia-smi
CleanShot 2024-09-07 at 18 58 40@2x

Who can help?

@BenjaminBossan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

import torch
from diffusers import FluxPipeline
import time
import os
import requests
import zipfile
import tempfile
import shutil

# Variables for easy switching
lora_zip_url = "https://pub-67101706c9e843179b611abc87947bd8.r2.dev/training_results_c2a28eb3-1797-419e-ac89-4ba1de9d229b/training_results.zip"
num_images = 1  # Number of images to generate

def print_benchmark(name, duration):
    print(f"{name}: {duration:.2f} seconds")

# Benchmark: Create pipeline
start_time = time.time()
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
    use_safetensors=True
)
create_pipeline_time = time.time() - start_time
print_benchmark("Create pipeline", create_pipeline_time)

# Benchmark: Move to CUDA
start_time = time.time()
pipe.to("cuda")
move_to_cuda_time = time.time() - start_time
print_benchmark("Move to CUDA", move_to_cuda_time)

# Benchmark: Download, extract, and load LoRA
start_time = time.time()
print("Downloading LoRA zip file...")
response = requests.get(lora_zip_url)
zip_path = tempfile.mktemp(suffix=".zip")
with open(zip_path, 'wb') as f:
    f.write(response.content)
print("LoRA zip file downloaded")

print("Extracting LoRA zip file...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(tempfile.gettempdir())
print("LoRA zip file extracted")

extracted_folder = os.path.join(tempfile.gettempdir(), zip_ref.namelist()[0].split('/')[0])
print("Loading LoRA weights...")
pipe.load_lora_weights(extracted_folder, weight_name="flux_train_latentgen.safetensors", adapter_name="main_lora")
print("LoRA weights loaded")

# Clean up
os.remove(zip_path)
shutil.rmtree(extracted_folder)

load_lora_time = time.time() - start_time
print_benchmark("Download, extract, and load LoRA", load_lora_time)

# Benchmark: Generate images
start_time = time.time()
prompt = "sks man"
images = pipe(
    prompt=prompt,
    num_images_per_prompt=num_images,
    height=1024,
    width=768,
).images
generate_images_time = time.time() - start_time
print_benchmark("Generate images", generate_images_time)

# Print total time
total_time = create_pipeline_time + move_to_cuda_time + load_lora_time + generate_images_time
print_benchmark("Total time", total_time)

# Return the final image (or first image if multiple were generated)
final_image = images[0]
print("Image generation complete. Final image returned.")

# Save the image
final_image.save("output.png")
print("Image saved as output.png")

# Unload LoRA weights
pipe.unload_lora_weights()
print("Unloaded LoRA weights")

Expected behavior

Issue

LoRA loading takes more than 200 seconds in the load_lora_weights step for custom loras, the script above includes a working url for a LoRA trained on a random person which you can use for testing, you will see that times are considerably long (more than 20 seconds on average) when running it even with 100% GPU usage for the script

Expected behavior

Lower load times

Reproduction steps

I'm running the script with the following command:

(.env) root@machine# HF_TOKEN=hf_token python test.py 

Then I get the following benchmarks:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.70it/s]
Loading pipeline components...:  86%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                     | 6/7 [00:01<00:00,  3.93it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  4.06it/s]
Create pipeline: 1.95 seconds
Move to CUDA: 9.97 seconds
Downloading LoRA zip file...
LoRA zip file downloaded
Extracting LoRA zip file...
LoRA zip file extracted
Loading LoRA weights...
LoRA weights loaded
Download, extract, and load LoRA: 294.35 seconds
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00,  2.37it/s]
Generate images: 12.68 seconds
Total time: 318.95 seconds
Image generation complete. Final image returned.
Image saved as output.png
Unloaded LoRA weights

Notice the time it takes to load the lora weights, the step for pipe.load_lora_weights(extracted_folder, weight_name="flux_train_latentgen.safetensors", adapter_name="main_lora") is taking more than 290 seconds in total. I have tried:

  • pipe.enable_model_cpu_offload() pipe.to("cuda")
  • pipe.to("cuda") as currently
  • Different size loras but every single time it takes a lot of time to load

Help

I would appreciate your help in any pointers or optimizations apart from leaving the LoRA in VRAM? (This is part of a larger process that receives a loRA url and processes images based on it so at any given time I need to be able to load and unload flux loras as fast as possible. Thanks!

@BenjaminBossan
Copy link
Member

Thanks for the detailed description. I cannot reproduce the issue of slow LoRA loading time, for me it takes 0.77 sec to load the LoRA weights (excluding download and extraction time).

However, I do have a suspicion of what's going on. Based on your screenshot, your GPU does not have enough memory to load the transformer part of the Flux model to GPU, which is the largest component by far. Thus it's loaded on CPU, which slows down the LoRA loading considerably. Could this be the reason? Please check the device of pipe.transformer.

Just for reference, I'm using 2 GPUs with 24 GB of VRAM and need to load the model like so:

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    device_map="balanced",
    max_memory={0: "24GB", 1: "20GB"},
)

When I check the devices after loading the LoRA adapter, I get:

TEXT_ENCODER
obj.device=device(type='cuda', index=1)
param devices: {device(type='cuda', index=1)}
param dtypes:  {torch.bfloat16}
num params:    123,060,480
TEXT_ENCODER_2
obj.device=device(type='cuda', index=1)
param devices: {device(type='cuda', index=1)}
param dtypes:  {torch.bfloat16}
num params:    4,762,310,656
TRANSFORMER
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    11,977,299,008
VAE
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    83,819,683

(Using this code:

attrs = ["text_encoder", "text_encoder_2", "transformer", "vae"]
for attr in attrs:
    obj = getattr(pipe, attr)
    if not obj:
        continue
    print(attr.upper())
    print(f"{obj.device=}")
    print("param devices:", {p.device for p in obj.parameters()})
    print("param dtypes: ", {p.dtype for p in obj.parameters()})
    print(f"num params:    {sum(p.numel() for p in obj.parameters()):,}")

)

The PR you referenced should still help with loading times on CPU, but it's unfortunately not in a state yet that it can be tested. Inference would still be super slow.

@nachoal
Copy link
Author

nachoal commented Sep 9, 2024

Hi @BenjaminBossan,

I added your devices check before loading the LoRA and after it and got CUDA so I don't think that transformers is being loaded on the CPU, note that times are still on the 200+ second range.

I'm curious, did you use my example lora url? Could not having a lora parameter file affect the loading part?

Im using an A100 SXM GPU which has 80GB VRAM so I don't think that the issue is on that side.

Pipeline components after creation:
TEXT_ENCODER
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    123,060,480
TEXT_ENCODER_2
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    4,762,310,656
TRANSFORMER
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    11,901,408,320
VAE
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    83,819,683
Downloading LoRA zip file...
LoRA zip file downloaded
Extracting LoRA zip file...
LoRA zip file extracted
Loading LoRA weights...
LoRA weights loaded
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00,  2.35it/s]
Generate images: 13.82 seconds

Pipeline components before unloading LoRA:
TEXT_ENCODER
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    123,060,480
TEXT_ENCODER_2
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    4,762,310,656
TRANSFORMER
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    11,987,326,016
VAE
obj.device=device(type='cuda', index=0)
param devices: {device(type='cuda', index=0)}
param dtypes:  {torch.bfloat16}
num params:    83,819,683

Final Benchmarks:
{'Create pipeline': 47.32027888298035,
 'Download LoRA': 6.121661424636841,
 'Extract LoRA': 2.003812789916992,
 'Generate images': 13.817049741744995,
 'Load LoRA': 274.5764060020447,
 'Move to CUDA': 8.013010263442993,
 'Total time': 351.85221910476685}
Total time: 343.73 seconds
Image generation complete. Final image returned.
Image saved as output.png
Unloaded LoRA weights

@BenjaminBossan
Copy link
Member

Thanks for the extra info, then my suspicion about the device proved to be incorrect. Indeed it is very strange why it would take so long for you whereas it takes less than 1 sec for me. Could you please take separate timings for the download and extraction (which I skipped on my test) vs the actual loading of the weights? Alternatively, could you just use a local file instead of downloading and deleting it each time? If you do the latter, please run the benchmark at least twice to check if that makes a difference.

@the-dream-machine
Copy link

the-dream-machine commented Sep 12, 2024

Just want to chime in here, I'm stumped by a similar slowdown using StableDiffusionXLPipeline. The LoRA (217.89 MB) takes ~3 seconds to load, while the base model (6.5 GB) also takes ~3 seconds to load. I'm loading everything from disk.

System information:

OS = linux (debian-slim:12.5)
GPU = A10G, 24GB VRAM
Python=3.12.5

Here are my package versions:

"diffusers==0.30.2",
"transformers==4.44.2",
"accelerate==0.34.2",
"safetensors==0.4.4",
"torch==2.4.1",
"peft==0.12.0"

Here is my code:

MODEL_URL = "https://civitai.com/api/download/models/290640?type=Model&format=SafeTensor&size=pruned&fp=fp16"
MODEL_FILE_PATH = "/models/checkpoints/pony_diffusion_v6_xl.safetensors"

VAE_URL="https://civitai.com/api/download/models/290640?type=VAE&format=SafeTensor"
VAE_FILE_PATH="/models/vae/pony_diffusion_v6_xl_vae.safetensors"

STYLE_LORA_URL="https://civitai.com/api/download/models/820564?type=Model&format=SafeTensor&token=<YOUR_CIVITAI_TOKEN>"
STYLE_LORA_FILE_PATH="/models/loras/pony/pony_style_lora.safetensors"


# Load VAE
vae_load_start = perf_counter()
vae = AutoencoderKL.from_single_file(VAE_FILE_PATH, torch_dtype=torch.float16)
vae_load_end = perf_counter()
print(f"VAE loaded in {vae_load_end - vae_load_start} seconds")

# Load Model
model_load_start = perf_counter()
self.pipeline = StableDiffusionXLPipeline.from_single_file(
    MODEL_FILE_PATH,
    vae=vae,
    safety_checker=None,
    torch_dtype=torch.float16,
)
model_load_end = perf_counter()
print(f"Model loaded in {model_load_end - model_load_start} seconds")

self.pipeline.to("cuda")

# Set scheduler to Euler Ancestral Discrete Scheduler
self.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipeline.scheduler.config)

# Load LoRAs
lora_load_start = perf_counter()
self.pipeline.load_lora_weights(STYLE_LORA_FILE_PATH, adapter_name="style")
lora_load_end = perf_counter()
print(f"LoRA loaded in {lora_load_end - lora_load_start} seconds")

 # Add LoRAs to the pipeline
self.pipeline.set_adapters(["style"], adapter_weights=[0.8])

# Generate the image
image_generation_start = perf_counter()
full_prompt = f"score_9, score_8_up, score_7_up, score_6_up,source_anime, {prompt}"
image = self.pipeline(
    prompt=full_prompt,    
    negative_prompt="bad quality, score_3, score_2, score_1",
    height=1024,
    width=1024,
    num_inference_steps=25,
    guidance_scale=8.5,
).images[0]
image_generation_end = perf_counter()
print(f"Image generated in {image_generation_end - image_generation_start} seconds")

# Unload LoRAs from the pipeline
unload_lora_start = perf_counter()
self.pipeline.unload_lora_weights()
unload_lora_end = perf_counter()
print(f"LoRA unloaded in {unload_lora_end - unload_lora_start} seconds")

Sample benchmarks:

Model loaded in 3.0992493440000004 seconds
LoRA loaded in 2.928138479000001 seconds
...

These issues seem related, Please let me know if this is the right place to bring this up or if I should open a new issue.

@BenjaminBossan
Copy link
Member

These issues seem related, Please let me know if this is the right place to bring this up or if I should open a new issue.

Judging by the numbers you get, I don't think the issues are related. But I do think that your issue could be the same as huggingface/diffusers#8953 and will be addressed by the work in #1961.

@BasmaElhoseny01
Copy link

@nachoal Hello I have small question i have fine tuned flux dev on replicate i have 2 files config.yamal and lora.safetensors
i want to use this fine tuned model on Colab using Flux Pipeline offered by Hugging face but I don't know how can you help

image

These are the files i got from replicate

@cshowley
Copy link

cshowley commented Sep 28, 2024

Was there a resolution to this? I'm having similar difficulties to @nachoal where flux-dev loads relatively quickly but loading any LORA, regardless of size, takes multiple minutes. This slow loading time persists if I'm loading local files as well. I'm also using an 80GB A100.

Example code:

from diffusers import Flux Pipeline
import torch

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
).to('cuda')

pipeline.load_lora_weights(
    "XLabs-AI/flux-lora-collection", weight_name="anime_lora.safetensors"
)      

@BenjaminBossan
Copy link
Member

Unfortunately, I also can't reproduce this issue. With the adapter downloaded, loading it took 2.5 sec for me on the first run and 0.5 sec on subsequent runs (probably due to disk caching). Some questions:

  1. Are you on the latest versions of diffusers and PEFT?
  2. If you run this script multiple times in a row with the same LoRA adapter, does it still take the same amount of time for each?
  3. How long do you spend loading the Flux pipeline vs the LoRA adapter?

@cshowley
Copy link

cshowley commented Oct 1, 2024

  1. I'm using a developer release of diffusers: diffusers @ git+https://github.com/huggingface/diffusers@665c6b47a23bc841ad1440c4fe9cbb1782258656
    and peft is version 0.12.0

  2. Yes, each time I load a LORA it's slow. If you need exact benchmarks I can provide those.

  3. The adapter takes significantly longer than flux to load. Flux loads in seconds while the LORA takes minutes.

@BasmaElhoseny01
Copy link

Unfortunately, I also can't reproduce this issue. With the adapter downloaded, loading it took 2.5 sec for me on the first run and 0.5 sec on subsequent runs (probably due to disk caching). Some questions:

  1. Are you on the latest versions of diffusers and PEFT?
  2. If you run this script multiple times in a row with the same LoRA adapter, does it still take the same amount of time for each?
  3. How long do you spend loading the Flux pipeline vs the LoRA adapter?

@BenjaminBossan how can you load the Lora weights plz if you have any hands on material

@BenjaminBossan
Copy link
Member

@cshowley

  1. I'm using a developer release of diffusers: diffusers @ git+https://github.com/huggingface/diffusers@665c6b47a23bc841ad1440c4fe9cbb1782258656
    and peft is version 0.12.0

I tried using these exact versions but still the results don't change, loading the LoRA takes ~0.5 sec, the whole loading script is finished in 9 sec.

2. If you need exact benchmarks I can provide those.

Maybe you could check a couple of other LoRAs to see if it's just this one or if all are slow. But I think the next logical step would be to profile the loading step to figure out what exactly it is that is so slow. Do you have some experience with profiling in Python? If I could reproduce this, I'd do it myself, but as is I can't.

@BasmaElhoseny01 I'm not sure what you mean, in this issue you can find code snippets that show how LoRA is loaded. If you need more general info on how to use LoRA in diffusers, check the diffusers docs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants