Skip to content

Commit

Permalink
Fixed model merging issue with scaled fp8.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Oct 20, 2024
1 parent 471cd3e commit f9f9faf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 29 deletions.
2 changes: 1 addition & 1 deletion comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
weight *= strength_model

if isinstance(v, list):
v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), )
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )

if len(v) == 1:
patch_type = "diff"
Expand Down
59 changes: 33 additions & 26 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,31 @@ def __call__(self, weight):
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))

return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)

def get_key_weight(model, key):
set_func = None
convert_func = None
op_keys = key.rsplit('.', 1)
if len(op_keys) < 2:
weight = comfy.utils.get_attr(model, key)
else:
op = comfy.utils.get_attr(model, op_keys[0])
try:
set_func = getattr(op, "set_{}".format(op_keys[1]))
except AttributeError:
pass

try:
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
except AttributeError:
pass

weight = getattr(op, op_keys[1])
if convert_func is not None:
weight = comfy.utils.get_attr(model, key)

return weight, set_func, convert_func

class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
Expand Down Expand Up @@ -294,14 +319,16 @@ def get_key_patches(self, filter_prefix=None):
if not k.startswith(filter_prefix):
continue
bk = self.backup.get(k, None)
weight, set_func, convert_func = get_key_weight(self.model, k)
if bk is not None:
weight = bk.weight
else:
weight = model_sd[k]
if convert_func is None:
convert_func = lambda a, **kwargs: a

if k in self.patches:
p[k] = [weight] + self.patches[k]
p[k] = [(weight, convert_func)] + self.patches[k]
else:
p[k] = (weight,)
p[k] = [(weight, convert_func)]
return p

def model_state_dict(self, filter_prefix=None):
Expand All @@ -317,27 +344,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches:
return

set_func = None
convert_func = None
op_keys = key.rsplit('.', 1)
if len(op_keys) < 2:
weight = comfy.utils.get_attr(self.model, key)
else:
op = comfy.utils.get_attr(self.model, op_keys[0])
try:
set_func = getattr(op, "set_{}".format(op_keys[1]))
except AttributeError:
pass

try:
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
except AttributeError:
pass

weight = getattr(op, op_keys[1])
if convert_func is not None:
weight = comfy.utils.get_attr(self.model, key)

weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update

if key not in self.backup:
Expand All @@ -348,7 +355,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
else:
temp_weight = weight.to(torch.float32, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight)
temp_weight = convert_func(temp_weight, inplace=True)

out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None:
Expand Down
8 changes: 6 additions & 2 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,12 @@ def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)

def convert_weight(self, weight):
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)

def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
Expand Down

0 comments on commit f9f9faf

Please sign in to comment.