Skip to content

Commit

Permalink
Make clip loader nodes support loading sd3 t5xxl in lower precision.
Browse files Browse the repository at this point in the history
Add attention mask support in the SD3 text encoder code.
  • Loading branch information
comfyanonymous committed Oct 10, 2024
1 parent 5f9d5a2 commit 1b80895
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
29 changes: 17 additions & 12 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,19 @@ def detect_te_model(sd):
return TEModel.T5_BASE
return None


def t5xxl_weight_dtype(clip_data):
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"

dtype_t5 = None
for sd in clip_data:
weight = sd.get(weight_name, None)
if weight is not None:
dtype_t5 = weight.dtype
break
return dtype_t5


def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = state_dicts

Expand Down Expand Up @@ -462,9 +475,7 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
elif te_model == TEModel.T5_XXL:
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=t5xxl_weight_dtype(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif te_model == TEModel.T5_XL:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
Expand All @@ -482,25 +493,19 @@ class EmptyClass:
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models)
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, dtype_t5=t5xxl_weight_dtype(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = comfy.text_encoders.hydit.HyditModel
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
elif clip_type == CLIPType.FLUX:
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
dtype_t5 = None
if weight is not None:
dtype_t5 = weight.dtype

clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer

parameters = 0
Expand Down
22 changes: 14 additions & 8 deletions comfy/text_encoders/sd3_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import logging

class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)

class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
Expand Down Expand Up @@ -39,7 +39,7 @@ def state_dict(self):
return {}

class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
Expand All @@ -57,7 +57,8 @@ def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu

if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.t5_attention_mask = t5_attention_mask
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=self.t5_attention_mask)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
Expand Down Expand Up @@ -87,6 +88,7 @@ def encode_token_weights(self, token_weight_pairs):
lg_out = None
pooled = None
out = None
extra = {}

if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
if self.clip_l is not None:
Expand All @@ -111,7 +113,11 @@ def encode_token_weights(self, token_weight_pairs):
pooled = torch.cat((l_pooled, g_pooled), dim=-1)

if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_out, t5_pooled = t5_output[:2]
if self.t5_attention_mask:
extra["attention_mask"] = t5_output[2]["attention_mask"]

if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
Expand All @@ -123,7 +129,7 @@ def encode_token_weights(self, token_weight_pairs):
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())

return out, pooled
return out, pooled, extra

def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
Expand All @@ -133,8 +139,8 @@ def load_sd(self, sd):
else:
return self.t5xxl.load_sd(sd)

def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

0 comments on commit 1b80895

Please sign in to comment.