Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Regional Prompting #5868

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
16e5748
Fix avoid storing extra conditioning info in two places.
RyanJDick Feb 13, 2024
e866e3b
Remove use of **kwargs in do_unet_step(...), where full parameter lis…
RyanJDick Feb 14, 2024
bf72cee
Remove outdated comments related to T2I-Adapters and ControlNets.
RyanJDick Feb 14, 2024
ee3abc1
Merge sequential conditioning and cac conditioning logic to eliminate…
RyanJDick Feb 14, 2024
382fa57
Remove unused constructor declared with typo in name: __int__.
RyanJDick Feb 14, 2024
58277c6
Add a mask to the ConditioningField primitive type.
RyanJDick Feb 13, 2024
f590b39
Add support for a list of ConditioningFields in DenoiseLatents.
RyanJDick Feb 15, 2024
7b0326d
Delete unused functions from shared_invokeai_diffusion.py.
RyanJDick Feb 15, 2024
ef51005
Remove unused code for attention map saving.
RyanJDick Feb 15, 2024
ba47880
Initialize a RegionalPromptAttnProcessor2_0 class by copying AttnProc…
RyanJDick Feb 15, 2024
38248b9
Fix a minor bug in the logic of the IPAttnProcessor2_0. The change wo…
RyanJDick Feb 16, 2024
caa690e
Add concatenation of multiple text conditioning tensors, and patching…
RyanJDick Feb 16, 2024
878bbc3
Add RectangleMaskInvocation.
RyanJDick Feb 16, 2024
2d5d370
Route masks into the RegionalPromptAttnProcessor2_0 processors.
RyanJDick Feb 17, 2024
d132fb4
Get RegionalPromptAttnProcessor2_0 working with a ton of hacks.
RyanJDick Feb 18, 2024
b0fcbe5
Tidy invocation interfaces for RectangleMaskInvocation and AddConditi…
RyanJDick Feb 26, 2024
2966c8d
Handle conditioned and unconditioned text conditioning in the same wa…
RyanJDick Feb 27, 2024
cfba51a
Removed unused function: _prepare_text_embeddings(...)
RyanJDick Feb 28, 2024
54971af
Add symmetric support for regional negative text prompts.
RyanJDick Feb 28, 2024
845c4e9
Update various comments related to regional prompting, and delete dup…
RyanJDick Feb 28, 2024
cad3e5d
Remove dead code related to an old symmetry feature.
RyanJDick Feb 28, 2024
e7ec13f
Remove scheduler_args from ConditioningData structure.
RyanJDick Feb 28, 2024
ee1b315
Split ip_adapter_conditioning out from ConditioningData.
RyanJDick Feb 28, 2024
53ebca5
Rename ConditioningData to TextConditioningData.
RyanJDick Feb 28, 2024
5f49e7a
Move regional prompt concatenation further up the stack. This solves …
RyanJDick Feb 29, 2024
e132afb
Make regional prompting work with sequential conditioning.
RyanJDick Feb 29, 2024
e7f7ae6
Raise a clear error message if prompt-to-prompt cross-attention contr…
RyanJDick Feb 29, 2024
bdf3691
Improve the logic for selecting SDXL pooled embeds when handling mult…
RyanJDick Feb 29, 2024
1bbd4f7
Fixup logic around compatibility of prompt-to-prompt, IP-Adapter, reg…
RyanJDick Feb 29, 2024
f44d3da
Add CustomAttnProcessor2_0 class with simultaneous support for IP-Ada…
RyanJDick Feb 29, 2024
8989a6c
Get multi-prompt attention working simultaneously with IP-adapter.
RyanJDick Feb 29, 2024
4a1acd4
Fix avoid storing extra conditioning info in two places.
RyanJDick Feb 13, 2024
7d96710
Remove use of **kwargs in do_unet_step(...), where full parameter lis…
RyanJDick Feb 14, 2024
d87ff3a
Remove outdated comments related to T2I-Adapters and ControlNets.
RyanJDick Feb 14, 2024
8721926
Merge sequential conditioning and cac conditioning logic to eliminate…
RyanJDick Feb 14, 2024
3e14bd6
Remove unused constructor declared with typo in name: __int__.
RyanJDick Feb 14, 2024
a5c94fb
Delete unused functions from shared_invokeai_diffusion.py.
RyanJDick Feb 15, 2024
5b3adf0
Remove unused code for attention map saving.
RyanJDick Feb 15, 2024
ffc4ebb
Merge branch 'ryan/remove-attention-map-saving' into ryan/regional-co…
RyanJDick Mar 1, 2024
942efa0
Implement (very slow) self-attention regional masking.
RyanJDick Mar 1, 2024
ad18429
Very experimentation with various regional prompting tuning params.
RyanJDick Mar 2, 2024
5fad379
Add ability to control regional prompt region weights.
RyanJDick Mar 3, 2024
271f8f2
Merge branch 'main' into ryan/regional-conditioning-tuning
RyanJDick Mar 4, 2024
d313e5e
Remove AddConditioningMaskInvocaton.
RyanJDick Mar 4, 2024
a665f20
Add positive_self_attn_mask_score and self_attn_adjustment_end_step_p…
RyanJDick Mar 4, 2024
bcfb43e
(minor) Remove commented code.
RyanJDick Mar 5, 2024
41e1a9f
Use the correct device / dtype for RegionalPromptData calculations.
RyanJDick Mar 5, 2024
57266d3
Remove dispatch_progress() function that was added aciidentally durin…
RyanJDick Mar 5, 2024
b5c334d
Fix _negative_cross_attn_mask_score.
RyanJDick Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
MaskField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
Expand Down Expand Up @@ -36,7 +44,7 @@
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.0.1",
version="1.1.0",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
Expand All @@ -51,6 +59,12 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
Expand Down Expand Up @@ -118,7 +132,15 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

conditioning_name = context.conditioning.save(conditioning_data)

return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
)
)


class SDXLPromptInvocationBase:
Expand Down Expand Up @@ -232,7 +254,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.0.1",
version="1.1.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
Expand All @@ -256,6 +278,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")

mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
positive_cross_attn_mask_score: float = InputField(default=0.0, description="")
positive_self_attn_mask_score: float = InputField(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = InputField(default=0.0, description="")

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
Expand Down Expand Up @@ -317,7 +346,15 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:

conditioning_name = context.conditioning.save(conditioning_data)

return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
positive_cross_attn_mask_score=self.positive_cross_attn_mask_score,
positive_self_attn_mask_score=self.positive_self_attn_mask_score,
self_attn_adjustment_end_step_percent=self.self_attn_adjustment_end_step_percent,
)
)


@invocation(
Expand Down Expand Up @@ -366,7 +403,7 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:

conditioning_name = context.conditioning.save(conditioning_data)

return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name))


@invocation_output("clip_skip_output")
Expand Down
40 changes: 40 additions & 0 deletions invokeai/app/invocations/conditioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InvocationContext,
invocation,
)
from invokeai.app.invocations.fields import InputField, WithMetadata
from invokeai.app.invocations.primitives import MaskField, MaskOutput


@invocation(
"rectangle_mask",
title="Create Rectangle Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"""Create a rectangular mask."""

height: int = InputField(description="The height of the entire mask.")
width: int = InputField(description="The width of the entire mask.")
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
rectangle_width: int = InputField(description="The width of the rectangular masked region.")

def invoke(self, context: InvocationContext) -> MaskOutput:
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
mask[
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
] = True

mask_name = context.tensors.save(mask)
return MaskOutput(
mask=MaskField(mask_name=mask_name),
width=self.width,
height=self.height,
)
19 changes: 18 additions & 1 deletion invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ class BoardField(BaseModel):
board_id: str = Field(description="The id of the board")


class MaskField(BaseModel):
"""A mask primitive field."""

mask_name: str = Field(description="The name of the mask.")


class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""

Expand Down Expand Up @@ -225,7 +231,18 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

conditioning_name: str = Field(description="The name of conditioning tensor")
# endregion
mask: Optional[MaskField] = Field(
default=None,
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
positive_cross_attn_mask_score: float = Field(
default=0.0,
# TODO(ryand): Add more details to this description
description="The weight of this conditioning tensor's mask relative to overlapping masks.",
)
positive_self_attn_mask_score: float = Field(default=1.0, description="")
self_attn_adjustment_end_step_percent: float = Field(default=0.0, description="")


class MetadataField(RootModel):
Expand Down
Loading
Loading