Skip to content

Commit

Permalink
Optimizations to --fast and scaled fp8.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Oct 22, 2024
1 parent f82314f commit 8ce2a10
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]:
return None

tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)


if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
Expand All @@ -272,7 +278,11 @@ def fp8_linear(self, input):
if isinstance(o, tuple):
o = o[0]

if tensor_2d:
return o.reshape(input.shape[0], -1)

return o.reshape((-1, input.shape[1], self.weight.shape[0]))

return None

class fp8_ops(manual_cast):
Expand Down Expand Up @@ -316,7 +326,11 @@ def forward_comfy_cast_weights(self, input):
return out

weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)

if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)

def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
Expand Down

0 comments on commit 8ce2a10

Please sign in to comment.