Skip to content

Commit c3e76e1

Browse files
authored
Adding svdquant required model type conversion script and fix bugs (#222)
* fix little typo in sana-sprint app. Signed-off-by: lawrence-cj <[email protected]> * update README.md Signed-off-by: lawrence-cj <[email protected]> * add convert sana to svdquant quantization required model type scripts; Signed-off-by: lawrence-cj <[email protected]> * fix model path bugs; Signed-off-by: lawrence-cj <[email protected]> * remove unused args; Signed-off-by: lawrence-cj <[email protected]> * FSDP does not support resume from ddp checkpoint; Signed-off-by: lawrence-cj <[email protected]> * pre-commit; Signed-off-by: lawrence-cj <[email protected]> * pre-commit; Signed-off-by: lawrence-cj <[email protected]> * update workflow; Signed-off-by: lawrence-cj <[email protected]> * fix timestep dtype; Signed-off-by: lawrence-cj <[email protected]> --------- Signed-off-by: lawrence-cj <[email protected]>
1 parent b6af6e6 commit c3e76e1

File tree

9 files changed

+575
-36
lines changed

9 files changed

+575
-36
lines changed

.github/workflows/bot-autolint.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ on:
1111
- unlabeled
1212
# run only one unit test for a branch / tag.
1313
concurrency:
14-
group: ci-lint-${{ github.ref }}
14+
group: ci-lint-${{ github.head_ref || github.ref }}
1515
cancel-in-progress: true
1616
jobs:
1717
lint-by-label:
@@ -21,7 +21,7 @@ jobs:
2121
- name: Check out Git repository
2222
uses: actions/checkout@v4
2323
with:
24-
token: ${{ secrets.PAT }}
24+
token: ${{ secrets.GITHUB_TOKEN }}
2525
ref: ${{ github.event.pull_request.head.ref }}
2626
- name: Set up Python
2727
uses: actions/setup-python@v5

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
<a href="https://nvlabs.github.io/Sana/"><img src="https://img.shields.io/static/v1?label=Project&message=Github&color=blue&logo=github-pages"></a> &ensp;
1111
<a href="https://hanlab.mit.edu/projects/sana/"><img src="https://img.shields.io/static/v1?label=Page&message=MIT&color=darkred&logo=github-pages"></a> &ensp;
1212
<a href="https://arxiv.org/abs/2410.10629"><img src="https://img.shields.io/static/v1?label=Arxiv&message=Sana&color=red&logo=arxiv"></a> &ensp;
13-
<a href="https://nv-sana.mit.edu/"><img src="https://img.shields.io/static/v1?label=Demo:6x3090&message=MIT&color=yellow"></a> &ensp;
13+
<a href="https://nv-sana.mit.edu/"><img src="https://img.shields.io/static/v1?label=Demo:5x3090&message=SANA&color=yellow"></a> &ensp;
14+
<a href="https://nv-sana.mit.edu/sprint/"><img src="https://img.shields.io/static/v1?label=Demo:1x3090&message=SANA-Sprint&color=yellow"></a> &ensp;
1415
<a href="https://nv-sana.mit.edu/4bit/"><img src="https://img.shields.io/static/v1?label=Demo:1x3090&message=4bit&color=yellow"></a> &ensp;
1516
<a href="https://nv-sana.mit.edu/ctrlnet/"><img src="https://img.shields.io/static/v1?label=Demo:1x3090&message=ControlNet&color=yellow"></a> &ensp;
1617
<a href="https://replicate.com/chenxwh/sana"><img src="https://img.shields.io/static/v1?label=API:H100&message=Replicate&color=pink"></a> &ensp;
@@ -25,7 +26,7 @@
2526

2627
### 🚶 Basic:
2728

28-
**Demo**: [SANA-1.5](https://nv-sana.mit.edu/) | [SANA-ControlNet](https://nv-sana.mit.edu/ctrlnet/) | [SANA-4bit](https://nv-sana.mit.edu/4bit/) | [SANA-Sprint (Coming)](<>) <br>
29+
**Demo**: [SANA-1.5](https://nv-sana.mit.edu/) | [SANA-ControlNet](https://nv-sana.mit.edu/ctrlnet/) | [SANA-4bit](https://nv-sana.mit.edu/4bit/) | [SANA-Sprint](https://nv-sana.mit.edu/sprint/) <br>
2930
**ComfyUI**: [ComfyUI Guidance](asset/docs/ComfyUI/comfyui.md) <br>
3031
**Model Zoo:** [Model Card Collects All Models](asset/docs/model_zoo.md) <br>
3132
**Env Preparation:** [One-Click Env Install](#-1-dependencies-and-installation) <br>

app/app_sana_sprint.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,10 @@ def get_args():
159159
parser.add_argument(
160160
"--model_path",
161161
nargs="?",
162-
default="hf://Efficient-Large-Model/SANA_Sprint_1.6B_1024px/checkpoints/SANA_Sprint_1.6B_1024px.pth",
162+
default="hf://Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth",
163163
type=str,
164164
help="Path to the model file (positional)",
165165
)
166-
parser.add_argument("--output", default="./", type=str)
167-
parser.add_argument("--bs", default=1, type=int)
168-
parser.add_argument("--image_size", default=1024, type=int)
169-
parser.add_argument("--cfg_scale", default=3.0, type=float)
170-
parser.add_argument("--seed", default=42, type=int)
171-
parser.add_argument("--step", default=-1, type=int)
172-
parser.add_argument("--custom_image_size", default=None, type=int)
173166
parser.add_argument("--share", action="store_true")
174167
parser.add_argument(
175168
"--shield_model_path",
@@ -184,7 +177,6 @@ def get_args():
184177
args = get_args()
185178

186179
if torch.cuda.is_available():
187-
weight_dtype = torch.float16
188180
model_path = args.model_path
189181
pipe = SanaSprintPipeline(args.config)
190182
pipe.from_pretrained(model_path)
@@ -300,7 +292,7 @@ def generate(
300292
)
301293

302294

303-
model_size = "1.6" if "1600M" in args.model_path else "0.6"
295+
model_size = "1.6" if "1.6B" in args.model_path else "0.6"
304296
title = f"""
305297
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
306298
<img src="https://nvlabs.github.io/Sana/Sprint/asset/SANA-Sprint.png" width="50%" alt="logo"/>
@@ -309,7 +301,7 @@ def generate(
309301
DESCRIPTION = f"""
310302
<p><span style="font-size: 36px; font-weight: bold;">SANA-Sprint-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
311303
<p style="font-size: 16px; font-weight: bold;">SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation</p>
312-
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2503.09641">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github(coming soon)]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
304+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2503.09641">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
313305
<p style="font-size: 16px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
314306
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
315307
"""

asset/docs/4bit_sana.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@
1919

2020
Follow the official [SVDQuant-Nunchaku](https://github.com/mit-han-lab/nunchaku) repository to set up the environment. The guidance can be found [here](https://github.com/mit-han-lab/nunchaku?tab=readme-ov-file#installation).
2121

22+
### 1-1. Quantize Sana with SVDQuant-4bit (Optional)
23+
24+
1. Convert pth to SVDQuant required safetensor
25+
26+
```
27+
python tools/convert_sana_to_svdquant.py \
28+
--orig_ckpt_path Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth \
29+
--model_type SanaMS1.5_1600M_P1_D20 \
30+
--dtype bf16 \
31+
--dump_path output/SANA1.5_1.6B_1024px_svdquant_diffusers \
32+
--save_full_pipeline
33+
```
34+
35+
2. follow the guidance to compress model
36+
[Quantization guidance](https://github.com/mit-han-lab/deepcompressor/tree/main/examples/diffusion)
37+
2238
### 2. Code snap for inference
2339

2440
Here we show the code snippet for SanaPipeline. For SanaPAGPipeline, please refer to the [SanaPAGPipeline](https://github.com/mit-han-lab/nunchaku/blob/main/examples/sana_1600m_pag.py) section.

configs/sana1-5_config/1024ms/Sana_1600M_1024px_AdamW_fsdp.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ model:
1919
image_size: 1024
2020
mixed_precision: bf16
2121
fp32_attention: true
22-
load_from: hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth
22+
# load_from: hf://Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth
23+
load_from:
2324
aspect_ratio_type: ASPECT_RATIO_1024
2425
multi_scale: true
2526
attn_type: linear

configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ model:
2020
mixed_precision: bf16
2121
fp32_attention: true
2222
teacher_model: hf://Efficient-Large-Model/Sana_Sprint_1.6B_1024px_teacher/checkpoints/Sana_Sprint_1.6B_1024px_teacher.pth
23-
load_from: hf://Efficient-Large-Model/SANA_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth
23+
load_from: hf://Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth
2424
resume_from:
2525
aspect_ratio_type: ASPECT_RATIO_1024
2626
multi_scale: true

diffusion/model/nets/sana_multi_scale.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from timm.models.layers import DropPath
2323

2424
from diffusion.model.builder import MODELS
25-
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp
25+
from diffusion.model.nets.basic_modules import DWMlp, GLUMBConv, Mlp
2626
from diffusion.model.nets.sana import Sana, get_2d_sincos_pos_embed
2727
from diffusion.model.nets.sana_blocks import (
2828
Attention,
@@ -279,9 +279,9 @@ def forward(self, x, timestep, y, mask=None, data_info=None, return_logvar=False
279279
bs = x.shape[0]
280280
x = x.to(self.dtype)
281281
if self.timestep_norm_scale_factor != 1.0:
282-
timestep = (timestep.float() / self.timestep_norm_scale_factor).to(self.dtype)
282+
timestep = (timestep.float() / self.timestep_norm_scale_factor).to(torch.float32)
283283
else:
284-
timestep = timestep.long().to(self.dtype)
284+
timestep = timestep.long().to(torch.float32)
285285
y = y.to(self.dtype)
286286
self.h, self.w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
287287
x = self.x_embedder(x)
@@ -322,6 +322,7 @@ def forward(self, x, timestep, y, mask=None, data_info=None, return_logvar=False
322322
y = self.attention_y_norm(y)
323323

324324
if mask is not None:
325+
mask = mask.to(torch.int16)
325326
mask = mask.repeat(y.shape[0] // mask.shape[0], 1) if mask.shape[0] != y.shape[0] else mask
326327
mask = mask.squeeze(1).squeeze(1)
327328
if _xformers_available:
@@ -478,34 +479,34 @@ def SanaMS_1600M_P2_D20(**kwargs):
478479
return SanaMS(depth=20, hidden_size=2240, patch_size=2, num_heads=20, **kwargs)
479480

480481

481-
# TrigFlow/sCM model
482482
@MODELS.register_module()
483-
def SanaMSCM_600M_P1_D28(**kwargs):
484-
return SanaMSCM(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)
483+
def SanaMS_2400M_P1_D30(**kwargs):
484+
return SanaMS(depth=30, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
485485

486486

487487
@MODELS.register_module()
488-
def SanaMSCM_1600M_P1_D20(**kwargs):
489-
return SanaMSCM(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
488+
def SanaMS_3200M_P1_D40(**kwargs):
489+
return SanaMS(depth=40, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
490490

491491

492492
@MODELS.register_module()
493-
def SanaMSCM_2400M_P1_D30(**kwargs):
494-
# 30 layers, 2400M
495-
return SanaMSCM(depth=30, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
493+
def SanaMS_4800M_P1_D60(**kwargs):
494+
# 60 layers, 4800M
495+
return SanaMS(depth=60, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
496496

497497

498+
# TrigFlow/sCM model
498499
@MODELS.register_module()
499-
def SanaMS_2400M_P1_D30(**kwargs):
500-
return SanaMS(depth=30, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
500+
def SanaMSCM_600M_P1_D28(**kwargs):
501+
return SanaMSCM(depth=28, hidden_size=1152, patch_size=1, num_heads=16, **kwargs)
501502

502503

503504
@MODELS.register_module()
504-
def SanaMS_3200M_P1_D40(**kwargs):
505-
return SanaMS(depth=40, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
505+
def SanaMSCM_1600M_P1_D20(**kwargs):
506+
return SanaMSCM(depth=20, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
506507

507508

508509
@MODELS.register_module()
509-
def SanaMS_4800M_P1_D60(**kwargs):
510-
# 60 layers, 4800M
511-
return SanaMS(depth=60, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)
510+
def SanaMSCM_2400M_P1_D30(**kwargs):
511+
# 30 layers, 2400M
512+
return SanaMSCM(depth=30, hidden_size=2240, patch_size=1, num_heads=20, **kwargs)

scripts/inference_sana_sprint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ class SanaInference(SanaConfig):
212212
] = "configs/sana_sprint_config/1024ms/SanaSprint_1600M_1024px_allqknorm_bf16_scm_ladd.yaml" # config
213213
model_path: Optional[
214214
str
215-
] = "hf://Efficient-Large-Model/SANA_Sprint_1.6B_1024px/checkpoints/SANA_Sprint_1.6B_1024px.pth"
215+
] = "hf://Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
216216
work_dir: Optional[str] = None
217217
txt_file: str = "asset/samples/samples_mini.txt"
218218
json_file: Optional[str] = None

0 commit comments

Comments
 (0)