Skip to content

Commit

Permalink
Use the lowvram cast_to function for everything.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Oct 17, 2024
1 parent 7390ff3 commit 6715899
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 32 deletions.
32 changes: 13 additions & 19 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,27 +840,21 @@ def force_channels_last():
#TODO
return False

def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
device_supports_cast = True
elif tensor.dtype == torch.bfloat16:
if hasattr(device, 'type') and device.type.startswith("cuda"):
device_supports_cast = True
elif is_intel_xpu():
device_supports_cast = True
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)

non_blocking = device_should_use_non_blocking(device)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r

def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)

if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)

def xformers_enabled():
global directml_enabled
Expand Down
17 changes: 4 additions & 13 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,10 @@
import comfy.model_management
from comfy.cli_args import args

def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)

r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
cast_to = comfy.model_management.cast_to #TODO: remove once no more references

def cast_to_input(weight, input, non_blocking=False, copy=True):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)

def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
Expand All @@ -47,12 +38,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = s.bias_function is not None
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)

has_function = s.weight_function is not None
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
return weight, bias
Expand Down

0 comments on commit 6715899

Please sign in to comment.