Skip to content

Commit 32c94fe

Browse files
lawrence-cjxieenze
andauthored
update code for better quality (#105)
* change bs for 2K Signed-off-by: lawrence-cj <[email protected]> * 1. code update for better quality; 2. fix app bug for CFG+PAG inference; 3. change default inference setting to CFG only; Signed-off-by: lawrence-cj <[email protected]> * update README.md; Signed-off-by: lawrence-cj <[email protected]> * change config name Signed-off-by: lawrence-cj <[email protected]> * update README.md; Signed-off-by: lawrence-cj <[email protected]> --------- Signed-off-by: lawrence-cj <[email protected]> Co-authored-by: Enze Xie <[email protected]>
1 parent 374447b commit 32c94fe

File tree

11 files changed

+104
-215
lines changed

11 files changed

+104
-215
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.
3636

3737
## 🔥🔥 News
3838

39-
- (🔥 New) \[2024/12/20\] 1.6B 2K resolution [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). 🚀 Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/).
39+
- (🔥 New) \[2024/12/20\] 1.6B 2K resolution [Sana models](asset/docs/model_zoo.md) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). 🚀 Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/).
4040
- (🔥 New) \[2024/12/18\] `diffusers` supports Sana-LoRA fine-tuning! Sana-LoRA's training and convergence speed is supper fast. [\[Guidance\]](asset/docs/sana_lora_dreambooth.md) or [\[diffusers docs\]](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md).
4141
- (🔥 New) \[2024/12/13\] `diffusers` has Sana! [All Sana models in diffusers safetensors](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released and diffusers pipeline `SanaPipeline`, `SanaPAGPipeline`, `DPMSolverMultistepScheduler(with FlowMatching)` are all supported now. We prepare a [Model Card](asset/docs/model_zoo.md) for you to choose.
4242
- (🔥 New) \[2024/12/10\] 1.6B BF16 [Sana model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) is released for stable fine-tuning.
@@ -126,7 +126,8 @@ DEMO_PORT=15432 \
126126
python app/app_sana.py \
127127
--share \
128128
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
129-
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth
129+
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth \
130+
--image_size=1024
130131
```
131132

132133
### 1. How to use `SanaPipeline` with `🧨diffusers`

app/app_sana.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,14 +408,14 @@ def generate(
408408
minimum=1,
409409
maximum=10,
410410
step=0.1,
411-
value=5.0,
411+
value=4.5,
412412
)
413413
flow_dpms_pag_guidance_scale = gr.Slider(
414414
label="PAG Guidance scale",
415415
minimum=1,
416416
maximum=4,
417417
step=0.5,
418-
value=2.0,
418+
value=1.0,
419419
)
420420
with gr.Row():
421421
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)

app/sana_pipeline.py

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from diffusion import DPMS, FlowEuler
2929
from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST
3030
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
31-
from diffusion.model.utils import prepare_prompt_ar, resize_and_crop_tensor
32-
from diffusion.utils.config import SanaConfig
31+
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar, resize_and_crop_tensor
32+
from diffusion.utils.config import SanaConfig, model_init_config
3333
from diffusion.utils.logger import get_root_logger
3434

3535
# from diffusion.utils.misc import read_config
@@ -40,6 +40,8 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
4040
guidance_type = default_guidance_type
4141
if not (pag_scale > 1.0 and attn_type == "linear"):
4242
guidance_type = "classifier-free"
43+
elif pag_scale > 1.0 and attn_type == "linear":
44+
guidance_type = "classifier-free_PAG"
4345
return guidance_type
4446

4547

@@ -93,15 +95,9 @@ def __init__(
9395
self.flow_shift = config.scheduler.flow_shift
9496
guidance_type = "classifier-free_PAG"
9597

96-
if config.model.mixed_precision == "fp16":
97-
weight_dtype = torch.float16
98-
elif config.model.mixed_precision == "bf16":
99-
weight_dtype = torch.bfloat16
100-
elif config.model.mixed_precision == "fp32":
101-
weight_dtype = torch.float32
102-
else:
103-
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
98+
weight_dtype = get_weight_dtype(config.model.mixed_precision)
10499
self.weight_dtype = weight_dtype
100+
self.vae_dtype = get_weight_dtype(config.vae.weight_dtype)
105101

106102
self.base_ratios = eval(f"ASPECT_RATIO_{self.image_size}_TEST")
107103
self.vis_sampler = self.config.scheduler.vis_sampler
@@ -126,7 +122,7 @@ def __init__(
126122
]
127123

128124
def build_vae(self, config):
129-
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.weight_dtype)
125+
vae = get_vae(config.vae_type, config.vae_pretrained, self.device).to(self.vae_dtype)
130126
return vae
131127

132128
def build_text_encoder(self, config):
@@ -135,31 +131,12 @@ def build_text_encoder(self, config):
135131

136132
def build_sana_model(self, config):
137133
# model setting
138-
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
139-
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
140-
model_kwargs = {
141-
"input_size": self.latent_size,
142-
"pe_interpolation": config.model.pe_interpolation,
143-
"config": config,
144-
"model_max_length": config.text_encoder.model_max_length,
145-
"qk_norm": config.model.qk_norm,
146-
"micro_condition": config.model.micro_condition,
147-
"caption_channels": self.text_encoder.config.hidden_size,
148-
"y_norm": config.text_encoder.y_norm,
149-
"attn_type": config.model.attn_type,
150-
"ffn_type": config.model.ffn_type,
151-
"mlp_ratio": config.model.mlp_ratio,
152-
"mlp_acts": list(config.model.mlp_acts),
153-
"in_channels": config.vae.vae_latent_dim,
154-
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
155-
"use_pe": config.model.use_pe,
156-
"pred_sigma": pred_sigma,
157-
"learn_sigma": learn_sigma,
158-
"use_fp32_attention": config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
159-
}
160-
model = build_model(config.model.model, **model_kwargs)
161-
model = model.to(self.weight_dtype)
162-
134+
model_kwargs = model_init_config(config, latent_size=self.latent_size)
135+
model = build_model(
136+
config.model.model,
137+
use_fp32_attention=config.model.get("fp32_attention", False) and config.model.mixed_precision != "bf16",
138+
**model_kwargs,
139+
)
163140
self.logger.info(f"use_fp32_attention: {model.fp32_attention}")
164141
self.logger.info(
165142
f"{model.__class__.__name__}:{config.model.model},"
@@ -310,7 +287,7 @@ def forward(
310287
flow_shift=self.flow_shift,
311288
)
312289

313-
sample = sample.to(self.weight_dtype)
290+
sample = sample.to(self.vae_dtype)
314291
with torch.no_grad():
315292
sample = vae_decode(self.config.vae.vae_type, self.vae, sample)
316293

configs/sana_config/2048ms/Sana_1600M_img2048.yaml renamed to configs/sana_config/2048ms/Sana_1600M_img2048_bf16.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ scheduler:
7878
train:
7979
num_workers: 10
8080
seed: 1
81-
train_batch_size: 64
81+
train_batch_size: 4
8282
num_epochs: 100
8383
gradient_accumulation_steps: 1
8484
grad_checkpointing: true

diffusion/model/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,3 +589,14 @@ def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, .
589589
else:
590590
assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
591591
return kernel_size // 2
592+
593+
594+
def get_weight_dtype(mixed_precision):
595+
if mixed_precision in ["fp16", "float16"]:
596+
return torch.float16
597+
elif mixed_precision in ["bf16", "bfloat16"]:
598+
return torch.bfloat16
599+
elif mixed_precision in ["fp32", "float32"]:
600+
return torch.float32
601+
else:
602+
raise ValueError(f"weigh precision {mixed_precision} is not defined")

diffusion/utils/config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class ModelConfig(BaseConfig):
7979
class AEConfig(BaseConfig):
8080
vae_type: str = "dc-ae"
8181
vae_pretrained: str = "mit-han-lab/dc-ae-f32c32-sana-1.0"
82+
weight_dtype: str = "bfloat16"
8283
scale_factor: float = 0.41407
8384
vae_latent_dim: int = 32
8485
vae_downsample_rate: int = 32
@@ -191,3 +192,30 @@ class SanaConfig(BaseConfig):
191192
tracker_project_name: str = "t2i-evit-baseline"
192193
name: str = "baseline"
193194
loss_report_name: str = "loss"
195+
196+
197+
def model_init_config(config: SanaConfig, latent_size: int = 32):
198+
199+
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
200+
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
201+
return {
202+
"input_size": latent_size,
203+
"pe_interpolation": config.model.pe_interpolation,
204+
"config": config,
205+
"model_max_length": config.text_encoder.model_max_length,
206+
"qk_norm": config.model.qk_norm,
207+
"micro_condition": config.model.micro_condition,
208+
"caption_channels": config.text_encoder.caption_channels,
209+
"y_norm": config.text_encoder.y_norm,
210+
"attn_type": config.model.attn_type,
211+
"ffn_type": config.model.ffn_type,
212+
"mlp_ratio": config.model.mlp_ratio,
213+
"mlp_acts": list(config.model.mlp_acts),
214+
"in_channels": config.vae.vae_latent_dim,
215+
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
216+
"use_pe": config.model.use_pe,
217+
"linear_head_dim": config.model.linear_head_dim,
218+
"pred_sigma": pred_sigma,
219+
"learn_sigma": learn_sigma,
220+
"cross_norm": config.model.cross_norm,
221+
}

scripts/inference.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from diffusion import DPMS, FlowEuler, SASolverSampler
3838
from diffusion.data.datasets.utils import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST, ASPECT_RATIO_2048_TEST
3939
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
40-
from diffusion.model.utils import prepare_prompt_ar
41-
from diffusion.utils.config import SanaConfig
40+
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar
41+
from diffusion.utils.config import SanaConfig, model_init_config
4242
from diffusion.utils.logger import get_root_logger
4343
from tools.download import find_model
4444

@@ -209,15 +209,14 @@ def visualize(config, args, model, items, bs, sample_steps, cfg_scale, pag_scale
209209
else:
210210
raise ValueError(f"{args.sampling_algo} is not defined")
211211

212-
samples = samples.to(weight_dtype)
212+
samples = samples.to(vae_dtype)
213213
samples = vae_decode(config.vae.vae_type, vae, samples)
214214
torch.cuda.empty_cache()
215215

216216
os.umask(0o000)
217217
for i, sample in enumerate(samples):
218218
save_file_name = f"{chunk[i]}.jpg" if dict_prompt else f"{prompts[i][:100]}.jpg"
219219
save_path = os.path.join(save_root, save_file_name)
220-
# logger.info(f"Saving path: {save_path}")
221220
save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1))
222221

223222

@@ -287,17 +286,12 @@ class SanaInference(SanaConfig):
287286
args.interval_guidance = [max(0, args.interval_guidance[0]), min(1, args.interval_guidance[1])]
288287
sample_steps_dict = {"dpm-solver": 20, "sa-solver": 25, "flow_dpm-solver": 20, "flow_euler": 28}
289288
sample_steps = args.step if args.step != -1 else sample_steps_dict[args.sampling_algo]
290-
if config.model.mixed_precision == "fp16":
291-
weight_dtype = torch.float16
292-
elif config.model.mixed_precision == "bf16":
293-
weight_dtype = torch.bfloat16
294-
elif config.model.mixed_precision == "fp32":
295-
weight_dtype = torch.float32
296-
else:
297-
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
289+
290+
weight_dtype = get_weight_dtype(config.model.mixed_precision)
298291
logger.info(f"Inference with {weight_dtype}, default guidance_type: {guidance_type}, flow_shift: {flow_shift}")
299292

300-
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(weight_dtype)
293+
vae_dtype = get_weight_dtype(config.vae.weight_dtype)
294+
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(vae_dtype)
301295
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name, device=device)
302296

303297
null_caption_token = tokenizer(
@@ -306,27 +300,7 @@ class SanaInference(SanaConfig):
306300
null_caption_embs = text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]
307301

308302
# model setting
309-
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
310-
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
311-
model_kwargs = {
312-
"pe_interpolation": config.model.pe_interpolation,
313-
"config": config,
314-
"model_max_length": config.text_encoder.model_max_length,
315-
"qk_norm": config.model.qk_norm,
316-
"micro_condition": config.model.micro_condition,
317-
"caption_channels": text_encoder.config.hidden_size,
318-
"y_norm": config.text_encoder.y_norm,
319-
"attn_type": config.model.attn_type,
320-
"ffn_type": config.model.ffn_type,
321-
"mlp_ratio": config.model.mlp_ratio,
322-
"mlp_acts": list(config.model.mlp_acts),
323-
"in_channels": config.vae.vae_latent_dim,
324-
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
325-
"use_pe": config.model.use_pe,
326-
"linear_head_dim": config.model.linear_head_dim,
327-
"pred_sigma": pred_sigma,
328-
"learn_sigma": learn_sigma,
329-
}
303+
model_kwargs = model_init_config(config, latent_size=latent_size)
330304
model = build_model(
331305
config.model.model, use_fp32_attention=config.model.get("fp32_attention", False), **model_kwargs
332306
).to(device)
@@ -418,6 +392,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
418392
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
419393
os.makedirs(save_root, exist_ok=True)
420394
if args.if_save_dirname and args.gpu_id == 0:
395+
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
421396
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
422397
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
423398
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")
@@ -441,6 +416,7 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
441416
save_root = create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type)
442417
os.makedirs(save_root, exist_ok=True)
443418
if args.if_save_dirname and args.gpu_id == 0:
419+
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
444420
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
445421
with open(f"{work_dir}/metrics/tmp_{dataset}_{time.time()}.txt", "w") as f:
446422
print(f"save tmp file at {work_dir}/metrics/tmp_{dataset}_{time.time()}.txt")

scripts/inference_dpg.py

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
get_chunks,
4242
)
4343
from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode
44-
from diffusion.model.utils import prepare_prompt_ar
45-
from diffusion.utils.config import SanaConfig
44+
from diffusion.model.utils import get_weight_dtype, prepare_prompt_ar
45+
from diffusion.utils.config import SanaConfig, model_init_config
4646
from diffusion.utils.logger import get_root_logger
4747

4848
# from diffusion.utils.misc import read_config
@@ -195,7 +195,7 @@ def visualize(items, bs, sample_steps, cfg_scale, pag_scale=1.0):
195195
else:
196196
raise ValueError(f"{args.sampling_algo} is not defined")
197197

198-
samples = samples.to(weight_dtype)
198+
samples = samples.to(vae_dtype)
199199
samples = vae_decode(config.vae.vae_type, vae, samples)
200200
torch.cuda.empty_cache()
201201

@@ -298,17 +298,11 @@ class SanaInference(SanaConfig):
298298
args.interval_guidance = [max(0, args.interval_guidance[0]), min(1, args.interval_guidance[1])]
299299
sample_steps_dict = {"dpm-solver": 20, "sa-solver": 25, "flow_dpm-solver": 20, "flow_euler": 28}
300300
sample_steps = args.step if args.step != -1 else sample_steps_dict[args.sampling_algo]
301-
if config.model.mixed_precision == "fp16":
302-
weight_dtype = torch.float16
303-
elif config.model.mixed_precision == "bf16":
304-
weight_dtype = torch.bfloat16
305-
elif config.model.mixed_precision == "fp32":
306-
weight_dtype = torch.float32
307-
else:
308-
raise ValueError(f"weigh precision {config.model.mixed_precision} is not defined")
301+
weight_dtype = get_weight_dtype(config.model.mixed_precision)
309302
logger.info(f"Inference with {weight_dtype}, default guidance_type: {guidance_type}, flow_shift: {flow_shift}")
310303

311-
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(weight_dtype)
304+
vae_dtype = get_weight_dtype(config.vae.weight_dtype)
305+
vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, device).to(vae_dtype)
312306
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name, device=device)
313307

314308
null_caption_token = tokenizer(
@@ -317,28 +311,7 @@ class SanaInference(SanaConfig):
317311
null_caption_embs = text_encoder(null_caption_token.input_ids, null_caption_token.attention_mask)[0]
318312

319313
# model setting
320-
pred_sigma = getattr(config.scheduler, "pred_sigma", True)
321-
learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma
322-
model_kwargs = {
323-
"input_size": latent_size,
324-
"pe_interpolation": config.model.pe_interpolation,
325-
"config": config,
326-
"model_max_length": config.text_encoder.model_max_length,
327-
"qk_norm": config.model.qk_norm,
328-
"micro_condition": config.model.micro_condition,
329-
"caption_channels": text_encoder.config.hidden_size,
330-
"y_norm": config.text_encoder.y_norm,
331-
"attn_type": config.model.attn_type,
332-
"ffn_type": config.model.ffn_type,
333-
"mlp_ratio": config.model.mlp_ratio,
334-
"mlp_acts": list(config.model.mlp_acts),
335-
"in_channels": config.vae.vae_latent_dim,
336-
"y_norm_scale_factor": config.text_encoder.y_norm_scale_factor,
337-
"use_pe": config.model.use_pe,
338-
"linear_head_dim": config.model.linear_head_dim,
339-
"pred_sigma": pred_sigma,
340-
"learn_sigma": learn_sigma,
341-
}
314+
model_kwargs = model_init_config(config, latent_size=latent_size)
342315
model = build_model(
343316
config.model.model, use_fp32_attention=config.model.get("fp32_attention", False), **model_kwargs
344317
).to(device)

0 commit comments

Comments
 (0)