Skip to content

Commit 62fe6ba

Browse files
[ctrlnet] fix gradio (#229)
* [ctrlnet] fix gradio * fix controlnet bug; Signed-off-by: lawrence-cj <[email protected]> --------- Signed-off-by: lawrence-cj <[email protected]> Co-authored-by: lawrence-cj <[email protected]>
1 parent c4bb626 commit 62fe6ba

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

app/sana_controlnet_pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from dataclasses import dataclass, field
1818
from typing import Optional, Tuple
1919

20+
import cv2
2021
import numpy as np
2122
import pyrallis
2223
import torch
@@ -282,7 +283,6 @@ def forward(
282283
)
283284
else:
284285
z = latents.to(self.device)
285-
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
286286

287287
# control signal
288288
if isinstance(ref_image, str):
@@ -302,7 +302,10 @@ def forward(
302302
self.config.vae.vae_type, self.vae, control_signal, self.config.vae.sample_posterior, self.device
303303
)
304304

305-
model_kwargs["control_signal"] = control_signal_latent
305+
model_kwargs = dict(
306+
data_info={"img_hw": hw, "aspect_ratio": ar, "control_signal": control_signal_latent},
307+
mask=emb_masks,
308+
)
306309

307310
if self.vis_sampler == "flow_euler":
308311
flow_solver = FlowEuler(

configs/sana_controlnet_config/Sana_1600M_1024px_controlnet_bf16.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ vae:
4040
vae_latent_dim: 32
4141
vae_downsample_rate: 32
4242
sample_posterior: true
43+
weight_dtype: bf16
4344
# text encoder
4445
text_encoder:
4546
text_encoder_name: gemma-2-2b-it

0 commit comments

Comments
 (0)