Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review TextClassification task #1073

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions src/distilabel/steps/tasks/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@

import orjson
from jinja2 import Template
from pydantic import BaseModel, Field, PositiveInt, PrivateAttr
from typing_extensions import override

from pydantic import (
BaseModel,
Field,
PositiveInt,
PrivateAttr,
model_validator,
)
from typing_extensions import Self, override

from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks import Task

if TYPE_CHECKING:
Expand All @@ -41,7 +48,7 @@
```

## Output Format
Now, please give me the labels in JSON format, do not include any other text in your response:
Now, {{ output_message }}, do not include any other text in your response:
```
{
"labels": {{ labels_format }}
Expand Down Expand Up @@ -74,16 +81,18 @@ class TextClassification(Task):
Attributes:
system_prompt: A prompt to display to the user before the task starts. Contains a default
message to make the model behave like a classifier specialist.
n: Number of labels to generate If only 1 is required, corresponds to a label
n: Number of labels to generate. If only 1 is required, corresponds to a label
classification problem, if >1 it will intend return the "n" labels most representative
for the text. Defaults to 1.
is_multilabel: Indicates whether the task allows multiple labels or a single label.
context: Context to use when generating the labels. By default contains a generic message,
but can be used to customize the context for the task.
examples: List of examples to help the model understand the task, few shots.
available_labels: List of available labels to choose from when classifying the text, or
a dictionary with the labels and their descriptions.
default_label: Default label to use when the text is ambiguous or lacks sufficient information for
classification. Can be a list in case of multiple labels (n>1).
query_title: Title of the query used to show the example/s to classify.

Examples:
Assigning a sentiment to a text:
Expand Down Expand Up @@ -128,7 +137,7 @@ class TextClassification(Task):

text_classification = TextClassification(
llm=llm,
n=1,
is_multilabel=False,
context="Determine the intent of the text.",
available_labels={
"complaint": "A statement expressing dissatisfaction or annoyance about a product, service, or experience. It's a negative expression of discontent, often with the intention of seeking a resolution or compensation.",
Expand Down Expand Up @@ -164,7 +173,7 @@ class TextClassification(Task):

text_classification = TextClassification(
llm=llm,
n=3,
is_multilabel=True,
context=(
"Describe the main themes, topics, or categories that could describe the "
"following type of persona."
Expand Down Expand Up @@ -199,7 +208,11 @@ class TextClassification(Task):
)
n: PositiveInt = Field(
default=1,
description="Number of labels to generate. Defaults to 1.",
description="Number of labels to generate. Only used for TextClustering.",
)
is_multilabel: bool = Field(
default=False,
description="Indicates whether the task allows multiple labels or a single label. Only used for TextClassification.",
)
context: Optional[str] = Field(
default="Generate concise, relevant labels that accurately represent the text's main themes, topics, or categories.",
Expand All @@ -220,7 +233,7 @@ class TextClassification(Task):
default="Unclassified",
description=(
"Default label to use when the text is ambiguous or lacks sufficient information for "
"classification. Can be a list in case of multiple labels (n>1)."
"classification. Can be a list in case of multiple labels."
),
)
query_title: str = Field(
Expand All @@ -231,21 +244,43 @@ class TextClassification(Task):

_template: Optional[Template] = PrivateAttr(default=None)

@model_validator(mode="after")
def multilabel_validation(self) -> Self:
if self.n > 1 and self.is_multilabel:
raise DistilabelUserError(
"Only one of 'is_multilabel' for TextClassifiaction or 'n' for TextClustering can be set at the same time.",
page="components-gallery/tasks/textclassification/",
)
return self

def load(self) -> None:
super().load()
self._template = Template(TEXT_CLASSIFICATION_TEMPLATE)
self._labels_format: str = (
'"label"'
if self.n == 1
else "[" + ", ".join([f'"label_{i}"' for i in range(self.n)]) + "]"
if self.n == 1 and not self.is_multilabel
else "["
+ ", ".join(
[f'"label_{i}"' for i in range(3 if self.is_multilabel else self.n)]
)
+ "]"
)
self._labels_message: str = (
"Provide the label that best describes the text."
if self.n == 1
else f"Provide a list of {self.n} labels that best describe the text."
if self.n == 1 and not self.is_multilabel
else (
f"Provide a list of {self.n} labels that best describe the text."
if not self.is_multilabel
else "Provide a list with the label or labels that best describe the text. Do not include any label that do not apply."
)
)
self._available_labels_message: str = self._get_available_labels_message()
self._examples: str = self._get_examples_message()
self._output_message: str = (
"please give me only the correct label in JSON format"
if self.n == 1 and not self.is_multilabel
else ("please give me only the relevant labels in JSON format")
)

def _get_available_labels_message(self) -> str:
"""Prepares the message to display depending on the available labels (if any),
Expand Down Expand Up @@ -320,6 +355,7 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
labels_message=self._labels_message,
available_labels=self._available_labels_message,
examples=self._examples,
output_message=self._output_message,
default_label=self.default_label,
labels_format=self._labels_format,
query_title=self.query_title,
Expand All @@ -346,17 +382,17 @@ def get_structured_output(self) -> Dict[str, Any]:
Returns:
JSON Schema of the response to enforce.
"""
if self.n > 1:
if self.n == 1 and not self.is_multilabel:

class MultiLabelSchema(BaseModel):
labels: List[str]
class SingleLabelSchema(BaseModel):
labels: str

return MultiLabelSchema.model_json_schema()
return SingleLabelSchema.model_json_schema()

class SingleLabelSchema(BaseModel):
labels: str
class MultiLabelSchema(BaseModel):
labels: List[str]

return SingleLabelSchema.model_json_schema()
return MultiLabelSchema.model_json_schema()

def _format_structured_output(
self, output: str
Expand All @@ -373,6 +409,6 @@ def _format_structured_output(
try:
return orjson.loads(output)
except orjson.JSONDecodeError:
if self.n > 1:
if self.n > 1 and not self.is_multilabel:
return {"labels": [None for _ in range(self.n)]}
return {"labels": None}
return {"labels": [None] if self.is_multilabel else None}
76 changes: 50 additions & 26 deletions tests/unit/steps/tasks/test_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import pytest
from pydantic import ValidationError

from distilabel.steps.tasks.text_classification import TextClassification
from tests.unit.conftest import DummyAsyncLLM
Expand All @@ -26,15 +27,15 @@


class TextClassificationLLM(DummyAsyncLLM):
n: int = 1
is_multilabel: bool = False

async def agenerate( # type: ignore
self, input: "FormattedInput", num_generations: int = 1
) -> "GenerateOutput":
if self.n == 1:
labels = "label"
else:
if self.is_multilabel:
labels = ["label_0", "label_1", "label_2"]
else:
labels = "label"
return {
"generations": [
json.dumps({"labels": labels}) for _ in range(num_generations)
Expand All @@ -48,20 +49,20 @@ async def agenerate( # type: ignore

class TestTextClassification:
@pytest.mark.parametrize(
"n, context, examples, available_labels, default_label, query_title",
"is_multilabel, context, examples, available_labels, default_label, query_title",
[
(1, "context", None, None, "Unclassified", "User Query"),
(1, "", ["example"], ["label1", "label2"], "default", "User Query"),
(False, "context", None, None, "Unclassified", "User Query"),
(False, "", ["example"], ["label1", "label2"], "default", "User Query"),
(
1,
False,
"",
["example"],
{"label1": "explanation 1", "label2": "explanation 2"},
"default",
"User Query",
),
(
3,
True,
"",
["example", "other example"],
None,
Expand All @@ -72,7 +73,7 @@ class TestTextClassification:
)
def test_format_input(
self,
n: int,
is_multilabel: bool,
context: str,
examples: Optional[List[str]],
available_labels: Optional[Union[List[str], Dict[str, str]]],
Expand All @@ -81,7 +82,7 @@ def test_format_input(
) -> None:
task = TextClassification(
llm=DummyAsyncLLM(),
n=n,
is_multilabel=is_multilabel,
context=context,
examples=examples,
available_labels=available_labels,
Expand All @@ -96,17 +97,15 @@ def test_format_input(
assert f'respond with "{default_label}"' in content
assert "## User Query\n```\nSAMPLE_TEXT\n```" in content
assert f'respond with "{default_label}"' in content
if n == 1:
if not is_multilabel:
assert "Provide the label that best describes the text." in content
assert '```\n{\n "labels": "label"\n}\n```' in content
else:
assert (
f"Provide a list of {n} labels that best describe the text." in content
)
assert (
'```\n{\n "labels": ["label_0", "label_1", "label_2"]\n}\n```'
"Provide a list with the label or labels that best describe the text. Do not include any label that do not apply."
in content
)
assert '```\n{\n "labels": [' in content
if available_labels:
if isinstance(available_labels, list):
assert 'Use the available labels to classify the user query:\navailable_labels = [\n "label1",\n "label2"\n]'
Expand All @@ -127,21 +126,46 @@ def test_format_input(
assert f"## {query_title}" in content

@pytest.mark.parametrize(
"n, expected",
"is_multilabel, expected",
[
(1, json.dumps({"labels": "label"})),
(3, json.dumps({"labels": ["label_0", "label_1", "label_2"]})),
(False, json.dumps({"labels": "label"})),
(
True,
[
json.dumps({"labels": ["label_0"]}),
json.dumps({"labels": ["label_0", "label_1"]}),
json.dumps({"labels": ["label_0", "label_1", "label_2"]}),
],
),
],
)
def test_process(self, n: int, expected: str) -> None:
def test_process(self, is_multilabel: bool, expected: str) -> None:
task = TextClassification(
llm=TextClassificationLLM(n=n), n=n, use_default_structured_output=True
llm=TextClassificationLLM(is_multilabel=is_multilabel),
is_multilabel=is_multilabel,
use_default_structured_output=True,
)
task.load()
result = next(task.process([{"text": "SAMPLE_TEXT"}]))
assert result[0]["text"] == "SAMPLE_TEXT"
assert result[0]["labels"] == json.loads(expected)["labels"]
assert (
result[0]["distilabel_metadata"]["raw_output_text_classification_0"]
== expected
)
if is_multilabel:
assert result[0]["labels"] in [
json.loads(opt)["labels"] for opt in expected
]
assert (
result[0]["distilabel_metadata"]["raw_output_text_classification_0"]
in expected
)
else:
assert result[0]["labels"] == json.loads(expected)["labels"]
assert (
result[0]["distilabel_metadata"]["raw_output_text_classification_0"]
== expected
)

def test_multilabel_error(self) -> None:
with pytest.raises(
ValidationError,
match=r"Only one of \'is_multilabel\' for TextClassifiaction or \'n\' for TextClustering can be set at the same time.",
):
TextClassification(llm=DummyAsyncLLM(), is_multilabel=True, n=2)
Loading