Skip to content

Commit

Permalink
Support scaled fp8 t5xxl model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Oct 21, 2024
1 parent f9f9faf commit 83ca891
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 30 deletions.
13 changes: 11 additions & 2 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,21 @@ def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

def scaled_fp8_ops(fp8_matrix_mult=False):
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().__init__(*args, **kwargs)

def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)

if not scale_input:
self.scale_input = None

if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
Expand Down Expand Up @@ -328,7 +337,7 @@ def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute)
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True)

if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
Expand Down
19 changes: 9 additions & 10 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,16 +432,15 @@ def detect_te_model(sd):
return None


def t5xxl_weight_dtype(clip_data):
def t5xxl_detect(clip_data):
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"

dtype_t5 = None
for sd in clip_data:
weight = sd.get(weight_name, None)
if weight is not None:
dtype_t5 = weight.dtype
break
return dtype_t5
if weight_name in sd:
return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd)

return {}


def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
Expand Down Expand Up @@ -475,7 +474,7 @@ class EmptyClass:
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
elif te_model == TEModel.T5_XXL:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=t5xxl_weight_dtype(clip_data))
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif te_model == TEModel.T5_XL:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
Expand All @@ -493,19 +492,19 @@ class EmptyClass:
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, dtype_t5=t5xxl_weight_dtype(clip_data))
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=TEModel.CLIP_L in te_models, clip_g=TEModel.CLIP_G in te_models, t5=TEModel.T5_XXL in te_models, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = comfy.text_encoders.hydit.HyditModel
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
elif clip_type == CLIPType.FLUX:
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
clip_target.clip = comfy.text_encoders.flux.flux_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(dtype_t5=t5xxl_weight_dtype(clip_data))
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer

parameters = 0
Expand Down
11 changes: 10 additions & 1 deletion comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,20 @@ def __init__(self, device="cpu", max_length=77,
config = json.load(f)

operations = model_options.get("custom_operations", None)
scaled_fp8 = None

if operations is None:
operations = comfy.ops.manual_cast
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast

self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
if scaled_fp8 is not None:
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))

self.num_layers = self.transformer.num_layers

self.max_length = max_length
Expand Down
14 changes: 5 additions & 9 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,11 @@ def clip_target(self, state_dict={}):
clip_l = True
if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_g = True
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
if "dtype_t5" in t5_detect:
t5 = True
dtype_t5 = state_dict[t5_key].dtype

return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5))
return supported_models_base.ClipTarget(comfy.text_encoders.sd3_clip.SD3Tokenizer, comfy.text_encoders.sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))

class StableAudio(supported_models_base.BASE):
unet_config = {
Expand Down Expand Up @@ -653,11 +652,8 @@ def get_model(self, state_dict, prefix="", device=None):

def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
dtype_t5 = None
if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))

class FluxSchnell(Flux):
unet_config = {
Expand Down
13 changes: 6 additions & 7 deletions comfy/text_encoders/flux.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip
import comfy.model_management
from transformers import T5TokenizerFast
import torch
import os

class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, model_options=model_options)

class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
Expand Down Expand Up @@ -41,7 +37,7 @@ def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])

def set_clip_options(self, options):
Expand All @@ -66,8 +62,11 @@ def load_sd(self, sd):
else:
return self.t5xxl.load_sd(sd)

def flux_clip(dtype_t5=None):
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_
23 changes: 22 additions & 1 deletion comfy/text_encoders/sd3_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,26 @@
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8

super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)


def t5_xxl_detect(state_dict, prefix=""):
out = {}
t5_key = "{}encoder.final_layer_norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype

scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype

return out

class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
Expand Down Expand Up @@ -139,8 +157,11 @@ def load_sd(self, sd):
else:
return self.t5xxl.load_sd(sd)

def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_attention_mask=False):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

0 comments on commit 83ca891

Please sign in to comment.