Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply proper serializer when pushing feedback dataset in Argilla #3192

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ These are the section headers that we use:

- Resolve breaking issue with `ArgillaSpanMarkerTrainer` for Named Entity Recognition with `span_marker` v1.1.x onwards.
- Move `ArgillaDatasetCard` import under `@requires_version` decorator, so that the `ImportError` on `huggingface_hub` is handled properly ([#3174](https://github.com/argilla-io/argilla/pull/3174))
- Allow flow `FeedbackDataset.from_argilla` -> `FeedbackDataset.push_to_argilla` under different dataset names and/or workspaces ([#3192](https://github.com/argilla-io/argilla/issues/3192))

## [1.9.0](https://github.com/argilla-io/argilla/compare/v1.8.0...v1.9.0)

Expand Down
17 changes: 8 additions & 9 deletions src/argilla/client/feedback/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
alvarobartt marked this conversation as resolved.
Show resolved Hide resolved
import logging
import tempfile
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from uuid import UUID

try:
from typing import Literal
Expand Down Expand Up @@ -142,7 +143,7 @@ class FeedbackDataset:
[FeedbackRecord(fields={"text": "This is the first record", "label": "positive"}, responses=[ResponseSchema(user_id=None, values={"question-1": ValueSchema(value="This is the first answer"), "question-2": ValueSchema(value=5), "question-3": ValueSchema(value="positive"), "question-4": ValueSchema(value=["category-1"])})], external_id="entry-1")]
"""

argilla_id: Optional[str] = None
argilla_id: Optional[UUID] = None

def __init__(
self,
Expand Down Expand Up @@ -465,7 +466,7 @@ def add_records(
else:
self.__new_records = records

def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[FeedbackRecord]:
def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[List[FeedbackRecord]]:
"""Returns an iterator over the records in the dataset.

Args:
Expand Down Expand Up @@ -543,7 +544,7 @@ def push_to_argilla(self, name: Optional[str] = None, workspace: Optional[Union[
f"Failed while creating the `FeedbackTask` dataset in Argilla with exception: {e}"
) from e

def delete_and_raise_exception(dataset_id: str, exception: Exception) -> None:
def delete_and_raise_exception(dataset_id: UUID, exception: Exception) -> None:
try:
datasets_api_v1.delete_dataset(client=httpx_client, id=dataset_id)
except Exception as e:
Expand All @@ -555,7 +556,7 @@ def delete_and_raise_exception(dataset_id: str, exception: Exception) -> None:

for field in self.fields:
try:
datasets_api_v1.add_field(client=httpx_client, id=argilla_id, field=field.dict())
datasets_api_v1.add_field(client=httpx_client, id=argilla_id, field=json.loads(field.json()))
alvarobartt marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
delete_and_raise_exception(
dataset_id=argilla_id,
Expand All @@ -568,9 +569,7 @@ def delete_and_raise_exception(dataset_id: str, exception: Exception) -> None:
for question in self.questions:
try:
datasets_api_v1.add_question(
client=httpx_client,
id=argilla_id,
question=question.dict(include={"name", "title", "description", "required", "settings"}),
client=httpx_client, id=argilla_id, question=json.loads(question.json())
alvarobartt marked this conversation as resolved.
Show resolved Hide resolved
)
except Exception as e:
delete_and_raise_exception(
Expand All @@ -596,7 +595,7 @@ def delete_and_raise_exception(dataset_id: str, exception: Exception) -> None:
datasets_api_v1.add_records(
client=httpx_client,
id=argilla_id,
records=[record.dict() for record in batch],
records=[json.loads(record.json()) for record in batch],
)
except Exception as e:
delete_and_raise_exception(
Expand Down
21 changes: 11 additions & 10 deletions src/argilla/client/sdk/v1/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import warnings
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

import httpx

Expand All @@ -35,7 +36,7 @@
def create_dataset(
client: httpx.Client,
name: str,
workspace_id: str,
workspace_id: UUID,
guidelines: Optional[str] = None,
) -> Response[Union[FeedbackDatasetModel, ErrorMessage, HTTPValidationError]]:
"""Sends a POST reques to `/api/v1/datasets` endpoint to create a new `FeedbackTask` dataset.
Expand Down Expand Up @@ -71,7 +72,7 @@ def create_dataset(

def get_dataset(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[FeedbackDatasetModel, ErrorMessage, HTTPValidationError]]:
"""Sends a GET request to `/api/v1/datasets/{id}` endpoint to retrieve a `FeedbackTask` dataset.

Expand Down Expand Up @@ -100,7 +101,7 @@ def get_dataset(

def delete_dataset(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[ErrorMessage, HTTPValidationError]]:
"""Sends a DELETE request to `/api/v1/datasets/{id}` endpoint to delete a `FeedbackTask` dataset.

Expand All @@ -126,7 +127,7 @@ def delete_dataset(

def publish_dataset(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[FeedbackDatasetModel, ErrorMessage, HTTPValidationError]]:
"""Sends a PUT request to `/api/v1/datasets/{id}/publish` endpoint to publish a `FeedbackTask` dataset.
Publishing in Argilla means setting the status of the dataset from `draft` to `ready`, so that
Expand Down Expand Up @@ -183,7 +184,7 @@ def list_datasets(

def get_records(
client: httpx.Client,
id: str,
id: UUID,
offset: int = 0,
limit: int = 50,
) -> Response[Union[FeedbackRecordsModel, ErrorMessage, HTTPValidationError]]:
Expand Down Expand Up @@ -219,7 +220,7 @@ def get_records(

def add_records(
client: httpx.Client,
id: str,
id: UUID,
records: List[Dict[str, Any]],
) -> Response[Union[ErrorMessage, HTTPValidationError]]:
"""Sends a POST request to `/api/v1/datasets/{id}/records` endpoint to add a list of `FeedbackTask` records.
Expand Down Expand Up @@ -269,7 +270,7 @@ def add_records(

def get_fields(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[List[FeedbackFieldModel], ErrorMessage, HTTPValidationError]]:
"""Sends a GET request to `/api/v1/datasets/{id}/fields` endpoint to retrieve a list of `FeedbackTask` fields.

Expand Down Expand Up @@ -298,7 +299,7 @@ def get_fields(

def add_field(
client: httpx.Client,
id: str,
id: UUID,
field: Dict[str, Any],
) -> Response[Union[ErrorMessage, HTTPValidationError]]:
"""Sends a POST request to `/api/v1/datasets/{id}/fields` endpoint to add a `FeedbackTask` field.
Expand Down Expand Up @@ -326,7 +327,7 @@ def add_field(

def get_questions(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[List[FeedbackQuestionModel], ErrorMessage, HTTPValidationError]]:
"""Sends a GET request to `/api/v1/datasets/{id}/questions` endpoint to retrieve a list of `FeedbackTask` questions.

Expand Down Expand Up @@ -355,7 +356,7 @@ def get_questions(

def add_question(
client: httpx.Client,
id: str,
id: UUID,
question: Dict[str, Any],
) -> Response[Union[ErrorMessage, HTTPValidationError]]:
"""Sends a POST request to `/api/v1/datasets/{id}/questions` endpoint to add a question to the `FeedbackTask` dataset.
Expand Down
26 changes: 26 additions & 0 deletions tests/client/feedback/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import datasets
import pytest
from argilla.client import api
from argilla.server.models import User

if TYPE_CHECKING:
from argilla.client.feedback.schemas import AllowedFieldTypes, AllowedQuestionTypes
Expand Down Expand Up @@ -333,6 +334,31 @@ def test_push_to_argilla_and_from_argilla(
assert len(dataset_from_argilla.records[-1].responses) == 1 # Since the second one was discarded as `user_id=None`


def test_copy_dataset_in_argilla(
mocked_client,
argilla_user: User,
feedback_dataset_guidelines: str,
feedback_dataset_fields: List["AllowedFieldTypes"],
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records: List[FeedbackRecord],
) -> None:
api.active_api()
api.init(api_key=argilla_user.api_key)

dataset = FeedbackDataset(
guidelines=feedback_dataset_guidelines,
fields=feedback_dataset_fields,
questions=feedback_dataset_questions,
)
dataset.add_records(records=feedback_dataset_records)
dataset.push_to_argilla(name="test-dataset")

same_dataset = FeedbackDataset.from_argilla("test-dataset")
same_dataset.push_to_argilla("copy-dataset")

assert same_dataset.argilla_id is not None


@pytest.mark.usefixtures(
"feedback_dataset_guidelines",
"feedback_dataset_fields",
Expand Down