Skip to content

Commit

Permalink
Mixed precision diffusion models with scaled fp8.
Browse files Browse the repository at this point in the history
This change allows supports for diffusion models where all the linears are
scaled fp8 while the other weights are the original precision.
  • Loading branch information
comfyanonymous committed Oct 21, 2024
1 parent 83ca891 commit 0075c6d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 11 deletions.
6 changes: 3 additions & 3 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod

if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8)
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
Expand Down Expand Up @@ -246,8 +246,8 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_

unet_state_dict = self.diffusion_model.state_dict()

if self.model_config.scaled_fp8:
unet_state_dict["scaled_fp8"] = torch.tensor([])
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)

unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)

Expand Down
7 changes: 5 additions & 2 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)

if "{}scaled_fp8".format(unet_key_prefix) in state_dict:
model_config.scaled_fp8 = True
scaled_fp8_weight = state_dict.get("{}scaled_fp8".format(unet_key_prefix), None)
if scaled_fp8_weight is not None:
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn

return model_config

Expand Down
6 changes: 3 additions & 3 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,10 @@ def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):

return scaled_fp8_op

def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)

if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
Expand Down
4 changes: 2 additions & 2 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None

unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)

model_config.custom_operations = model_options.get("custom_operations", None)
Expand Down Expand Up @@ -677,7 +677,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse

offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)

if dtype is None:
Expand Down
2 changes: 1 addition & 1 deletion comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class BASE:

manual_cast_dtype = None
custom_operations = None
scaled_fp8 = False
scaled_fp8 = None
optimizations = {"fp8": False}

@classmethod
Expand Down

0 comments on commit 0075c6d

Please sign in to comment.