Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image Language Models and ImageGeneration task #1060

Open
wants to merge 42 commits into
base: develop
Choose a base branch
from

Conversation

plaguss
Copy link
Contributor

@plaguss plaguss commented Nov 14, 2024

Description

This PR adds a new module to models: models/image_generation to store image models (InferenceEndpointsImageGeneration and OpenAIImageGeneration), with 2 new base classes: ImageGenerationModel and AsyncImageGenerationModel, and a new ImageGeneration task.

Sample pipeline and dataset. Take into account the distiset.transform_columns_to_image method, necessary to push the dataset with the images as objects instead of strings.

from datasets import load_dataset

from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.steps.tasks import ImageGeneration

ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))

with Pipeline(name="image_generation_pipeline") as pipeline:
    igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell")

    img_generation = ImageGeneration(
        name="flux_schnell", image_generation_model=igm, input_mappings={"prompt": "persona"}
    )

    keep_columns = KeepColumns(columns=["persona", "model_name", "image"])

    img_generation >> keep_columns


if __name__ == "__main__":
    distiset = pipeline.run(use_cache=False, dataset=ds)
    # Save the images as `PIL.Image.Image`
    distiset = distiset.transform_columns_to_image("image")
    distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell")

@plaguss plaguss added the enhancement New feature or request label Nov 14, 2024
@plaguss plaguss added this to the 1.5.0 milestone Nov 14, 2024
@plaguss plaguss self-assigned this Nov 14, 2024
@plaguss plaguss requested a review from gabrielmbmb November 14, 2024 11:58
Copy link

Documentation for this PR has been built. You can view it at: https://distilabel.argilla.io/pr-1060/

Copy link

codspeed-hq bot commented Nov 14, 2024

CodSpeed Performance Report

Merging #1060 will not alter performance

Comparing vision-language-models (e9e6790) with develop (a8d02c2)

Summary

✅ 1 untouched benchmarks

@plaguss plaguss marked this pull request as ready for review November 15, 2024 08:24
@plaguss plaguss requested a review from dvsrepo November 15, 2024 11:51
@davidberenstein1957
Copy link
Member

@burtenshaw

if __name__ == "__main__":
distiset = pipeline.run(use_cache=False, dataset=ds)
# Save the images as `PIL.Image.Image`
+ distiset = distiset.transform_columns_to_image("image")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
+ distiset = distiset.transform_columns_to_image("image")
distiset = distiset.transform_columns_to_image("image")

@@ -102,6 +102,7 @@ text-clustering = [
"scikit-learn >= 1.4.1",
"matplotlib >= 3.8.3", # For the figure (even though it's optional)
]
vision = ["Pillow >= 10.3.0"] # To work with images.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just go a pil error due to some imports

/Users/davidberenstein/Documents/programming/argilla/distilabel/.venv/lib/python3.11/site-packag │
│ es/distilabel/steps/tasks/text_generation_with_image.py:18 in <module>                           │
│                                                                                                  │
│    15 from typing import TYPE_CHECKING, Any, Literal, Union                                      │
│    16                                                                                            │
│    17 from jinja2 import Template                                                                │
│ ❱  18 from PIL import Image                                                                      │
│    19 from pydantic import Field                                                                 │
│    20                                                                                            │
│    21 from distilabel.steps.tasks.base import Task                                               │
│                                                                                                  │
│ ╭──────────── locals ────────────╮                                                               │
│ │       Literal = typing.Literal │                                                               │
│ │ TYPE_CHECKING = False          │                                                               │
│ │         Union = typing.Union   │                                                               │
│ ╰────────────────────────────────╯                                                               │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ModuleNotFoundError: No module named 'PIL'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants