Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented generic type validation for all nodes. #4149

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")

parser.add_argument("--type-conformance", type=str, choices=['all', 'images', 'error', 'none'], default='none', help="Level of type conformance for input/output values between nodes. 'all/images' will attempt to convert all or only image values to the expected type, 'error' will raise an error if the types do not match.")


# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"

Expand Down
136 changes: 133 additions & 3 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,130 @@

import comfy.model_management

from comfy.cli_args import args


import sys
import copy
import logging
import threading
import heapq
import time
import traceback
import inspect
from typing import List, Literal, NamedTuple, Optional

import torch
import nodes

import comfy.model_management

from comfy.cli_args import args


def validate_image_shape(image, node_type, output_name):
image_len = len(image.shape)

def type_check():
print(f"Checking Image Shape: {image.shape}")
if image_len != 4 or image.shape[-1] > 4 or image.shape[-1] < 3:
if "error" == args.type_conformance:
raise ValueError(f"Image Shape Error: Node: {node_type}, Output: {output_name}, [{image.shape}] does not match the expected RGB format: torch.Size[B, H, W, 3] or the RGBA format: torch.Size[B, H, W, 4]")
else:
print(f"Image Shape Error: Node: {node_type}, Output: {output_name}, [{image.shape}] does not match the expected RGB format: torch.Size[B, H, W, 3] or the RGBA format: torch.Size[B, H, W, 4]")
return False
return True

if type_check() != True and args.type_conformance in ["all", "images"]:
transforms = {
"HW": lambda t: t.unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3),
"BHW": lambda t: t.unsqueeze(-1).expand(-1, -1, -1, 3),
"HWC": lambda t: t.unsqueeze(0),
}

if image_len == 2:
#HW -> add Batch and Channel dimensions
image = transforms["HW"](image)
return image

if image_len == 3:
if image.shape[-1] > 4 or image.shape[-1] < 3:
#BHW -> add Channel dimension
image = transforms["BHW"](image)
return image

if image.shape[-1] == 3 or image.shape[-1] == 4:
#HW3 or HW4 -> add Batch dimension --- can misbehave in edge cases where image is BHW but W is 3 or 4 px
image = transforms["HWC"](image)
return image

return image


def generate_type_validator(valid_types, validator=None):
def validate_type(node_type, output_name, value):
print(type(validator))
if not isinstance(value, valid_types) and validator is None:
print(f"TypeError in Node: {node_type}. Expected {output_name} to be one of type(s): {[t.__name__ for t in valid_types]}, got {[type(value).__name__]} instead")

if "all" == args.type_conformance:
return valid_types[0](value)

if "error" == args.type_conformance:
raise TypeError(f"Expected: [{output_name}], to be one of type(s): {[vtype.__name__ for vtype in valid_types]}, got {[type(value).__name__]} instead")

return value

if isinstance(value, valid_types) and validator is not None:
print(f"Validating {output_name} in Node: {node_type}")
value = validator(value, node_type, output_name)
return value

return value

return validate_type


def input_validation(input_data_all, obj):
validation_funcs = {
"IMAGE": generate_type_validator((torch.Tensor,), validator=validate_image_shape),
"INT": generate_type_validator((int,)),
"FLOAT": generate_type_validator((float,)),
"STRING": generate_type_validator((str,)),
}
input_types = obj.INPUT_TYPES()
for _, v in input_types.items():
if isinstance(v, dict):
for k2, v2 in v.items():
if tuple(v2[0]) in validation_funcs.keys():
input_data_all[k2] = [validation_funcs[v2[0]](obj.__class__.__name__, k2, x) for x in input_data_all[k2]]

return input_data_all


def output_validation(results, obj):
validation_funcs = {
"IMAGE": generate_type_validator((torch.Tensor,), validator=validate_image_shape),
"INT": generate_type_validator((int,)),
"FLOAT": generate_type_validator((float,)),
"STRING": generate_type_validator((str,)),
}
if hasattr(obj, "RETURN_NAMES") and hasattr(obj, "RETURN_TYPES"):
return_indexs = {}
formatted_results = []

for i, return_type in enumerate(obj.RETURN_TYPES):
return_indexs[i] = return_type

for i, result in enumerate(results[0]):
return_type = return_indexs[i]
formatted_results.append(validation_funcs[return_type](obj.__class__.__name__, obj.RETURN_NAMES[i], result) if return_type in validation_funcs.keys() else result)

results = [tuple(formatted_results)]
del formatted_results
return results


def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
Expand All @@ -28,7 +152,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = [input_data]
input_data_all[x] = [input_data]

if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
Expand All @@ -39,9 +163,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]

return input_data_all


def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
input_data_all = input_validation(input_data_all, obj)

# check if node wants the lists
input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
Expand Down Expand Up @@ -73,14 +201,16 @@ def slice_dict(d, i):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))

results = output_validation(results, obj)
return results


def get_output_data(obj, input_data_all):

results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)

for r in return_values:
if isinstance(r, dict):
if 'ui' in r:
Expand Down
Loading