Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented generic type validation for all nodes. #4149

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

shawnington
Copy link
Contributor

@shawnington shawnington commented Jul 31, 2024

This change started trying to transparently implement BHW3 compliance for IMAGE type inputs and outputs for all nodes, and has since started to expand into a general way to check type conformance for all types, with the goal of enabling a much more robust typing system. No changes need to be made to existing nodes, the conformance checking is done before data is passed into a node, and after it leaves the nodes.

Validation functions can be added to type checks in addition, to perform data transformations that allow custom type to parse multiple datatypes into a format that conforms with the type definition of that type.

The current implementation uses a cli-arg to set the level of desired type-conformance and warning.

--type-conformance=["all", "image", "error", "none"]

  • "error": will raise an error at the first node with a type error, with the name of the node, and the input/output variable name, expected type, and type received.
  • "image": will only convert BHW and HW mono channel images in to IMAGE conforming BHW3 images that will work with the base compositing nodes. This will not force a BHW4 image into BHW3
  • "all": will force conversion to the type expected, using either defaults such as int|float|string, or custom formatting functions supplied to the validation generator function.
  • "none": this is the default value, and will simply soft error the same information as the "error" flag to console. The name of the node, the name of the input/output, the expected type, and the received type.

This is what a raised error looks like when wrong types are passed, the same message default to console by default without interrupting workflow execution.

Screenshot 2024-08-03 at 9 39 45 AM Screenshot 2024-08-03 at 9 41 13 AM

The "image" option currently serves as en example of how custom validation can be done for any type by injecting a custom validation and error handling function.

These change were discussed on the #backend-development discord with @comfyanonymous @mcmonkey4eva @guill and several others, along with in this PR

More rigorous testing needs to be done to ensure that there are no egregious failure modes and that all formats are covered.

Changes that are being looked at include creation of a TypeValidator class to replace the dict that controls the generate_type_validator function and opens up the possibility for node packs to define their own types and validation functions that apply globally to any nodes using that type.

Replacing this:

    validation_funcs = { 
        "IMAGE": generate_type_validator((torch.Tensor,), validator=validate_image_shape),
        "INT": generate_type_validator((int,)), 
        "FLOAT": generate_type_validator((float,)), 
        "STRING": generate_type_validator((str,)), 
    }

With something like this:

class TypeValidator:
    validation_funcs = {
        "IMAGE": generate_type_validator((torch.Tensor,)),
        "INT": generate_type_validator((int,)),
        "FLOAT": generate_type_validator((float,)),
        "STRING": generate_type_validator((str,)),
    }

    @classmethod
    def add_type(cls, type_name, valid_types, validator=None):
        cls.validation_funcs[type_name] = generate_type_validator(valid_types, validator=validator)
    
    def has_type(self, type_name):
        return type_name in self.validation_funcs.keys()

    def __getitem__(self, type_name):
        return self.validation_funcs[type_name]

    def __call__(self, type_name, value):
       return self.validation_funcs[type_name](value) if self.has_type(type_name) else value

and called like:

TypeValidator.add_type("IMAGE", (torch.Tensor,), validator=validate_image_shape)

validation_funcs = TypeValidator()

value = validation_funcs("IMAGE", value)

#or the more verbose way 
value = validation_funcs["IMAGE"](value) if validation_func.has_type("IMAGE") else value

This change transparently implements BHW3 compliance for IMAGE type inputs and outputs for all nodes. No changes need to be made to existing nodes, the conformance is applied before the IMAGE is passed into the node, and after it is passed out of the node. 

This change was discussed on the #backend_developers discord with @comfyanonymous  and @mcmonkey4eva 

More rigorous testing needs to be done to ensure that there are no egregious failure modes and that all formats are covered. 

Im sure the `force_bhw3` function can also be made better, but this is currently a working implementation.
execution.py Outdated
#add batch dimension
image = image.unsqueeze(0)

if image.shape[1] == 3:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if somebody just happens to have a 3x3 image somewhere? This will corrupt it.

I imagine if check [2] and [3] are both >4 then the conversion is confident. But then the conversion only happens on larger images, and quietly doesn't apply to smaller images.
... This type of thing is a good example of why it might be better to warn than to try to autocorrect - I don't think there's actually a 100% reliable detection of format in all cases, just "99% of the time it's right" heuristic checks like this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Also mentioned in Discord, but copying here for documentation)
+1 for just warning (and eventually erroring) in any of the slightly ambiguous cases. If we're going to auto fix anything, it should just be the BWH case (which is easy to differentiate due to the tensor only having 3 dimensions). That's probably the most common error anyway.

execution.py Outdated
Comment on lines 69 to 75
#Ensure image inputs are in BHW3 format
input_types = obj.INPUT_TYPES()
for _, v in input_types.items():
if isinstance(v, dict):
for k2, v2 in v.items():
if v2[0] == "IMAGE":
input_data_all[k2] = [force_bhw3(x) for x in input_data_all[k2]]
Copy link
Contributor

@guill guill Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should only validate things coming out of nodes. map_node_over_list is used for some cases that aren't actual node execution and there may be awkward side-effects. Additionally, if we're going to throw a warning/error, that warning/error would be attributed to a node that isn't the real culprit.

As long as we're validating all outputs, the only additional thing this protects people from is a node with output using the silly __ne__ trick. In that case, we can't even be sure that it was intended to be an image -- it might have been an audio clip.

In either case, validating the input in this way still wouldn't catch undeclared inputs (i.e. any node taking a variable number of inputs).

Copy link
Contributor Author

@shawnington shawnington Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all evaluated on execution, so you might be right, or it might work, if the variable number of nodes is codified at execution time, it shouldn't be an issue. All the information one what is being put in is gleaned from the states run through execution.py, and I made sure not to touch the get_input_data function for this reason. It needs rigorous testing with all the different kind of node presentations it can run up against. However since it mainly parses data once its already in a state where its linked to other node outputs/inputs I think it will work, unless there is something I am not understanding about how there variable input nodes work, which is possible.

Also the issue with evaluating after map_nodes_over_list , is that all the type info is stripped from the output after that, so another variable would need to be added to the return that includes the type information and index position of that type, or the type information would need to be included in the output, and subsequent code would need a rewrite to handle that information being present. Brighter minds than me can probably figure out a way, but this was the most logical way I could think of after tearing apart the output at various stages.

Also, side note. If someone manages to pass audio through the IMAGE pipe with this in place, I'd not only be curious, Id wonder why they chose the image pipe. If they do, we can just call it an undocumented feature, lol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all evaluated on execution, so you might be right, or it might work, if the variable number of nodes is codified at execution time, it shouldn't be an issue.

To be clear, this is definitely evaluated on execution, but this function is also called at other times -- like calling IS_CHANGED. In that function, it's totally legitimate to have an IMAGE with a value of None (in fact, that's the expected value since only constant inputs are available when IS_CHANGED is called before the graph has begun execution.

I would put the validation right after the call to get_output_data in recursive_execute. At that point, you have information about the node that was executed like its output types.

execution.py Outdated
Comment on lines 20 to 22
while isinstance(image, list):
was_list = True
image = image[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this only going to fix the first image in the list? If we get a list of images, we should be fixing all the images in the list (which may all have different dimensions). (I would probably make this change outside of the call to force_bhw3 so that it applies to any other type validation we add in the future.)

Copy link
Contributor Author

@shawnington shawnington Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sometimes the image tensor was wrapped in an extraneous list, it did not seem to be related to batch size in any way, its just a difference in formatting from the ways it comes for input, and the way its structured for output. I have no idea why it's like that, but the data does not change between.

Granted it's a very ugly hack, could have gone for an if instead of a while. I just really really wanted to get rid of the list wrapper before you know, I added it back for formatting reasons at the end if it was removed.

I could be wrong. That whole function is likely to evolve considerably as I start to expose it to a wide variety of edge cases, such as 3x3x3x3, and also take into consideration the suggestions and further discord discussion we have had.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not frequently used, but ComfyUI does have functionality for "list" outputs/inputs that are different than batches. Specifically, if a node returns a batch of 5 images, those 5 images will all be passed to the following node for one execution. If a node returns a list of 5 images (each with a batch size of 1), the following node will actually have its execution function called 5 times. In order to support this, outputs are usually passed around wrapped in a list. I believe that's what you were seeing.

To continue to support that functionality, it's important that we process each entry in the list the same way. Someone in Discord/Matrix might have suggestions for real nodes that make use of that functionality so you can test it. I think some of the nodes used for making X/Y plots use it.

execution.py Outdated Show resolved Hide resolved
execution.py Outdated Show resolved Hide resolved
@huchenlei
Copy link
Collaborator

Are we offering a image with alpha type to replace the BHW4 image usage now?

class JoinImageWithAlpha:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"alpha": ("MASK",),
}
}
CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "join_image_with_alpha"
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
batch_size = min(len(image), len(alpha))
out_images = []
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result

added cli args --type-confromance with 3 options

`all` - attempts to convert all values to the input/output types specified by the node

`images` - only converts images to bwh3 format 

`error` - raises an error if the type does not match what is specified. 

`none` - prints error messages to console and does nothing else.
A validation generator function has been aded along with the ability to specify custom functions for validating individual types.

cli args have been added in the form of 

--type-conformance='arg'

with the following args: 
all : attempts to convert any mistyped inputs/outputs to the type specified by the node

images : only enforces BHW3 on images, doing the minimum verifiably accurate transforms such as MASK HW and BHW to BHW3

error: raises an error with information about the type mismatch at the node where it occurs.

none: the default value, it outputs errors to console, and does nothing else. 

This should address most of the areas of discussion.
@shawnington
Copy link
Contributor Author

Are we offering a image with alpha type to replace the BHW4 image usage now?

class JoinImageWithAlpha:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"alpha": ("MASK",),
}
}
CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "join_image_with_alpha"
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
batch_size = min(len(image), len(alpha))
out_images = []
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),)
return result

As it currently is written, it shouldn't touch images in BHW4 format, but it will only convert a BHW mask type to BHW3 per @comfyanonymous preferences.

I personally think BHW4 or another image format specifically for alpha inclusion would be a good use for this kind of validation as new function can be added for any type with the latest changes.

@shawnington shawnington changed the title implemented BHW3 compliance for IMAGE types for all nodes. implemented generic type validation for all nodes. Aug 2, 2024
@melMass
Copy link
Contributor

melMass commented Aug 2, 2024

BHW mask type to BHW3

What is the reasoning I don't get it, why not BHW1? Currently, it's BHW

I personally think BHW4 or another image format specifically for alpha inclusion would be a good use for this kind of validation as new function can be added for any type with the latest changes.

The issue with that is that it limits a bit integrations since you would now need an optional BHW4 and an optional BHW3 or separate nodes to handle IMAGE

@shawnington
Copy link
Contributor Author

shawnington commented Aug 2, 2024

BHW mask type to BHW3

What is the reasoning I don't get it, why not BHW1? Currently, it's BHW

I personally think BHW4 or another image format specifically for alpha inclusion would be a good use for this kind of validation as new function can be added for any type with the latest changes.

The issue with that is that it limits a bit integrations since you would now need an optional BHW4 and an optional BHW3 or separate nodes to handle IMAGE

BHW will display properly in the preview node, but will not composite properly with out conversion to BHW3, and since the only real use for converting mask to image is to do composite operations, having it in the BHW3 format that IMAGE type is supposed to be in, make it work with all of the nodes that do compositing operations.

The issue is basically that the input and output type currently only limit what nodes can connect to each other, and have no bearing on what what type is actually being passed around, or if its in a standardized format.

The reasoning is that at least for base nodes, things should move around in a stand format that is expressed by the type in the node definition, so that custom nodes at least know what they will be getting. If output is specified as an INT, you shouldn't get a FLOAT, or a STRING or a tensor, but all are currently possible.

This started off just dealing with mask images not compositing like you would reasonably expect, and has expanded into a broad type conformance system.

@melMass
Copy link
Contributor

melMass commented Aug 3, 2024

BHW will display properly in the preview node, but will not composite properly with out conversion to BHW3, and since the only real use for converting mask to image is to do composite operations

I can be slow on these kinds of things and I'm not a native English speaker, but I still don't understand the need to duplicate channels (BHW1 vs BHW3). I see now that it's under a flag so I will try this PR with my nodes to grasp it better

@shawnington
Copy link
Contributor Author

shawnington commented Aug 3, 2024

BHW will display properly in the preview node, but will not composite properly with out conversion to BHW3, and since the only real use for converting mask to image is to do composite operations

I can be slow on these kinds of things and I'm not a native English speaker, but I still don't understand the need to duplicate channels (BHW1 vs BHW3). I see now that it's under a flag so I will try this PR with my nodes to grasp it better

So this all started because in comfy_extras/nodes_mask.py the composite function requires both images to be 3 channel

def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
    print(f"destination: {destination.shape} source: {source.shape}")
    destination = tensor_to_rgb(destination)
    source = tensor_to_rgb(source)
    
    source = source.to(destination.device)
    if resize_source:
        source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")

    source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])

    x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
    y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))

    left, top = (x // multiplier, y // multiplier)
    right, bottom = (left + source.shape[3], top + source.shape[2],)

    if mask is None:
        mask = torch.ones_like(source)
    else:
        mask = mask.to(destination.device, copy=True)
        mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
        mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])

    # calculate the bounds of the source that will be overlapping the destination
    # this prevents the source trying to overwrite latent pixels that are out of bounds
    # of the destination
    visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)

    mask = mask[:, :, :visible_height, :visible_width]
    inverse_mask = torch.ones_like(mask) - mask

    source_portion = mask * source[:, :, :visible_height, :visible_width]
    destination_portion = inverse_mask  * destination[:, :, top:bottom, left:right]

    destination[:, :, top:bottom, left:right] = source_portion + destination_portion
    return destination

This fails because it tries to access destination.shape[3], after doing a movedim(-1, 1), and iterate through the channels. This results in an index out of range error, because it cannot access channel 3 in a 1 channel image.

If you take a mask and convert it to an image with PIL (or even just output it from a variable that has RETURN_TYPE = {("IMAGE",),}, you will get a tensor(H,W) back after you convert it back to a tensor.
This would display in the PreviewImage node perfectly fine, when coming from a node with "IMAGE" as the output type.

After discussion on discord, we agreed that if something outputs as an "IMAGE" and previews properly, it should behave properly when put into another node with an "IMAGE" input.

However, since the composite function used by multiple default nodes required both images to be 3 channel to work, it was decided that forcing HW to BHW3 so that it will work with default nodes that handle images

In the process of implementing this, I realized that there is not an actual type system in place to make sure that 'IMAGE' even outputs a torch.Tensor. It can output absolutely anything it wants, so this PR has gotten a bit broader and is now checking for type validation, allowing the injection of custom validation functions, and the creation of custom types, that conform to the expected type structure so that nodes that work with that type, will get what they are expecting as input, and outputting the type that is expected.

I edited the initial comment on the PR to reflect the current state of it.

@Amorano
Copy link

Amorano commented Aug 3, 2024

So this all started because in comfy_extras/nodes_mask.py the composite function requires both images to be 3 channel

which is why my argument is still let the node, itself, convert. You have a compound type that can be 1, 3 or 4... three "types". There should only exist a CONVERSION function for the API users to "convert this 'image input' to the target I want".

Node authors are the ones with the burden of types.

How else is there going to be support for ANYTYPE? Is the core going to ignore it and continue with the current hack -- which means I literally have to type cast it myself?

I don't feel that is a good design and at some point in order to support an "anytype" you would still need a CONVERSION function to be called to cast the input type into the appropriate thing -- why make everyone do this over and over.

All the work done in my parse_dynamic, parse_value and parse_parameter functions does the casting for the core types, mine and other types from authors like Kijai.

As always, don't take the sh!tty code I made, take the concept. The concept covers all current and future use-cases and doesn't blow things up every time someone has a bug about what a type should exist or be supported.

@melMass
Copy link
Contributor

melMass commented Aug 4, 2024

So this all started because in comfy_extras/nodes_mask.py

I see, I didn't get it was about the "internal" methods, I don't use most of them, but maybe will after this change.
I'm going to try the PR now

@mcmonkey4eva mcmonkey4eva added User Support A user needs help with something, probably not a bug. and removed User Support A user needs help with something, probably not a bug. labels Sep 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants