Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] The Modular Diffusers #9672

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

[WIP] The Modular Diffusers #9672

wants to merge 8 commits into from

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Oct 14, 2024

Modular Diffusers

This PR experiments with some initial designs for a modular pipeline building system that we plan to support in diffusers officially.

Key components

CustomPipelineBuilder

CustomPipelineBuilder is the main user interface for creating and running custom pipelines. An example usage

     from diffusers import CustomPipelineBuilder
     builder = CustomPipelineBuilder("SDXL")
     builder.add_blocks(encode_prompt_block, prepare_latent_block, ...)
     builder.run_pipeline("a cat", num_inference_steps = 15)

PipelineBlock

PipelineBlock is the building block for the custom pipelines. Each user-defined block has to inherit from PipelineBlock, it should:

  • define its components, inputs, and outputs
  • define a __call__ method that performs a specific part of the pipeline's operation. It should always take and return the same two variables: pipeline (CustomPipeline) and state (PipelineState)
  • can be easily composed and rearranged to create custom pipeline flows

Define a PipelineBlock

You will need to implement new features using PipelineBlock. Here is an example on how how to write a PipelineBlock that encodes the text prompts

class TextEncoderStep(PipelineBlock):
    
   # [1]. define components: 
   #  specify all the model components used in this bloc.
   #  There are two attributes you need to define:`optional_components` and `required_components`
   #  if the block can run without a component, you should list that as optional (this logic is aligned with the _optional_component attribute in our regular pipeline). 
   #  in our example, since this block can optionally take already encoded text prompts as inputs, all the text_encoders and tokenizers are optional
    optional_components = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]

   # [2] define inputs and their default value:
   #   inputs is equivalent to the `__call__` parameters in our regular pipeline; only include the inputs needed in this block
   @property
    def inputs(self) -> List[Tuple[str, Any]]:
        return [
            ("prompt", None),
            ("prompt_2", None),
            ("negative_prompt", None),
            ("negative_prompt_2", None),
            ("cross_attention_kwargs", None),
            ("prompt_embeds", None),
            ("negative_prompt_embeds", None),
            ("pooled_prompt_embeds", None),
            ("negative_pooled_prompt_embeds", None),
            ("num_images_per_prompt", 1),
            ("guidance_scale", 5.0),
            ("clip_skip", None),
        ]

   # [3] define intermediates inputs/outputs: they should be output/inputs from/for other blocks
   # [TO-DO]: see if I can remove the need to specify intermediates, and only use `inputs` and `outputs`
    @property
    def intermediates_outputs(self) -> List[str]:
        return [
            "prompt_embeds",
            "negative_prompt_embeds",
            "pooled_prompt_embeds",
            "negative_pooled_prompt_embeds",
        ]

    def __init__(
        self,
        text_encoder: Optional[CLIPTextModel] = None,
        text_encoder_2: Optional[CLIPTextModelWithProjection] = None,
        tokenizer: Optional[CLIPTokenizer] = None,
        tokenizer_2: Optional[CLIPTokenizer] = None,
        force_zeros_for_empty_prompt: bool = True,
    ):
        super().__init__(
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
        )

    @torch.no_grad()
    def __call__(self, pipeline, state: PipelineState) -> PipelineState:
        # [4] implement the __call__ method!
        # it should be very similar to how you would write the __call__ method for our regular pipelines, except a few differences:
           # (1). you need to explicitly get inputs from state, and then add the output to states 
           # (2). all the model components are stored on the pipeline object, e.g., you need to use `pipeline.text_encoder` to get your text_encoder!
           # (3)  all the common pipeline methods defined in our official pipelines are accessible in the `pipeline` object, e.g., you should use `pipeline.encode_prompt`, instead of `self.encode_prompt`
           # (4) you can define custom methods on the block too, e.g., we defined a `check_input` method to check text encoder inputs, and we can use it like this `self.check_inputs()`
        prompt = state.get_input("prompt")
        prompt_2 = state.get_input("prompt_2")
        negative_prompt = state.get_input("negative_prompt")
        negative_prompt_2 = state.get_input("negative_prompt_2")
        cross_attention_kwargs = state.get_input("cross_attention_kwargs")
        prompt_embeds = state.get_input("prompt_embeds")
        negative_prompt_embeds = state.get_input("negative_prompt_embeds")
        pooled_prompt_embeds = state.get_input("pooled_prompt_embeds")
        negative_pooled_prompt_embeds = state.get_input("negative_pooled_prompt_embeds")
        num_images_per_prompt = state.get_input("num_images_per_prompt")
        guidance_scale = state.get_input("guidance_scale")
        clip_skip = state.get_input("clip_skip")

        do_classifier_free_guidance = guidance_scale > 1.0
        device = pipeline._execution_device

       self.check_inputs(...)

        text_encoder_lora_scale = (
            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        )
        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = pipeline.encode_prompt(...)
       
        state.add_intermediate("prompt_embeds", prompt_embeds)
        state.add_intermediate("negative_prompt_embeds", negative_prompt_embeds)
        state.add_intermediate("pooled_prompt_embeds", pooled_prompt_embeds)
        state.add_intermediate("negative_pooled_prompt_embeds", negative_pooled_prompt_embeds)
        return pipeline, state

We will also release pre-defined pipeline blocks so that you can build all the official pipelines that we maintain with the pipeline builder system. e.g, for SDXL, in this PR, we have a set of pipeline blocks such as InputStep,TextEncoderStep,SetTimestepsStep,PrepareLatentsStep,PrepareAdditionalConditioningStep,PrepareGuidance,DenoiseStep,DecodeLatentsStep

PipelineState

A new PipelineState is created when you run builder.run_pipeline(), It will then be passed through each blocks, and the state of the pipeline throughout its execution, including inputs, outputs, and all the intermediate inputs/outputs.

CustomPipeline

CustomPipeline is the base class for all custom pipelines built using CustomPipelineBuilder. It is used only as a container for pipeline components, pipeline-level config/attributes, and common pipeline methods.

Note that unlike DiffusionPipeline, CustomPipeline does not handle loading and saving; it also does not have a __call__ method implemented and should only be run through CustomPipellineBuilder.run_pipeline method.

At run time, the builder will pass the CustomPipeline object to each block so these pipeline-level methods and components can be used within these blocks.

The main motivations to have this class are:

  1. different pipeline blocks may use the same components, so it is better to host them on pipeline-level
  2. to be able to very easily re-use the methods that are currently implemented on our regular pipelines, e.g. we can use #Copied from statement to copy all the methods to the SDXLCustomPipeline without re-implement them for blocks.
  3. this class can be combined with the builder class into one, i.e. it can just be a ModularDifussionPipeine that can build itself - we can iterate on this design later

Overall Objectives

a user-friendly API to compose a pipeline

  • It will be released as an experimental feature and iterated with the community!

performance

  • all the optimization methods currently working on our regular pipelines should work on the custom pipeline (offloading, device map, etc), and without any performance difference whatsoever.
  • will add a comprehensive test suit for all the pipeline blocks that we support and make sure the result is the same as our regular pipelines

community

  • we should support loading custom pipeline blocks from the hub
  • serialization - allow saving, loading, and sharing of the workflows (in the form of custom pipelines in our case!)

Testing this PR

To build a pre-defined pipeline block

We will implement a from_pretrained() method on PipelineBlock that allows you to load the pipeline block from a hub repo, similarly to how you would load a [DiffusionPipeline]. For now, we need to load a DiffusionPipeline first and reuse its components and configuration to initiate a pipeline block

import torch
from diffusers import StableDiffusionXLPipeline
device = torch.device("cuda")
# step1. create SDXL pipeline so we can reuse its components for the custom pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)
pipe = pipe.to(device)

Once we have that there are two ways to create a PipelineBlock, we will use the TextEncoderStep as an example

method1: create it by passing all the __init__ arguments

from diffusers.pipelines.custom_pipeline_builder import TextEncoderStep
encode_prompt = TextEncoderStep(
    text_encoder=pipe.text_encoder, 
    text_encoder_2=pipe.text_encoder_2, 
    tokenizer=pipe.tokenizer, 
    tokenizer_2=pipe.tokenizer_2
)

method2: use from_pipe API (recommended)

from diffusers.pipelines.custom_pipeline_builder import TextEncoderStep
encode_prompt = TextEncoderStep.from_pipe(pipe)

You can print out the PipelineBlock object to get information about its components, configuration, as well as its inputs/outputs

TextEncoderStep(
  components: text_encoder=CLIPTextModel, text_encoder_2=CLIPTextModelWithProjection, tokenizer=CLIPTokenizer, tokenizer_2=CLIPTokenizer
  auxiliaries: 
  configs: force_zeros_for_empty_prompt=True
  inputs: prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, num_images_per_prompt=1, guidance_scale=5.0, clip_skip=None
  intermediates_inputs: 
  intermediates_outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
)

To run a pipeline block

builder.add_blocks()
builder.run_blocks()

Our pipeline blocks are designed to be composed with other pipeline blocks, so we designed a CustomPipelineBuilder class that is responsible for composing the blocks together and running them in the correct order with the correct inputs. You can use the builder to run the pipeline block in the standalone manner as well.

encode_prompt example

Let's first take a look at how to run the encode_prompt block on its own because it is pretty common to generate prompt embeddings as a separate step

First, create the builder and add the block

from diffusers.pipelines.custom_pipeline_builder import CustomPipelineBuilder
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(encode_prompt)
print(builder)

print(builder) will give you information about what the builder has built so far, i.e., you can find out information about the pipeline blocks and each block's output, as well as their components; It also puts together a list of "call parameters" so you know which argument you need to pass to run that block.

CustomPipeline Configuration:
==============================

Pipeline Blocks:
----------------
1. TextEncoderStep
   -> prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds


Registered Components:
----------------------
text_encoder: CLIPTextModel
text_encoder_2: CLIPTextModelWithProjection
tokenizer: CLIPTokenizer
tokenizer_2: CLIPTokenizer

Default Call Parameters:
------------------------
prompt: None
prompt_2: None
negative_prompt: None
negative_prompt_2: None
cross_attention_kwargs: None
prompt_embeds: None
negative_prompt_embeds: None
pooled_prompt_embeds: None
negative_pooled_prompt_embeds: None
num_images_per_prompt: 1
guidance_scale: 5.0
clip_skip: None

Required Call Parameters:
--------------------------

Note: These are the default values. Actual values may be different when running the pipeline.

let's define a prompt and use run_blocks method to run the block!

prompt = "a photo of an astronaut riding a horse on mars"
state = builder.run_blocks(prompt=prompt)
print(state)

The run_blocks method always returns the entire pipeline state. You can get the specific tensor outputs using state.get_intermediates(), e.g., we can use state.get_intermediate("prompt_embeds") to get prompt_embeds

PipelineState(
  inputs={
    prompt: 'a photo of an astronaut riding a horse on mars'
    prompt_2: None
    negative_prompt: None
    negative_prompt_2: None
    cross_attention_kwargs: None
    prompt_embeds: None
    negative_prompt_embeds: None
    pooled_prompt_embeds: None
    negative_pooled_prompt_embeds: None
    num_images_per_prompt: 1
    guidance_scale: 5.0
    clip_skip: None
  },
  intermediates={
    prompt_embeds: Tensor(
      dtype=torch.float16, shape=torch.Size([1, 77, 2048])
      tensor([[[-3.8926, -2.5137,  4.7148,  ...,  0.1898,  0.4189, -0.2971],
         [ 0.0889, -0.6201, -0.4875,  ...,  0.5005, -0.0376, -0.1573],
         [ 0.7329,  0.4199, -0.1284,  ...,  0.5713,  0.7275,  0.2302],
         ...,
         [-0.6655,  0.7178, -0.4092,  ...,  0.1685,  0.5654,  0.1741],
         [-0.6685,  0.7197, -0.4263,  ...,  0.0764,  0.4136,  0.0657],
         [-0.6382,  0.7500, -0.4238,  ..., -0.0156,  0.4626,  0.0447]]],
       device='cuda:0', dtype=torch.float16))
    negative_prompt_embeds: Tensor(
      dtype=torch.float16, shape=torch.Size([1, 77, 2048])
      tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.float16))
    pooled_prompt_embeds: Tensor(
      dtype=torch.float16, shape=torch.Size([1, 1280])
      tensor([[ 1.2764,  0.5522,  0.4302,  ..., -0.8579, -0.0495,  0.0177]],
       device='cuda:0', dtype=torch.float16))
    negative_pooled_prompt_embeds: Tensor(
      dtype=torch.float16, shape=torch.Size([1, 1280])
      tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16))
  },
  outputs={

  }
)

You can get the default call parameters in a dict with builder.default_call_parameters, edit it, and pass it directly to run_blocks.

call_params = builder.default_call_parameters
call_params['prompt'] = " a cat"
call_params["guidance_scale"] = 7.0
state = builder.run_blocks(**call_params)

decode_latent example

Let's take another decode_latent example where we pass the generated latent to the block to get the image

from diffusers.pipelines.custom_pipeline_builder import DecodeLatentsStep
decoder_step = DecodeLatentsStep.from_pipe(pipe)
print(f" decoder_step: {decoder_step}")
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(decoder_step)
print(builder)

based on the printed-out builder info, we know the DecodeLatentsStep takes latents input and output images, it comes with a vae components, and has two optional call arguments output_type and return_dict (these are our standard pipeline call parameters); it also has a required argument latents which we have to pass to run_blocks

CustomPipeline Configuration:
==============================

Pipeline Blocks:
----------------
1. DecodeLatentsStep
   latents -> images


Registered Components:
----------------------
vae: AutoencoderKL

Default Call Parameters:
------------------------
output_type: 'pil'
return_dict: True

Required Call Parameters:
--------------------------
latents: 

Note: These are the default values. Actual values may be different when running the pipeline.

to decode and get the image, you can run

state = builder.run_blocks(latents=latents)
state.get_intermediates("images")

Build a Modular Pipeline Incrementally

Using CustomPipelineBuilder, you can build a pipeline block, test it out using the process we describe in the last section (builder.add_block() + builder.run_blocks()); and then move on to building the next block, repeat the same process. Note that the run_blocks method also takes state argument, so you can just take the state output from last step and pass it to the next step - we recommend to use PipelineState to manage your inputs between the pipeline block runs.

builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(encode_prompt)
state = builder.run_blocks(prompt=prompt)

builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(set_timesteps)
state = builder.run_blocks(state)
...

here is an example of how to build the SDXL text2img pipeline incrementally

example code

from diffusers.pipelines.custom_pipeline_builder import CustomPipelineBuilder, TextEncoderStep, SetTimestepsStep, PrepareLatentsStep, PrepareAdditionalConditioningStep, PrepareGuidance, DenoiseStep, DecodeLatentsStep
import torch
from diffusers import StableDiffusionXLPipeline
device = torch.device("cuda")
# step1. create SDXL pipeline so we can reuse its components for the custom pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)
pipe = pipe.to(device)


# (1) encode_prompt
encode_prompt = TextEncoderStep.from_pipe(pipe)
print(f" encode_prompt: {encode_prompt}")
generator = torch.Generator(device="cuda").manual_seed(0)   
prompt = "a photo of an astronaut riding a horse on mars"
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(encode_prompt)
print(builder)

state = builder.run_blocks(prompt=prompt)
print(f"state: {state}")

# set_timesteps
print("* " * 10)
set_timesteps = SetTimestepsStep.from_pipe(pipe)
print(f" set_timesteps: {set_timesteps}")
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(set_timesteps)
print(builder)

# You can pass the previous state (including the output from the encode_prompts) to the next block
# In this example, the `prompt_embeds` are not used in `set_timesteps` step so it will just be kept in the state
# and passed down to next, and next and eventually be used by one of the future steps that needs it 

state = builder.run_blocks(num_inference_steps=28, state=state)
print(f"state: {state}")


# prepare_latents
print("* " * 10)
prepare_latents = PrepareLatentsStep.from_pipe(pipe)
print(f" prepare_latents: {prepare_latents}")
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(prepare_latents)
print(builder)
state = builder.run_blocks(batch_size=1, dtype=torch.float16, device=device, generator=generator, state=state)
print(f"state: {state}")
# latents = state.get_intermediate("latents")

# prepare_add_cond
print("* " * 10)
from diffusers.pipelines.custom_pipeline_builder import PrepareAdditionalConditioningStep
prepare_add_cond = PrepareAdditionalConditioningStep.from_pipe(pipe)
print(f" prepare_add_cond: {prepare_add_cond}")
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(prepare_add_cond)
print(builder)
state = builder.run_blocks(state=state)
print(f"state: {state}")

# prepare_guidance
print("* " * 10)
prepare_guidance = PrepareGuidance.from_pipe(pipe)
print(f" prepare_guidance: {prepare_guidance}")
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(prepare_guidance)
print(builder)
state = builder.run_blocks(state=state)
print(f"state: {state}")

# denoise_step
print("* " * 10)
denoise_step = DenoiseStep.from_pipe(pipe)
print(f" denoise_step: {denoise_step}")
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(denoise_step)
print(builder)
state = builder.run_blocks(state=state)
latents = state.get_intermediate("latents")
print(f"state: {state}")

# decoder_step
print("* " * 10)
from diffusers.pipelines.custom_pipeline_builder import DecodeLatentsStep
decoder_step = DecodeLatentsStep.from_pipe(pipe)
print(f" decoder_step: {decoder_step}")
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(decoder_step)
print(builder)
state = builder.run_blocks(state=state)
print(f"state: {state}")

images = state.get_output("images").images[0]
images.save("yiyi_test_10_out.png")

You can also built a "partial pipeline" with a subset of blocks and test as you go

builder = CustomPipelineBuilder("SDXL")
builder.add_blocks([encode_prompt, set_timesteps, prepare_latents])
print(builder)
builder.run_blocks(prompt=prompt, generator=generator)

a complete SDXL text2img pipeline looks like this after it is built

>>> builder
CustomPipeline Configuration:
==============================

Pipeline Blocks:
----------------
1. InputStep
   -> batch_size

2. TextEncoderStep
   -> prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

3. SetTimestepsStep
   -> timesteps, num_inference_steps

4. PrepareLatentsStep
   batch_size -> latents

5. PrepareAdditionalConditioningStep
   latents, batch_size, pooled_prompt_embeds -> add_time_ids, negative_add_time_ids, timestep_cond

6. PrepareGuidance
   add_time_ids, negative_add_time_ids, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds -> add_text_embeds, add_time_ids, prompt_embeds

7. DenoiseStep
   latents, timesteps, num_inference_steps, add_text_embeds, add_time_ids, timestep_cond, prompt_embeds -> latents

8. DecodeLatentsStep
   latents -> images


Registered Components:
----------------------
text_encoder: CLIPTextModel
text_encoder_2: CLIPTextModelWithProjection
tokenizer: CLIPTokenizer
tokenizer_2: CLIPTokenizer
scheduler: EulerDiscreteScheduler
unet: UNet2DConditionModel
vae: AutoencoderKL

Default Call Parameters:
------------------------
prompt: None
prompt_embeds: None
prompt_2: None
negative_prompt: None
negative_prompt_2: None
cross_attention_kwargs: None
negative_prompt_embeds: None
pooled_prompt_embeds: None
negative_pooled_prompt_embeds: None
num_images_per_prompt: 1
guidance_scale: 5.0
clip_skip: None
num_inference_steps: 50
timesteps: None
sigmas: None
denoising_end: None
height: None
width: None
generator: None
latents: None
device: None
dtype: None
original_size: None
target_size: None
negative_original_size: None
negative_target_size: None
crops_coords_top_left: (0, 0)
negative_crops_coords_top_left: (0, 0)
guidance_rescale: 0.0
eta: 0.0
output_type: 'pil'
return_dict: True

Required Call Parameters:
--------------------------

Note: These are the default values. Actual values may be different when running the pipeline.

builder.run_pipeline()

if you already built all the pipeline blocks and want to use them to run inference, you can add all the blocks to the builder at once using add_blocks() and use run_pipeline to get the generated image

builder.add_blocks([.....])
generator = torch.Generator(device="cuda").manual_seed(0)   
out =builder.run_pipeline(prompt="a photo of an astronaut riding a horse on mars", generator=generator)

More Pipeline Examples: text2img, img2img, controlnet etc

text2Img

build + run a text-to-image pipeline


from diffusers.pipelines.custom_pipeline_builder import CustomPipelineBuilder, InputStep, TextEncoderStep, SetTimestepsStep, PrepareLatentsStep, PrepareAdditionalConditioningStep, PrepareGuidance, DenoiseStep, DecodeLatentsStep
import torch
from diffusers import StableDiffusionXLPipeline
device = "cuda"
# step1. create SDXL pipeline so we can reuse its components for the custom pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)
pipe = pipe.to(device)
# step2: create a set of pipeline blocks
prepare_input = InputStep.from_pipe(pipe    )
encode_prompt = TextEncoderStep.from_pipe(pipe)
set_timesteps = SetTimestepsStep.from_pipe(pipe)
prepare_latents = PrepareLatentsStep.from_pipe(pipe )
prepare_add_cond = PrepareAdditionalConditioningStep.from_pipe(pipe)
prepare_guidance = PrepareGuidance.from_pipe(pipe)
denoise_step = DenoiseStep.from_pipe(pipe)
decoder_step =DecodeLatentsStep.from_pipe(pipe)
# step3: create a builder, add all the pipeline blocks to it and run the built pipeline
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(
    [prepare_input,
     encode_prompt, 
     set_timesteps, 
     prepare_latents, 
     prepare_add_cond, 
     prepare_guidance,
     denoise_step,
     decoder_step
     ]
     )
generator = torch.Generator(device="cuda").manual_seed(0)   
out =builder.run_pipeline(prompt="a photo of an astronaut riding a horse on mars", generator=generator)
out.images[0].save(f"yiyi_text2img_modular.png")
Img2Img

img2img has its blocks for set_timesteps, prepare_latents and prepare_add_cond:

  • Image2ImageSetTimestepsStep
  • Image2ImagePrepareLatentsStep
  • Image2ImagePrepareAdditionalConditioningStep

To build an Img2img pipeline with our pipeline building system, this is the only code change needed

- set_timesteps = SetTimestepsStep.from_pipe(pipe)
+ set_timesteps = Image2ImageSetTimestepsStep.from_pipe(pipe)
- prepare_latents = PrepareLatentsStep.from_pipe(pipe )
+ prepare_latents = Image2ImagePrepareLatentsStep.from_pipe(pipe)
- prepare_add_cond = PrepareAdditionalConditioningStep.from_pipe(pipe)
+ prepare_add_cond = Image2ImagePrepareAdditionalConditioningStep.from_pipe(pipe)

to make the pipeline:


from diffusers.pipelines.custom_pipeline_builder import CustomPipelineBuilder, InputStep, TextEncoderStep, Image2ImageSetTimestepsStep, Image2ImagePrepareLatentsStep, Image2ImagePrepareAdditionalConditioningStep, PrepareGuidance, DenoiseStep, DecodeLatentsStep
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image

url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
init_image = load_image(url).convert("RGB")
strength = 0.9

device = "cuda"
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
)
pipe = pipe.to(device)
prepare_input = InputStep.from_pipe(pipe)
encode_prompt = TextEncoderStep.from_pipe(pipe)
set_timesteps = Image2ImageSetTimestepsStep.from_pipe(pipe)
prepare_latents = Image2ImagePrepareLatentsStep.from_pipe(pipe)
prepare_add_cond = Image2ImagePrepareAdditionalConditioningStep.from_pipe(pipe)
prepare_guidance = PrepareGuidance.from_pipe(pipe)
denoise_step = DenoiseStep.from_pipe(pipe)
decoder_step =DecodeLatentsStep.from_pipe(pipe)

builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(
    [prepare_input,
     encode_prompt, 
     set_timesteps, 
     prepare_latents, 
     prepare_add_cond, 
     prepare_guidance,
     denoise_step,
     decoder_step
     ]
     )
print(builder)

to run the modular img2img pipeline


generator = torch.Generator(device="cuda").manual_seed(0)   
out =builder.run_pipeline(prompt="a photo of an astronaut riding a horse on mars", image=init_image, strength=strength, generator=generator)
out.images[0].save("yiyi_img2img_modular.png")
controlnet

To add controlnet to any existing pipeline, you need to replace the DenoiseStep with ControlnetDenoiseStep

- denoise_step = DenoiseStep.from_pipe(pipe)
+ denoise_step = ControlNetDenoiseStep.from_pipe(pipe, controlnet=controlnet)

to add controlnet to text-to-image pipeline


from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.pipelines.custom_pipeline_builder import CustomPipelineBuilder, InputStep, TextEncoderStep, SetTimestepsStep, PrepareLatentsStep, PrepareAdditionalConditioningStep, PrepareGuidance, ControlNetDenoiseStep, DecodeLatentsStep

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
).to("cuda")
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
).to("cuda")

prepare_input = InputStep.from_pipe(pipe)
encode_prompt = TextEncoderStep.from_pipe(pipe)
set_timesteps = SetTimestepsStep.from_pipe(pipe)
prepare_latents = PrepareLatentsStep.from_pipe(pipe )
prepare_add_cond = PrepareAdditionalConditioningStep.from_pipe(pipe)
prepare_guidance = PrepareGuidance.from_pipe(pipe)
denoise_step = ControlNetDenoiseStep.from_pipe(pipe, controlnet=controlnet)
decoder_step =DecodeLatentsStep.from_pipe(pipe)

builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(
    [prepare_input,
     encode_prompt, 
     set_timesteps, 
     prepare_latents, 
     prepare_add_cond, 
     prepare_guidance,
     denoise_step,
     decoder_step
     ]
     )
print(builder)

to run the controlnet text-to-image pipeline



# prepare inputs
from diffusers.utils import load_image
import numpy as np
import torch
import cv2
from PIL import Image
# prepare canny image
image = load_image(
    "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
# prepare prompt
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = "low quality, bad quality, sketches"
# prepare controlnet conditioning scale
controlnet_conditioning_scale = 0.5  # recommended for good generalization

generator = torch.Generator(device="cuda").manual_seed(0)   
out =builder.run_pipeline(prompt=prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, control_image=canny_image, generator=generator)
# state = builder.run_blocks(prompt=prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, control_image=canny_image, generator=generator)
out.images[0].save(f"controlnet_modular_out.png")

to build + run controlnet with img2img


import torch
from PIL import Image
import numpy as np
from transformers import DPTImageProcessor, DPTForDepthEstimation
from diffusers.utils import load_image
from diffusers.pipelines.custom_pipeline_builder import CustomPipelineBuilder, InputStep, TextEncoderStep, Image2ImageSetTimestepsStep, Image2ImagePrepareLatentsStep, Image2ImagePrepareAdditionalConditioningStep, PrepareGuidance, ControlNetDenoiseStep, DecodeLatentsStep
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline

depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
def get_depth_map(image):
    image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
    with torch.no_grad(), torch.autocast("cuda"):
        depth_map = depth_estimator(image).predicted_depth

    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=(1024, 1024),
        mode="bicubic",
        align_corners=False,
    )
    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = (depth_map - depth_min) / (depth_max - depth_min)
    image = torch.cat([depth_map] * 3, dim=1)
    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
    return image

image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
    "/kandinsky/cat.png"
).resize((1024, 1024))
depth_image = get_depth_map(image)
prompt = "A robot, 4k photo"
controlnet_conditioning_scale = 0.5  # recommended for good generalization

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-depth-sdxl-1.0-small",
    variant="fp16",
    use_safetensors=True,
    torch_dtype=torch.float16,
)
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    variant="fp16",
    torch_dtype=torch.float16,
)
pipe.to("cuda")

prepare_input = InputStep.from_pipe(pipe)
encode_prompt = TextEncoderStep.from_pipe(pipe)
set_timesteps = Image2ImageSetTimestepsStep.from_pipe(pipe)
prepare_latents = Image2ImagePrepareLatentsStep.from_pipe(pipe )
prepare_add_cond = Image2ImagePrepareAdditionalConditioningStep.from_pipe(pipe)
prepare_guidance = PrepareGuidance.from_pipe(pipe)
denoise_step = ControlNetDenoiseStep.from_pipe(pipe, controlnet=controlnet)
decoder_step =DecodeLatentsStep.from_pipe(pipe)
builder = CustomPipelineBuilder("SDXL")
builder.add_blocks(
    [prepare_input,
     encode_prompt, 
     set_timesteps, 
     prepare_latents, 
     prepare_add_cond, 
     prepare_guidance,
     denoise_step,
     decoder_step
     ]
     )
print(builder)

generator = torch.Generator(device="cuda").manual_seed(0)   
out =builder.run_pipeline(
    prompt=prompt, 
    image=image,
    generator=generator,
    control_image=depth_image,
    strength=0.99,
    num_inference_steps=50,
    controlnet_conditioning_scale=controlnet_conditioning_scale,
    )
control_image=canny_image, generator=generator)
out.images[0].save(f"controlnet_img2img_modular.png")

To-Dos

the PR is in a very early stage, so there are a lot of to-dos left, I will just list a few that I'm working on next

  • add a from_pipe for PipelineBlock
  • add a from_pretrained for PipelineBlock
  • make a from_pipe/from_pretrained method for CustomPipelineBuilder so that you can create a regular pipeline and pass it to the builder. The builder will map the pipeline to a set of pre-defined pipeline blocks and automatically create the pipeline. You can then use it as a starting point to build your custom pipeline, e.g.
    # create the SDXL pipeline with from_pretrained
    pipe_sdxl = StableDiffusionXLPipeline.from_pretrained()
    # this just map the pipeline to a set of blocks, creating them and then do CustomPipelineBuilder.add_blocks(..) 
    builder = CustomPipelineBuilder.from_pipe(pipe_sdxl)
    # from there you can add and subtract and replace the blocks to make your new pipeline
    builder.add_blocks()
    builder.remove_blocks()
  • img2img
  • controlnet
  • pag
  • inpaint
  • work on guide class to see if we can make it robust and generalize different guidance methods, e.g. PAG
  • memory management for PipelineState - this feature should only apply to run_pipeline(not run_blocks). Since the custom pipeline blocks may not be written in the most memory-efficient way, e.g., user could add too many intermediate variables that are not needed, we can add some guard rails to prevent that
  • see if we can unity "inputs" and "intermediate inputs" - basically a input required by the block but not provided by the previous blocks should be a call parameter for the builder. So I think we should be able to remove the difference and simplify things

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yoland68
Copy link

Very cool!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants