Skip to content

Commit

Permalink
Apply proper serializer when pushing feedback dataset in Argilla (#3192)
Browse files Browse the repository at this point in the history
# Description

By default, the `model.dict()` method won't apply custom serialization
to the `UUID` or `datetatime` fields. Since pydantic provides a `.json`
method which applies serialization properly, we should use them in order
to support pushing data in Argilla.

The current solution may be improved since we're serializing more times
than needed. But for that, we should change the sdk.v1.api layer. See
issue #3191

Closes #3189 

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

New tests have been added

**Checklist**

- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Alvaro Bartolome <[email protected]>
  • Loading branch information
frascuchon and alvarobartt authored Jun 14, 2023
1 parent b86d49b commit e33a2df
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ These are the section headers that we use:

- Resolve breaking issue with `ArgillaSpanMarkerTrainer` for Named Entity Recognition with `span_marker` v1.1.x onwards.
- Move `ArgillaDatasetCard` import under `@requires_version` decorator, so that the `ImportError` on `huggingface_hub` is handled properly ([#3174](https://github.com/argilla-io/argilla/pull/3174))
- Allow flow `FeedbackDataset.from_argilla` -> `FeedbackDataset.push_to_argilla` under different dataset names and/or workspaces ([#3192](https://github.com/argilla-io/argilla/issues/3192))

## [1.9.0](https://github.com/argilla-io/argilla/compare/v1.8.0...v1.9.0)

Expand Down
17 changes: 8 additions & 9 deletions src/argilla/client/feedback/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import tempfile
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from uuid import UUID

try:
from typing import Literal
Expand Down Expand Up @@ -142,7 +143,7 @@ class FeedbackDataset:
[FeedbackRecord(fields={"text": "This is the first record", "label": "positive"}, responses=[ResponseSchema(user_id=None, values={"question-1": ValueSchema(value="This is the first answer"), "question-2": ValueSchema(value=5), "question-3": ValueSchema(value="positive"), "question-4": ValueSchema(value=["category-1"])})], external_id="entry-1")]
"""

argilla_id: Optional[str] = None
argilla_id: Optional[UUID] = None

def __init__(
self,
Expand Down Expand Up @@ -465,7 +466,7 @@ def add_records(
else:
self.__new_records = records

def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[FeedbackRecord]:
def iter(self, batch_size: Optional[int] = FETCHING_BATCH_SIZE) -> Iterator[List[FeedbackRecord]]:
"""Returns an iterator over the records in the dataset.
Args:
Expand Down Expand Up @@ -543,7 +544,7 @@ def push_to_argilla(self, name: Optional[str] = None, workspace: Optional[Union[
f"Failed while creating the `FeedbackTask` dataset in Argilla with exception: {e}"
) from e

def delete_and_raise_exception(dataset_id: str, exception: Exception) -> None:
def delete_and_raise_exception(dataset_id: UUID, exception: Exception) -> None:
try:
datasets_api_v1.delete_dataset(client=httpx_client, id=dataset_id)
except Exception as e:
Expand All @@ -555,7 +556,7 @@ def delete_and_raise_exception(dataset_id: str, exception: Exception) -> None:

for field in self.fields:
try:
datasets_api_v1.add_field(client=httpx_client, id=argilla_id, field=field.dict())
datasets_api_v1.add_field(client=httpx_client, id=argilla_id, field=json.loads(field.json()))
except Exception as e:
delete_and_raise_exception(
dataset_id=argilla_id,
Expand All @@ -568,9 +569,7 @@ def delete_and_raise_exception(dataset_id: str, exception: Exception) -> None:
for question in self.questions:
try:
datasets_api_v1.add_question(
client=httpx_client,
id=argilla_id,
question=question.dict(include={"name", "title", "description", "required", "settings"}),
client=httpx_client, id=argilla_id, question=json.loads(question.json())
)
except Exception as e:
delete_and_raise_exception(
Expand All @@ -596,7 +595,7 @@ def delete_and_raise_exception(dataset_id: str, exception: Exception) -> None:
datasets_api_v1.add_records(
client=httpx_client,
id=argilla_id,
records=[record.dict() for record in batch],
records=[json.loads(record.json()) for record in batch],
)
except Exception as e:
delete_and_raise_exception(
Expand Down
21 changes: 11 additions & 10 deletions src/argilla/client/sdk/v1/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import warnings
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

import httpx

Expand All @@ -35,7 +36,7 @@
def create_dataset(
client: httpx.Client,
name: str,
workspace_id: str,
workspace_id: UUID,
guidelines: Optional[str] = None,
) -> Response[Union[FeedbackDatasetModel, ErrorMessage, HTTPValidationError]]:
"""Sends a POST reques to `/api/v1/datasets` endpoint to create a new `FeedbackTask` dataset.
Expand Down Expand Up @@ -71,7 +72,7 @@ def create_dataset(

def get_dataset(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[FeedbackDatasetModel, ErrorMessage, HTTPValidationError]]:
"""Sends a GET request to `/api/v1/datasets/{id}` endpoint to retrieve a `FeedbackTask` dataset.
Expand Down Expand Up @@ -100,7 +101,7 @@ def get_dataset(

def delete_dataset(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[ErrorMessage, HTTPValidationError]]:
"""Sends a DELETE request to `/api/v1/datasets/{id}` endpoint to delete a `FeedbackTask` dataset.
Expand All @@ -126,7 +127,7 @@ def delete_dataset(

def publish_dataset(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[FeedbackDatasetModel, ErrorMessage, HTTPValidationError]]:
"""Sends a PUT request to `/api/v1/datasets/{id}/publish` endpoint to publish a `FeedbackTask` dataset.
Publishing in Argilla means setting the status of the dataset from `draft` to `ready`, so that
Expand Down Expand Up @@ -183,7 +184,7 @@ def list_datasets(

def get_records(
client: httpx.Client,
id: str,
id: UUID,
offset: int = 0,
limit: int = 50,
) -> Response[Union[FeedbackRecordsModel, ErrorMessage, HTTPValidationError]]:
Expand Down Expand Up @@ -219,7 +220,7 @@ def get_records(

def add_records(
client: httpx.Client,
id: str,
id: UUID,
records: List[Dict[str, Any]],
) -> Response[Union[ErrorMessage, HTTPValidationError]]:
"""Sends a POST request to `/api/v1/datasets/{id}/records` endpoint to add a list of `FeedbackTask` records.
Expand Down Expand Up @@ -269,7 +270,7 @@ def add_records(

def get_fields(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[List[FeedbackFieldModel], ErrorMessage, HTTPValidationError]]:
"""Sends a GET request to `/api/v1/datasets/{id}/fields` endpoint to retrieve a list of `FeedbackTask` fields.
Expand Down Expand Up @@ -298,7 +299,7 @@ def get_fields(

def add_field(
client: httpx.Client,
id: str,
id: UUID,
field: Dict[str, Any],
) -> Response[Union[ErrorMessage, HTTPValidationError]]:
"""Sends a POST request to `/api/v1/datasets/{id}/fields` endpoint to add a `FeedbackTask` field.
Expand Down Expand Up @@ -326,7 +327,7 @@ def add_field(

def get_questions(
client: httpx.Client,
id: str,
id: UUID,
) -> Response[Union[List[FeedbackQuestionModel], ErrorMessage, HTTPValidationError]]:
"""Sends a GET request to `/api/v1/datasets/{id}/questions` endpoint to retrieve a list of `FeedbackTask` questions.
Expand Down Expand Up @@ -355,7 +356,7 @@ def get_questions(

def add_question(
client: httpx.Client,
id: str,
id: UUID,
question: Dict[str, Any],
) -> Response[Union[ErrorMessage, HTTPValidationError]]:
"""Sends a POST request to `/api/v1/datasets/{id}/questions` endpoint to add a question to the `FeedbackTask` dataset.
Expand Down
26 changes: 26 additions & 0 deletions tests/client/feedback/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import datasets
import pytest
from argilla.client import api
from argilla.server.models import User

if TYPE_CHECKING:
from argilla.client.feedback.schemas import AllowedFieldTypes, AllowedQuestionTypes
Expand Down Expand Up @@ -333,6 +334,31 @@ def test_push_to_argilla_and_from_argilla(
assert len(dataset_from_argilla.records[-1].responses) == 1 # Since the second one was discarded as `user_id=None`


def test_copy_dataset_in_argilla(
mocked_client,
argilla_user: User,
feedback_dataset_guidelines: str,
feedback_dataset_fields: List["AllowedFieldTypes"],
feedback_dataset_questions: List["AllowedQuestionTypes"],
feedback_dataset_records: List[FeedbackRecord],
) -> None:
api.active_api()
api.init(api_key=argilla_user.api_key)

dataset = FeedbackDataset(
guidelines=feedback_dataset_guidelines,
fields=feedback_dataset_fields,
questions=feedback_dataset_questions,
)
dataset.add_records(records=feedback_dataset_records)
dataset.push_to_argilla(name="test-dataset")

same_dataset = FeedbackDataset.from_argilla("test-dataset")
same_dataset.push_to_argilla("copy-dataset")

assert same_dataset.argilla_id is not None


@pytest.mark.usefixtures(
"feedback_dataset_guidelines",
"feedback_dataset_fields",
Expand Down

0 comments on commit e33a2df

Please sign in to comment.