|
| 1 | +# Changed from https://github.com/GaParmar/img2img-turbo/blob/main/gradio_sketch2image.py |
| 2 | +import argparse |
| 3 | +import os |
| 4 | +import random |
| 5 | +import socket |
| 6 | +import tempfile |
| 7 | +import time |
| 8 | + |
| 9 | +import gradio as gr |
| 10 | +import numpy as np |
| 11 | +import torch |
| 12 | +from PIL import Image |
| 13 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 14 | + |
| 15 | +from app import safety_check |
| 16 | +from app.sana_controlnet_pipeline import SanaControlNetPipeline |
| 17 | + |
| 18 | +STYLES = { |
| 19 | + "None": "{prompt}", |
| 20 | + "Cinematic": "cinematic still {prompt}. emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", |
| 21 | + "3D Model": "professional 3d model {prompt}. octane render, highly detailed, volumetric, dramatic lighting", |
| 22 | + "Anime": "anime artwork {prompt}. anime style, key visual, vibrant, studio anime, highly detailed", |
| 23 | + "Digital Art": "concept art {prompt}. digital artwork, illustrative, painterly, matte painting, highly detailed", |
| 24 | + "Photographic": "cinematic photo {prompt}. 35mm photograph, film, bokeh, professional, 4k, highly detailed", |
| 25 | + "Pixel art": "pixel-art {prompt}. low-res, blocky, pixel art style, 8-bit graphics", |
| 26 | + "Fantasy art": "ethereal fantasy concept art of {prompt}. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", |
| 27 | + "Neonpunk": "neonpunk style {prompt}. cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", |
| 28 | + "Manga": "manga style {prompt}. vibrant, high-energy, detailed, iconic, Japanese comic style", |
| 29 | +} |
| 30 | +DEFAULT_STYLE_NAME = "None" |
| 31 | +STYLE_NAMES = list(STYLES.keys()) |
| 32 | + |
| 33 | +MAX_SEED = 1000000000 |
| 34 | +DEFAULT_SKETCH_GUIDANCE = 0.28 |
| 35 | +DEMO_PORT = int(os.getenv("DEMO_PORT", "15432")) |
| 36 | + |
| 37 | +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| 38 | + |
| 39 | +blank_image = Image.new("RGB", (1024, 1024), (255, 255, 255)) |
| 40 | + |
| 41 | + |
| 42 | +def get_args(): |
| 43 | + parser = argparse.ArgumentParser() |
| 44 | + parser.add_argument("--config", type=str, help="config") |
| 45 | + parser.add_argument( |
| 46 | + "--model_path", |
| 47 | + nargs="?", |
| 48 | + default="hf://Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth", |
| 49 | + type=str, |
| 50 | + help="Path to the model file (positional)", |
| 51 | + ) |
| 52 | + parser.add_argument("--output", default="./", type=str) |
| 53 | + parser.add_argument("--bs", default=1, type=int) |
| 54 | + parser.add_argument("--image_size", default=1024, type=int) |
| 55 | + parser.add_argument("--cfg_scale", default=5.0, type=float) |
| 56 | + parser.add_argument("--pag_scale", default=2.0, type=float) |
| 57 | + parser.add_argument("--seed", default=42, type=int) |
| 58 | + parser.add_argument("--step", default=-1, type=int) |
| 59 | + parser.add_argument("--custom_image_size", default=None, type=int) |
| 60 | + parser.add_argument("--share", action="store_true") |
| 61 | + parser.add_argument( |
| 62 | + "--shield_model_path", |
| 63 | + type=str, |
| 64 | + help="The path to shield model, we employ ShieldGemma-2B by default.", |
| 65 | + default="google/shieldgemma-2b", |
| 66 | + ) |
| 67 | + |
| 68 | + return parser.parse_known_args()[0] |
| 69 | + |
| 70 | + |
| 71 | +args = get_args() |
| 72 | + |
| 73 | +if torch.cuda.is_available(): |
| 74 | + model_path = args.model_path |
| 75 | + pipe = SanaControlNetPipeline(args.config) |
| 76 | + pipe.from_pretrained(model_path) |
| 77 | + pipe.register_progress_bar(gr.Progress()) |
| 78 | + |
| 79 | + # safety checker |
| 80 | + safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path) |
| 81 | + safety_checker_model = AutoModelForCausalLM.from_pretrained( |
| 82 | + args.shield_model_path, |
| 83 | + device_map="auto", |
| 84 | + torch_dtype=torch.bfloat16, |
| 85 | + ).to(device) |
| 86 | + |
| 87 | + |
| 88 | +def save_image(img): |
| 89 | + if isinstance(img, dict): |
| 90 | + img = img["composite"] |
| 91 | + temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
| 92 | + img.save(temp_file.name) |
| 93 | + return temp_file.name |
| 94 | + |
| 95 | + |
| 96 | +def norm_ip(img, low, high): |
| 97 | + img.clamp_(min=low, max=high) |
| 98 | + img.sub_(low).div_(max(high - low, 1e-5)) |
| 99 | + return img |
| 100 | + |
| 101 | + |
| 102 | +@torch.no_grad() |
| 103 | +@torch.inference_mode() |
| 104 | +def run( |
| 105 | + image, |
| 106 | + prompt: str, |
| 107 | + prompt_template: str, |
| 108 | + sketch_thickness: int, |
| 109 | + guidance_scale: float, |
| 110 | + inference_steps: int, |
| 111 | + seed: int, |
| 112 | + blend_alpha: float, |
| 113 | +) -> tuple[Image, str]: |
| 114 | + |
| 115 | + print(f"Prompt: {prompt}") |
| 116 | + image_numpy = np.array(image["composite"].convert("RGB")) |
| 117 | + |
| 118 | + if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628): |
| 119 | + return blank_image, "Please input the prompt or draw something." |
| 120 | + |
| 121 | + if safety_check.is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt, threshold=0.2): |
| 122 | + prompt = "A red heart." |
| 123 | + |
| 124 | + prompt = prompt_template.format(prompt=prompt) |
| 125 | + pipe.set_blend_alpha(blend_alpha) |
| 126 | + start_time = time.time() |
| 127 | + images = pipe( |
| 128 | + prompt=prompt, |
| 129 | + ref_image=image["composite"], |
| 130 | + guidance_scale=guidance_scale, |
| 131 | + num_inference_steps=inference_steps, |
| 132 | + num_images_per_prompt=1, |
| 133 | + sketch_thickness=sketch_thickness, |
| 134 | + generator=torch.Generator(device=device).manual_seed(seed), |
| 135 | + ) |
| 136 | + |
| 137 | + latency = time.time() - start_time |
| 138 | + |
| 139 | + if latency < 1: |
| 140 | + latency = latency * 1000 |
| 141 | + latency_str = f"{latency:.2f}ms" |
| 142 | + else: |
| 143 | + latency_str = f"{latency:.2f}s" |
| 144 | + torch.cuda.empty_cache() |
| 145 | + |
| 146 | + img = [ |
| 147 | + Image.fromarray( |
| 148 | + norm_ip(img, -1, 1) |
| 149 | + .mul(255) |
| 150 | + .add_(0.5) |
| 151 | + .clamp_(0, 255) |
| 152 | + .permute(1, 2, 0) |
| 153 | + .to("cpu", torch.uint8) |
| 154 | + .numpy() |
| 155 | + .astype(np.uint8) |
| 156 | + ) |
| 157 | + for img in images |
| 158 | + ] |
| 159 | + img = img[0] |
| 160 | + return img, latency_str |
| 161 | + |
| 162 | + |
| 163 | +model_size = "1.6" if "1600M" in args.model_path else "0.6" |
| 164 | +title = f""" |
| 165 | + <div style='display: flex; align-items: center; justify-content: center; text-align: center;'> |
| 166 | + <img src="https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/logo.png" width="50%" alt="logo"/> |
| 167 | + </div> |
| 168 | +""" |
| 169 | +DESCRIPTION = f""" |
| 170 | + <p><span style="font-size: 36px; font-weight: bold;">Sana-ControlNet-{model_size}B</span><span style="font-size: 20px; font-weight: bold;">{args.image_size}px</span></p> |
| 171 | + <p style="font-size: 18px; font-weight: bold;">Sana: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer</p> |
| 172 | + <p><span style="font-size: 16px;"><a href="https://arxiv.org/abs/2410.10629">[Paper]</a></span> <span style="font-size: 16px;"><a href="https://github.com/NVlabs/Sana">[Github]</a></span> <span style="font-size: 16px;"><a href="https://nvlabs.github.io/Sana">[Project]</a></span</p> |
| 173 | + <p style="font-size: 18px; font-weight: bold;">Powered by <a href="https://hanlab.mit.edu/projects/dc-ae">DC-AE</a> with 32x latent space, </p>running on node {socket.gethostname()}. |
| 174 | + <p style="font-size: 16px; font-weight: bold;">Unsafe word will give you a 'Red Heart' in the image instead.</p> |
| 175 | + """ |
| 176 | +if model_size == "0.6": |
| 177 | + DESCRIPTION += "\n<p>0.6B model's text rendering ability is limited.</p>" |
| 178 | +if not torch.cuda.is_available(): |
| 179 | + DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
| 180 | + |
| 181 | + |
| 182 | +with gr.Blocks(css_paths="asset/app_styles/controlnet_app_style.css", title=f"Sana Sketch-to-Image Demo") as demo: |
| 183 | + gr.Markdown(title) |
| 184 | + gr.HTML(DESCRIPTION) |
| 185 | + |
| 186 | + with gr.Row(elem_id="main_row"): |
| 187 | + with gr.Column(elem_id="column_input"): |
| 188 | + gr.Markdown("## INPUT", elem_id="input_header") |
| 189 | + with gr.Group(): |
| 190 | + canvas = gr.Sketchpad( |
| 191 | + value=blank_image, |
| 192 | + height=640, |
| 193 | + image_mode="RGB", |
| 194 | + sources=["upload", "clipboard"], |
| 195 | + type="pil", |
| 196 | + label="Sketch", |
| 197 | + show_label=False, |
| 198 | + show_download_button=True, |
| 199 | + interactive=True, |
| 200 | + transforms=[], |
| 201 | + canvas_size=(1024, 1024), |
| 202 | + scale=1, |
| 203 | + brush=gr.Brush(default_size=3, colors=["#000000"], color_mode="fixed"), |
| 204 | + format="png", |
| 205 | + layers=False, |
| 206 | + ) |
| 207 | + with gr.Row(): |
| 208 | + prompt = gr.Text(label="Prompt", placeholder="Enter your prompt", scale=6) |
| 209 | + run_button = gr.Button("Run", scale=1, elem_id="run_button") |
| 210 | + download_sketch = gr.DownloadButton("Download Sketch", scale=1, elem_id="download_sketch") |
| 211 | + with gr.Row(): |
| 212 | + style = gr.Dropdown(label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1) |
| 213 | + prompt_template = gr.Textbox( |
| 214 | + label="Prompt Style Template", value=STYLES[DEFAULT_STYLE_NAME], scale=2, max_lines=1 |
| 215 | + ) |
| 216 | + |
| 217 | + with gr.Row(): |
| 218 | + sketch_thickness = gr.Slider( |
| 219 | + label="Sketch Thickness", |
| 220 | + minimum=1, |
| 221 | + maximum=4, |
| 222 | + step=1, |
| 223 | + value=2, |
| 224 | + ) |
| 225 | + with gr.Row(): |
| 226 | + inference_steps = gr.Slider( |
| 227 | + label="Sampling steps", |
| 228 | + minimum=5, |
| 229 | + maximum=40, |
| 230 | + step=1, |
| 231 | + value=20, |
| 232 | + ) |
| 233 | + guidance_scale = gr.Slider( |
| 234 | + label="CFG Guidance scale", |
| 235 | + minimum=1, |
| 236 | + maximum=10, |
| 237 | + step=0.1, |
| 238 | + value=4.5, |
| 239 | + ) |
| 240 | + blend_alpha = gr.Slider( |
| 241 | + label="Blend Alpha", |
| 242 | + minimum=0, |
| 243 | + maximum=1, |
| 244 | + step=0.1, |
| 245 | + value=0, |
| 246 | + ) |
| 247 | + with gr.Row(): |
| 248 | + seed = gr.Slider(label="Seed", show_label=True, minimum=0, maximum=MAX_SEED, value=233, step=1, scale=4) |
| 249 | + randomize_seed = gr.Button("Random Seed", scale=1, min_width=50, elem_id="random_seed") |
| 250 | + |
| 251 | + with gr.Column(elem_id="column_output"): |
| 252 | + gr.Markdown("## OUTPUT", elem_id="output_header") |
| 253 | + with gr.Group(): |
| 254 | + result = gr.Image( |
| 255 | + format="png", |
| 256 | + height=640, |
| 257 | + image_mode="RGB", |
| 258 | + type="pil", |
| 259 | + label="Result", |
| 260 | + show_label=False, |
| 261 | + show_download_button=True, |
| 262 | + interactive=False, |
| 263 | + elem_id="output_image", |
| 264 | + ) |
| 265 | + latency_result = gr.Text(label="Inference Latency", show_label=True) |
| 266 | + |
| 267 | + download_result = gr.DownloadButton("Download Result", elem_id="download_result") |
| 268 | + gr.Markdown("### Instructions") |
| 269 | + gr.Markdown("**1**. Enter a text prompt (e.g. a cat)") |
| 270 | + gr.Markdown("**2**. Start sketching or upload a reference image") |
| 271 | + gr.Markdown("**3**. Change the image style using a style template") |
| 272 | + gr.Markdown("**4**. Try different seeds to generate different results") |
| 273 | + |
| 274 | + run_inputs = [canvas, prompt, prompt_template, sketch_thickness, guidance_scale, inference_steps, seed, blend_alpha] |
| 275 | + run_outputs = [result, latency_result] |
| 276 | + |
| 277 | + randomize_seed.click( |
| 278 | + lambda: random.randint(0, MAX_SEED), |
| 279 | + inputs=[], |
| 280 | + outputs=seed, |
| 281 | + api_name=False, |
| 282 | + queue=False, |
| 283 | + ).then(run, inputs=run_inputs, outputs=run_outputs, api_name=False) |
| 284 | + |
| 285 | + style.change( |
| 286 | + lambda x: STYLES[x], |
| 287 | + inputs=[style], |
| 288 | + outputs=[prompt_template], |
| 289 | + api_name=False, |
| 290 | + queue=False, |
| 291 | + ).then(fn=run, inputs=run_inputs, outputs=run_outputs, api_name=False) |
| 292 | + gr.on( |
| 293 | + triggers=[prompt.submit, run_button.click, canvas.change], |
| 294 | + fn=run, |
| 295 | + inputs=run_inputs, |
| 296 | + outputs=run_outputs, |
| 297 | + api_name=False, |
| 298 | + ) |
| 299 | + |
| 300 | + download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch) |
| 301 | + download_result.click(fn=save_image, inputs=result, outputs=download_result) |
| 302 | + gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility") |
| 303 | + |
| 304 | + |
| 305 | +if __name__ == "__main__": |
| 306 | + demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=DEMO_PORT, debug=False, share=args.share) |
0 commit comments