Skip to content

Commit

Permalink
Make --fast work on pytorch nightly.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 21, 2024
1 parent 5f50263 commit 904bf58
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,33 @@ def fp8_linear(self, input):
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()

scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = scale_weight
if scale_input is None:
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)

if self.bias is not None:
o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight)
else:
o, _ = torch._scaled_mm(inn, w, out_dtype=input.dtype)
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)

if isinstance(o, tuple):
o = o[0]

return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None

class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def reset_parameters(self):
self.scale_weight = None
self.scale_input = None
return None

def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
Expand Down

0 comments on commit 904bf58

Please sign in to comment.