Skip to content

Commit 93c9f9d

Browse files
lawrence-cjxieenze
andauthored
Nv labs GitHub repo/nv labs GitHub repo main adding controlnet (#23) (#177)
* Nv labs GitHub repo/nv labs GitHub repo main adding controlnet (#23) * 1. add scripts of controlnet; 2. pre-commit; Signed-off-by: lawrence-cj <[email protected]> * add all we need for controlnet inference and run successful; * move samples txt file into one dir; update readme; * 1. add readme for controlnet; 2. update readme; * add 1.6B controlnet related model and config files; Signed-off-by: lawrence-cj <[email protected]> * update readme && pre-commit; Signed-off-by: lawrence-cj <[email protected]> * 1. update readme.md * 1. add all need for online controlnet demo; 2. run success; * little bug fixed; Signed-off-by: lawrence-cj <[email protected]> * code update && pre-commit; Signed-off-by: lawrence-cj <[email protected]> * 1. update controlnet readme; 2. pre-commit; Signed-off-by: lawrence-cj <[email protected]> * 1. update controlnet readme; 2. pre-commit; Signed-off-by: lawrence-cj <[email protected]> --------- Signed-off-by: lawrence-cj <[email protected]> Co-authored-by: Enze Xie <[email protected]> * 1. add test controlnet in CI; 2. fix controlnet config bug; Signed-off-by: lawrence-cj <[email protected]> * add ref image for controlnet; Signed-off-by: lawrence-cj <[email protected]> * update controlnet readme; Signed-off-by: lawrence-cj <[email protected]> * update controlnet CI; Signed-off-by: lawrence-cj <[email protected]> --------- Signed-off-by: lawrence-cj <[email protected]> Co-authored-by: Enze Xie <[email protected]>
1 parent dd38c12 commit 93c9f9d

32 files changed

+2054
-10
lines changed

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion models (e
3838

3939
## 🔥🔥 News
4040

41+
- (🔥 New) \[2025/2/10\] 🚀Sana + ControlNet is released. [\[Guidance\]](asset/docs/sana_controlnet) | [\[Model\]](asset/docs/model_zoo.md)
4142
- (🔥 New) \[2025/1/30\] Release CAME-8bit optimizer code. Saving more GPU memory during training. [\[How to config\]](https://github.com/NVlabs/Sana/blob/main/configs/sana_config/1024ms/Sana_1600M_img1024_CAME8bit.yaml#L86)
42-
- (🔥 New) \[2025/1/29\] 🎉 🎉 🎉**SANA 1.5 is out! Figure out how to do efficient training & inference scaling!** 🚀[\[Tech Report\]](asset/SANA_1.5.pdf)
43+
- (🔥 New) \[2025/1/29\] 🎉 🎉 🎉**SANA 1.5 is out! Figure out how to do efficient training & inference scaling!** 🚀[\[Tech Report\]](https://arxiv.org/abs/2501.18427)
4344
- (🔥 New) \[2025/1/24\] 4bit-Sana is released, powered by [SVDQuant and Nunchaku](https://github.com/mit-han-lab/nunchaku) inference engine. Now run your Sana within **8GB** GPU VRAM [\[Guidance\]](asset/docs/4bit_sana.md) [\[Demo\]](https://svdquant.mit.edu/) [\[Model\]](asset/docs/model_zoo.md)
4445
- (🔥 New) \[2025/1/24\] DCAE-1.1 is released, better reconstruction quality. [\[Model\]](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1) [\[diffusers\]](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers)
4546
- (🔥 New) \[2025/1/23\] **Sana is accepted by ICLR-2025.** 🎉🎉🎉
@@ -271,16 +272,16 @@ docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
271272
python scripts/inference.py \
272273
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
273274
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth \
274-
--txt_file=asset/samples_mini.txt
275+
--txt_file=asset/samples/samples_mini.txt
275276

276277
# Run samples in a json file
277278
python scripts/inference.py \
278279
--config=configs/sana_config/1024ms/Sana_1600M_img1024.yaml \
279280
--model_path=hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth \
280-
--json_file=asset/samples_mini.json
281+
--json_file=asset/samples/samples_mini.json
281282
```
282283

283-
where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a prompt to generate
284+
where each line of [`asset/samples/samples_mini.txt`](asset/samples/samples_mini.txt) contains a prompt to generate
284285

285286
# 🔥 3. How to Train Sana
286287

app/app_sana.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,6 @@ def get_args():
199199
args = get_args()
200200

201201
if torch.cuda.is_available():
202-
weight_dtype = torch.float16
203202
model_path = args.model_path
204203
pipe = SanaPipeline(args.config)
205204
pipe.from_pretrained(model_path)

app/app_sana_controlnet_hed.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py
2+
import argparse
3+
import os
4+
import random
5+
import socket
6+
import tempfile
7+
import time
8+
9+
import gradio as gr
10+
import numpy as np
11+
import torch
12+
from PIL import Image
13+
from transformers import AutoModelForCausalLM, AutoTokenizer
14+
15+
from app import safety_check
16+
from app.sana_controlnet_pipeline import SanaControlNetPipeline
17+
18+
STYLES = {
19+
"None": "{prompt}",
20+
"Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
21+
"3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting",
22+
"Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed",
23+
"Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed",
24+
"Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed",
25+
"Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics",
26+
"Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
27+
"Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
28+
"Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style",
29+
}
30+
DEFAULT_STYLE_NAME = "None"
31+
STYLE_NAMES = list(STYLES.keys())
32+
33+
MAX_SEED = 1000000000
34+
DEFAULT_SKETCH_GUIDANCE = 0.28
35+
DEMO_PORT = int(os.getenv("DEMO_PORT", "15432"))
36+
37+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38+
39+
blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255))
40+
41+
42+
def get_args():
43+
parser = argparse.ArgumentParser()
44+
parser.add_argument("--config", type=str, help="config")
45+
parser.add_argument(
46+
"--model_path",
47+
nargs="?",
48+
default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
49+
type=str,
50+
help="Path to the model file (positional)",
51+
)
52+
parser.add_argument("--output", default="./", type=str)
53+
parser.add_argument("--bs", default=1, type=int)
54+
parser.add_argument("--image_size", default=1024, type=int)
55+
parser.add_argument("--cfg_scale", default=5.0, type=float)
56+
parser.add_argument("--pag_scale", default=2.0, type=float)
57+
parser.add_argument("--seed", default=42, type=int)
58+
parser.add_argument("--step", default=-1, type=int)
59+
parser.add_argument("--custom_image_size", default=None, type=int)
60+
parser.add_argument("--share", action="store_true")
61+
parser.add_argument(
62+
"--shield_model_path",
63+
type=str,
64+
help="The path to shield model, we employ ShieldGemma-2B by default.",
65+
default="google/shieldgemma-2b",
66+
)
67+
68+
return parser.parse_known_args()[0]
69+
70+
71+
args = get_args()
72+
73+
if torch.cuda.is_available():
74+
model_path = args.model_path
75+
pipe = SanaControlNetPipeline(args.config)
76+
pipe.from_pretrained(model_path)
77+
pipe.register_progress_bar(gr.Progress())
78+
79+
# safety checker
80+
safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path)
81+
safety_checker_model = AutoModelForCausalLM.from_pretrained(
82+
args.shield_model_path,
83+
device_map="auto",
84+
torch_dtype=torch.bfloat16,
85+
).to(device)
86+
87+
88+
def save_image(img):
89+
if isinstance(img, dict):
90+
img = img["composite"]
91+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
92+
img.save(temp_file.name)
93+
return temp_file.name
94+
95+
96+
def norm_ip(img, low, high):
97+
img.clamp_(min=low, max=high)
98+
img.sub_(low).div_(max(high - low, 1e-5))
99+
return img
100+
101+
102+
@torch.no_grad()
103+
@torch.inference_mode()
104+
def run(
105+
image,
106+
prompt: str,
107+
prompt_template: str,
108+
sketch_thickness: int,
109+
guidance_scale: float,
110+
inference_steps: int,
111+
seed: int,
112+
blend_alpha: float,
113+
) -> tuple[Image, str]:
114+
115+
print(f"Prompt: {prompt}")
116+
image_numpy = np.array(image["composite"].convert("RGB"))
117+
118+
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
119+
return blank_image, "Please input the prompt or draw something."
120+
121+
if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2):
122+
prompt = "A red heart."
123+
124+
prompt = prompt_template.format(prompt=prompt)
125+
pipe.set_blend_alpha(blend_alpha)
126+
start_time = time.time()
127+
images = pipe(
128+
prompt=prompt,
129+
ref_image=image["composite"],
130+
guidance_scale=guidance_scale,
131+
num_inference_steps=inference_steps,
132+
num_images_per_prompt=1,
133+
sketch_thickness=sketch_thickness,
134+
generator=torch.Generator(device=device).manual_seed(seed),
135+
)
136+
137+
latency = time.time() - start_time
138+
139+
if latency < 1:
140+
latency = latency * 1000
141+
latency_str = f"{latency:.2f}ms"
142+
else:
143+
latency_str = f"{latency:.2f}s"
144+
torch.cuda.empty_cache()
145+
146+
img = [
147+
Image.fromarray(
148+
norm_ip(img, -1, 1)
149+
.mul(255)
150+
.add_(0.5)
151+
.clamp_(0, 255)
152+
.permute(1, 2, 0)
153+
.to("cpu", torch.uint8)
154+
.numpy()
155+
.astype(np.uint8)
156+
)
157+
for img in images
158+
]
159+
img = img[0]
160+
return img, latency_str
161+
162+
163+
model_size = "1.6" if "1600M" in args.model_path else "0.6"
164+
title = f"""
165+
<div style='display: flex; align-items: center; justify-content: center; text-align: center;'>
166+
<img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/>
167+
</div>
168+
"""
169+
DESCRIPTION = f"""
170+
<p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p>
171+
<p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p>
172+
<p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p>
173+
<p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}.
174+
<p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p>
175+
"""
176+
if model_size == "0.6":
177+
DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>"
178+
if not torch.cuda.is_available():
179+
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
180+
181+
182+
with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo:
183+
gr.Markdown(title)
184+
gr.HTML(DESCRIPTION)
185+
186+
with gr.Row(elem_id="main_row"):
187+
with gr.Column(elem_id="column_input"):
188+
gr.Markdown("## INPUT", elem_id="input_header")
189+
with gr.Group():
190+
canvas = gr.Sketchpad(
191+
value=blank_image,
192+
height=640,
193+
image_mode="RGB",
194+
sources=["upload", "clipboard"],
195+
type="pil",
196+
label="Sketch",
197+
show_label=False,
198+
show_download_button=True,
199+
interactive=True,
200+
transforms=[],
201+
canvas_size=(1024, 1024),
202+
scale=1,
203+
brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"),
204+
format="png",
205+
layers=False,
206+
)
207+
with gr.Row():
208+
prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6)
209+
run_button = gr.Button("Run", scale=1, elem_id="run_button")
210+
download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch")
211+
with gr.Row():
212+
style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1)
213+
prompt_template = gr.Textbox(
214+
label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1
215+
)
216+
217+
with gr.Row():
218+
sketch_thickness = gr.Slider(
219+
label="Sketch Thickness",
220+
minimum=1,
221+
maximum=4,
222+
step=1,
223+
value=2,
224+
)
225+
with gr.Row():
226+
inference_steps = gr.Slider(
227+
label="Sampling steps",
228+
minimum=5,
229+
maximum=40,
230+
step=1,
231+
value=20,
232+
)
233+
guidance_scale = gr.Slider(
234+
label="CFG Guidance scale",
235+
minimum=1,
236+
maximum=10,
237+
step=0.1,
238+
value=4.5,
239+
)
240+
blend_alpha = gr.Slider(
241+
label="Blend Alpha",
242+
minimum=0,
243+
maximum=1,
244+
step=0.1,
245+
value=0,
246+
)
247+
with gr.Row():
248+
seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4)
249+
randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed")
250+
251+
with gr.Column(elem_id="column_output"):
252+
gr.Markdown("## OUTPUT", elem_id="output_header")
253+
with gr.Group():
254+
result = gr.Image(
255+
format="png",
256+
height=640,
257+
image_mode="RGB",
258+
type="pil",
259+
label="Result",
260+
show_label=False,
261+
show_download_button=True,
262+
interactive=False,
263+
elem_id="output_image",
264+
)
265+
latency_result = gr.Text(label="Inference Latency", show_label=True)
266+
267+
download_result = gr.DownloadButton("Download Result", elem_id="download_result")
268+
gr.Markdown("### Instructions")
269+
gr.Markdown("**1**. Enter a text prompt (e.g. a cat)")
270+
gr.Markdown("**2**. Start sketching or upload a reference image")
271+
gr.Markdown("**3**. Change the image style using a style template")
272+
gr.Markdown("**4**. Try different seeds to generate different results")
273+
274+
run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha]
275+
run_outputs = [result, latency_result]
276+
277+
randomize_seed.click(
278+
lambda: random.randint(0, MAX_SEED),
279+
inputs=[],
280+
outputs=seed,
281+
api_name=False,
282+
queue=False,
283+
).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False)
284+
285+
style.change(
286+
lambda x: STYLES[x],
287+
inputs=[style],
288+
outputs=[prompt_template],
289+
api_name=False,
290+
queue=False,
291+
).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False)
292+
gr.on(
293+
triggers=[prompt.submit, run_button.click, canvas.change],
294+
fn=run,
295+
inputs=run_inputs,
296+
outputs=run_outputs,
297+
api_name=False,
298+
)
299+
300+
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
301+
download_result.click(fn=save_image, inputs=result, outputs=download_result)
302+
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
303+
304+
305+
if __name__ == "__main__":
306+
demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share)

0 commit comments

Comments
 (0)