diff --git a/comfy/ops.py b/comfy/ops.py index 5e7c668eb03..3c5ba0124b5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -264,10 +264,14 @@ def fp8_linear(self, input): scale_input = self.scale_input if scale_weight is None: scale_weight = torch.ones((), device=input.device, dtype=torch.float32) + else: + scale_weight = scale_weight.to(input.device) + if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) inn = input.reshape(-1, input.shape[2]).to(dtype) else: + scale_input = scale_input.to(input.device) inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) if bias is not None: