Skip to content

Commit 168cf92

Browse files
hills-codeChengyue Wulawrence-cj
authored
inference scaling update (#208)
* inference scaling update * pre-commit; Signed-off-by: lawrence-cj <[email protected]> * update scripts for geneval generation * move inference scaling md into asset/docs/ Signed-off-by: lawrence-cj <[email protected]> * fix number typo Signed-off-by: lawrence-cj <[email protected]> --------- Signed-off-by: lawrence-cj <[email protected]> Co-authored-by: Chengyue Wu <[email protected]> Co-authored-by: lawrence-cj <[email protected]>
1 parent 2082c75 commit 168cf92

File tree

11 files changed

+378
-30
lines changed

11 files changed

+378
-30
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion models (e
4040

4141
## 🔥🔥 News
4242

43+
- (🔥 New) \[2025/3/21\] 🚀Sana + Inference Scaling is released. [\[Guidance\]](asset/docs/inference_scaling/inference_scaling.md)
4344
- (🔥 New) \[2025/3/16\] 🔥**SANA-1.5 code & weights are released!** 🎉 Include: [DDP/FSDP](#3-train-with-tar-file) | [TAR file WebDataset](#3-train-with-tar-file) | [Multi-Scale](#3-train-with-tar-file) Training code and [Weights](asset/docs/model_zoo.md) | [HF](https://huggingface.co/collections/Efficient-Large-Model/sana-15-67d6803867cb21c230b780e4) are all released.
4445
- (🔥 New) \[2025/3/14\] 🏃SANA-Sprint is coming out!🎉 A new one/few-step generator of Sana. 0.1s per 1024px image on H100, 0.3s on RTX 4090. Find out more details: [\[Page\]](https://nvlabs.github.io/Sana/Sprint/) | [\[Arxiv\]](https://arxiv.org/abs/2503.09641). Code is coming very soon along with `diffusers`
4546
- (🔥 New) \[2025/2/10\] 🚀Sana + ControlNet is released. [\[Guidance\]](asset/docs/sana_controlnet.md) | [\[Model\]](asset/docs/model_zoo.md) | [\[Demo\]](https://nv-sana.mit.edu/ctrlnet/)
@@ -393,6 +394,19 @@ bash train_scripts/train.sh \
393394

394395
Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md).
395396

397+
# 🚀 5. Inference Scaling
398+
399+
We trained a specialized [NVILA-2B](https://huggingface.co/Efficient-Large-Model/NVILA-Lite-2B-Verifier) model to score images, which we named VISA (VIla as SAna verifier). By selecting the top 4 images from 2,048 candidates, we enhanced the GenEval performance of SD1.5 and SANA-1.5-4.8B v2, increasing their scores from 42 to 87 and 81 to 96, respectively.
400+
401+
| Method | Overall | Single | Two | Counting | Colors | Position | Color Attribution |
402+
|--------------------------------|---------|--------|------|----------|--------|----------|------------------|
403+
| SD1.5 | 0.42 | 0.98 | 0.39 | 0.31 | 0.72 | 0.04 | 0.06 |
404+
| **+ Inference Scaling** | **0.87** | **1.00** | **0.97** | **0.93** | **0.96** | **0.75** | **0.62** |
405+
| SANA-1.5 4.8B v2 | 0.81 | 0.99 | 0.86 | 0.86 | 0.84 | 0.59 | 0.65 |
406+
| **+ Inference Scaling** | **0.96** | **1.00** | **1.00** | **0.97** | **0.94** | **0.96** | **0.87** |
407+
408+
Details refer to [Inference Scaling Manual](asset/docs/inference_scaling/inference_scaling.md).
409+
396410
# 💪To-Do List
397411

398412
We will try our best to release
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
## Inference Time Scaling for SANA-1.5
2+
3+
![results](results.jpg)
4+
5+
We trained a specialized [NVILA-2B](https://huggingface.co/Efficient-Large-Model/NVILA-Lite-2B-Verifier) model to score images, which we named VISA (VIla as SAna verifier). By selecting the top 4 images from 2,048 candidates, we enhanced the GenEval performance of SD1.5 and SANA-1.5-4.8B v2, increasing their scores from 42 to 87 and 81 to 96, respectively.
6+
7+
![curve](scaling_curve.jpg)
8+
9+
Even for smaller number of candidates, like 32, we can also push the performance over 90% for SANA-1.5-4.8B v2 in the GenEval.
10+
11+
### Environment Requirement
12+
13+
Dependency setups:
14+
15+
```bash
16+
# other transformers version may also work, but we have not tested
17+
pip install transformers==4.46
18+
pip install git+https://github.com/bfshi/scaling_on_scales.git
19+
```
20+
21+
### 1. Generate N images with a .pth file for the following selection
22+
23+
```bash
24+
# download the checkpoint for the following generation
25+
huggingface-cli download Efficient-Large-Model/Sana_600M_512px --repo-type model --local-dir output/Sana_600M_512px --local-dir-use-symlinks False
26+
# 32 is a relatively small number for test but can already push the geneval>90% when we verify the SANA-1.5-4.8B v2 model. Set it to larger number like 2048 for the limit of sky.
27+
n_samples=32
28+
pick_number=4
29+
30+
output_dir=output/geneval_generated_path
31+
# example
32+
bash scripts/infer_run_inference_geneval.sh \
33+
configs/sana_config/512ms/Sana_600M_img512.yaml \
34+
output/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth \
35+
--img_nums_per_sample=$n_samples \
36+
--output_dir=$output_dir
37+
```
38+
39+
### 2. Use NVILA-Verifier to select from the generated images
40+
41+
```bash
42+
bash tools/inference_scaling/nvila_sana_pick.sh \
43+
$output_dir \
44+
$n_samples \
45+
$pick_number
46+
```
47+
48+
### 3. Calculate the GenEval metric
49+
50+
You need to use the GenEval environment for the final evaluation. The document about installation can be found [here](../../../tools/metrics/geneval/geneval_env.md).
51+
52+
```bash
53+
# activate geneval env
54+
conda activate geneval
55+
56+
DIR_AFTER_PICK="output/nvila_pick/best_${pick_number}_of_${n_samples}/${output_dir}"
57+
58+
bash tools/metrics/compute_geneval.sh $(dirname "$DIR_AFTER_PICK") $(basename "$DIR_AFTER_PICK")
59+
```
91.8 KB
Loading
195 KB
Loading

asset/docs/model_zoo.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919
| Sana-1.6B-ControlNet | 1Kpx | [Sana_1600M_1024px_BF16_ControlNet_HED](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_ControlNet_HED) | Coming soon | **bf16**/fp32 | Multi-Language |
2020
| Sana-0.6B-ControlNet | 1Kpx | [Sana_600M_1024px_ControlNet_HED](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_ControlNet_HED) | Coming soon | fp16/fp32 | - |
2121

22-
---
22+
______________________________________________________________________
2323

2424
### SANA-1.5
2525

2626
| Model | Reso | pth link | diffusers | Precision | Description |
2727
|--------------|--------|-------------------------------------------------------------------------------------------|------------------------------------------------------------------------|-----------|----------------|
28-
| SANA1.5-4.8B | 1024px | [SANA1.5_4.8B_1024px](https://huggingface.co/Efficient-Large-Model/SANA1.5_4.8B_1024px) | [Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers]()(coming soon) | bf16 | Multi-Language |
28+
| SANA1.5-4.8B | 1024px | [SANA1.5_4.8B_1024px](https://huggingface.co/Efficient-Large-Model/SANA1.5_4.8B_1024px) | [Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers](<>)(coming soon) | bf16 | Multi-Language |
2929

30+
______________________________________________________________________
3031

31-
---
3232
## ❗ 2. Make sure to use correct precision(fp16/bf16/fp32) for training and inference.
3333

3434
### We provide two samples to use fp16 and bf16 weights, respectively.

scripts/infer_run_inference_geneval.sh

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ default_step=20 # 14
1010
default_sample_nums=553
1111
default_sampling_algo="flow_dpm-solver"
1212
default_add_label=''
13+
default_img_nums_per_sample=4
14+
default_batch_size=1
1315

1416
# parser
1517
config_file=$1
@@ -22,6 +24,18 @@ do
2224
step="${arg#*=}"
2325
shift
2426
;;
27+
--sample_nums=*)
28+
sample_nums="${arg#*=}"
29+
shift
30+
;;
31+
--img_nums_per_sample=*)
32+
img_nums_per_sample="${arg#*=}"
33+
shift
34+
;;
35+
--batch_size=*)
36+
batch_size="${arg#*=}"
37+
shift
38+
;;
2539
--sampling_algo=*)
2640
sampling_algo="${arg#*=}"
2741
shift
@@ -34,8 +48,8 @@ do
3448
model_paths="${arg#*=}"
3549
shift
3650
;;
37-
--sample_nums=*)
38-
sample_nums="${arg#*=}"
51+
--output_dir=*)
52+
output_dir="${arg#*=}"
3953
shift
4054
;;
4155
--cfg_scale=*)
@@ -63,21 +77,31 @@ samples_per_gpu=$((sample_nums / np))
6377
add_label=${add_label:-$default_add_label}
6478
ablation_key=${ablation_key:-''}
6579
ablation_selections=${ablation_selections:-''}
80+
img_nums_per_sample=${img_nums_per_sample:-$default_img_nums_per_sample}
81+
batch_size=${batch_size:-$default_batch_size}
82+
output_dir=${output_dir:-''}
83+
sssss
6684

6785
echo "Step: $step"
6886
echo "Sample numbers: $sample_nums"
87+
echo "Image numbers per sample: $img_nums_per_sample"
88+
echo "Batch size: $batch_size"
6989
echo "Sampling Algo: $sampling_algo"
7090
echo "CFG scale: $cfg_scale"
7191
echo "Add label: $add_label"
7292
echo "Exist time prefix: $exist_time_prefix"
7393

7494
cmd_template="DPM_TQDM=True python scripts/inference_geneval.py --config={config_file} --model_path={model_path} \
75-
--sampling_algo $sampling_algo --step $step --cfg_scale $cfg_scale --sample_nums $sample_nums \
76-
--gpu_id {gpu_id} --start_index {start_index} --end_index {end_index}"
95+
--sampling_algo $sampling_algo --step $step --cfg_scale $cfg_scale --sample_nums $sample_nums --n_samples $img_nums_per_sample \
96+
--batch_size $batch_size --gpu_id {gpu_id} --start_index {start_index} --end_index {end_index}"
7797
if [ -n "${add_label}" ]; then
7898
cmd_template="${cmd_template} --add_label ${add_label}"
7999
fi
80100

101+
if [ -n "${output_dir}" ]; then
102+
cmd_template="${cmd_template} --output_dir ${output_dir}"
103+
fi
104+
81105
if [ -n "${ablation_key}" ]; then
82106
cmd_template="${cmd_template} --ablation_key ${ablation_key} --ablation_selections "${ablation_selections}""
83107
echo "ablation_key: $ablation_key"
@@ -108,7 +132,7 @@ if [[ "$model_paths" == *.pth ]]; then
108132
cmd="${cmd//\{end_index\}/$end_index}"
109133

110134
echo "Running on GPU $gpu_id: samples $start_index to $end_index"
111-
echo $cmd
135+
echo "cmd: $cmd"
112136
eval CUDA_VISIBLE_DEVICES=$gpu_id $cmd &
113137
done
114138
wait

scripts/inference_geneval.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch
3131
from einops import rearrange
3232
from PIL import Image
33+
from termcolor import colored
3334
from torchvision.utils import _log_api_usage_once, make_grid, save_image
3435
from tqdm import tqdm
3536

@@ -173,7 +174,6 @@ def visualize(sample_steps, cfg_scale, pag_scale):
173174
os.makedirs(sample_path, exist_ok=True)
174175

175176
prompt = metadata["prompt"]
176-
# print(f"Prompt ({index: >3}/{len(metadatas)}): '{prompt}'")
177177
with open(os.path.join(outpath, "metadata.jsonl"), "w") as fp:
178178
json.dump(metadata, fp)
179179

@@ -347,7 +347,7 @@ def parse_args():
347347
class SanaInference(SanaConfig):
348348
config: str = ""
349349
dataset: str = "GenEval"
350-
outdir: str = field(default="outputs", metadata={"help": "dir to write results to"})
350+
output_dir: str = field(default=None, metadata={"help": "dir to write results to"})
351351
n_samples: int = field(default=4, metadata={"help": "number of samples"})
352352
batch_size: int = field(default=1, metadata={"help": "how many samples can be produced simultaneously"})
353353
skip_grid: bool = field(default=False, metadata={"help": "skip saving grid"})
@@ -394,8 +394,9 @@ class SanaInference(SanaConfig):
394394
device = "cuda" if torch.cuda.is_available() else "cpu"
395395
logger = get_root_logger()
396396

397-
n_rows = batch_size = args.n_samples
398-
assert args.batch_size == 1, ValueError(f"{batch_size} > 1 is not available in GenEval")
397+
batch_size = args.batch_size
398+
n_rows = 4 if args.n_samples > 4 else args.n_samples
399+
assert args.n_samples % args.batch_size == 0, ValueError(f"{args.n_samples} cannot be divided by {args.batch_size}")
399400

400401
# only support fixed latent size currently
401402
latent_size = args.image_size // config.vae.vae_downsample_rate
@@ -448,12 +449,25 @@ class SanaInference(SanaConfig):
448449
if ("flow" not in args.model_path or args.sampling_algo == "flow_dpm-solver")
449450
else "flow_euler"
450451
)
452+
logger.info(f"Sampler {args.sampling_algo}")
451453

452-
work_dir = (
453-
f"/{os.path.join(*args.model_path.split('/')[:-2])}"
454-
if args.model_path.startswith("/")
455-
else os.path.join(*args.model_path.split("/")[:-2])
456-
)
454+
# save path
455+
if args.output_dir is None:
456+
work_dir = (
457+
f"/{os.path.join(*args.model_path.split('/')[:-2])}"
458+
if args.model_path.startswith("/")
459+
else os.path.join(*args.model_path.split("/")[:-2])
460+
)
461+
img_save_dir = os.path.join(str(work_dir), "vis")
462+
463+
os.umask(0o000)
464+
os.makedirs(img_save_dir, exist_ok=True)
465+
logger.info(colored(f"Saving images at {img_save_dir}", "green"))
466+
else:
467+
work_dir = args.output_dir
468+
469+
os.umask(0o000)
470+
os.makedirs(work_dir, exist_ok=True)
457471

458472
# dataset
459473
metadatas = datasets.load_dataset(
@@ -465,15 +479,9 @@ class SanaInference(SanaConfig):
465479
match = re.search(r".*epoch_(\d+).*step_(\d+).*", args.model_path)
466480
epoch_name, step_name = match.groups() if match else ("unknown", "unknown")
467481

468-
img_save_dir = os.path.join(str(work_dir), "vis")
469-
os.umask(0o000)
470-
os.makedirs(img_save_dir, exist_ok=True)
471-
logger.info(f"Sampler {args.sampling_algo}")
472-
473482
def create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidance_type):
474483
save_root = os.path.join(
475484
img_save_dir,
476-
# f"{datetime.now().date() if args.exist_time_prefix == '' else args.exist_time_prefix}_"
477485
f"{dataset}_epoch{epoch_name}_step{step_name}_scale{args.cfg_scale}"
478486
f"_step{sample_steps}_size{args.image_size}_bs{batch_size}_samp{args.sampling_algo}"
479487
f"_seed{args.seed}_{str(weight_dtype).split('.')[-1]}",
@@ -487,6 +495,10 @@ def create_save_root(args, dataset, epoch_name, step_name, sample_steps, guidanc
487495
save_root += f"_{guidance_type}"
488496
if args.interval_guidance[0] != 0 and args.interval_guidance[1] != 1:
489497
save_root += f"_intervalguidance{args.interval_guidance[0]}{args.interval_guidance[1]}"
498+
if not DATA_URL.endswith("evaluation_metadata.jsonl"):
499+
save_root += f"_metadata{DATA_URL.split('/')[-1]}"
500+
if args.n_samples != 4:
501+
save_root += f"_nsample{args.n_samples}"
490502

491503
save_root += f"_imgnums{args.sample_nums}" + args.add_label
492504
return save_root
@@ -505,7 +517,10 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
505517
sample_steps = args.step if args.step != -1 else sample_steps_dict[args.sampling_algo]
506518
guidance_type = guidance_type_select(guidance_type, args.pag_scale, config.model.attn_type)
507519

508-
save_root = create_save_root(args, args.dataset, epoch_name, step_name, sample_steps, guidance_type)
520+
if args.output_dir is None:
521+
save_root = create_save_root(args, args.dataset, epoch_name, step_name, sample_steps, guidance_type)
522+
else:
523+
save_root = args.output_dir
509524
os.makedirs(save_root, exist_ok=True)
510525
if args.if_save_dirname and args.gpu_id == 0:
511526
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
@@ -519,7 +534,10 @@ def guidance_type_select(default_guidance_type, pag_scale, attn_type):
519534
guidance_type = guidance_type_select(guidance_type, args.pag_scale, config.model.attn_type)
520535
logger.info(f"Inference with {weight_dtype}, guidance_type: {guidance_type}, flow_shift: {flow_shift}")
521536

522-
save_root = create_save_root(args, args.dataset, epoch_name, step_name, sample_steps, guidance_type)
537+
if args.output_dir is None:
538+
save_root = create_save_root(args, args.dataset, epoch_name, step_name, sample_steps, guidance_type)
539+
else:
540+
save_root = args.output_dir
523541
os.makedirs(save_root, exist_ok=True)
524542
if args.if_save_dirname and args.gpu_id == 0:
525543
os.makedirs(f"{work_dir}/metrics", exist_ok=True)

scripts/inference_geneval_diffusers.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def parse_args():
151151
help="skip saving grid",
152152
)
153153

154+
parser.add_argument("--work_dir", default=None, type=str)
154155
parser.add_argument("--sample_nums", default=553, type=int)
155156
parser.add_argument("--add_label", default="", type=str)
156157
parser.add_argument("--exist_time_prefix", default="", type=str)
@@ -193,12 +194,17 @@ def parse_args():
193194
logger.info(f"Eval {len(metadatas)} samples")
194195

195196
# save path
196-
work_dir = (
197-
f"/{os.path.join(*args.model_path.split('/')[:-1])}"
198-
if args.model_path.startswith("/")
199-
else os.path.join(*args.model_path.split("/")[:-1])
200-
)
197+
if args.work_dir is None:
198+
work_dir = (
199+
f"/{os.path.join(*args.model_path.split('/')[:-1])}"
200+
if args.model_path.startswith("/")
201+
else os.path.join(*args.model_path.split("/")[:-1])
202+
)
203+
else:
204+
work_dir = args.work_dir
205+
args.work_dir = work_dir
201206
img_save_dir = os.path.join(str(work_dir), "vis")
207+
202208
os.umask(0o000)
203209
os.makedirs(img_save_dir, exist_ok=True)
204210

@@ -214,6 +220,7 @@ def parse_args():
214220

215221
if args.if_save_dirname and args.gpu_id == 0:
216222
# save at work_dir/metrics/tmp_xxx.txt for metrics testing
223+
os.makedirs(f"{work_dir}/metrics", exist_ok=True)
217224
with open(f"{work_dir}/metrics/tmp_geneval_{time.time()}.txt", "w") as f:
218225
print(f"save tmp file at {work_dir}/metrics/tmp_geneval_{time.time()}.txt")
219226
f.write(os.path.basename(save_root))

0 commit comments

Comments
 (0)