diff --git a/comfy/ops.py b/comfy/ops.py index bd84a804c92..5fef7cee78d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -254,16 +254,33 @@ def fp8_linear(self, input): non_blocking = comfy.model_management.device_supports_non_blocking(input.device) w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t() + scale_weight = self.scale_weight + scale_input = self.scale_input + if scale_weight is None: + scale_weight = torch.ones((1), device=input.device, dtype=torch.float32) + if scale_input is None: + scale_input = scale_weight + if scale_input is None: + scale_input = torch.ones((1), device=input.device, dtype=torch.float32) + if self.bias is not None: - o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking)) + o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight) else: - o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype) + o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) + + if isinstance(o, tuple): + o = o[0] return o.reshape((-1, input.shape[1], self.weight.shape[0])) return None class fp8_ops(manual_cast): class Linear(manual_cast.Linear): + def reset_parameters(self): + self.scale_weight = None + self.scale_input = None + return None + def forward_comfy_cast_weights(self, input): out = fp8_linear(self, input) if out is not None: