From f62d58a2f91e16eb4d02e56c4039a432070349c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 1 Jul 2024 12:31:02 +0200 Subject: [PATCH 01/36] feat: add dataset support to be created using distribution settings (#5013) # Description This PR is the first one related with distribution task feature, adding the following changes: * Added `distribution` JSON column to `datasets` table: * This column is non-nullable so a value is always required when a dataset is created. * By default old datasets will have the value `{"strategy": "overlap", "min_submitted": 1}`. * Added `distribution` attribute to `DatasetCreate` schema: * None is not a valid value. * If no value is specified for this attribute `DatasetOverlapDistributionCreate` with `min_submitted` to `1` is used. * `DatasetOverlapDistributionCreate` only allows values greater or equal than `1` for `min_submitted` attributed. * Now the context `create_dataset` function is receiving a dictionary instead of `DatasetCreate` schema. * Moved dataset creation validations to a new `DatasetCreateValidator` class. Update of `distribution` attribute for datasets will be done in a different issue. Closes #5005 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [x] Adding new tests and passing old ones. - [x] Check that migration works as expected with old datasets and SQLite. - [x] Check that migration works as expected with old datasets and PostgreSQL. **Checklist** - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paco Aranda --- .../repositories/RecordRepository.ts | 30 ++- argilla-server/CHANGELOG.md | 7 +- ...4d74_add_status_column_to_records_table.py | 60 ++++++ ...7_add_metadata_column_to_records_table.py} | 6 +- ...d_distribution_column_to_datasets_table.py | 45 +++++ ...xtra_metadata_column_to_datasets_table.py} | 6 +- .../api/handlers/v1/datasets/datasets.py | 4 +- .../api/handlers/v1/responses.py | 8 +- .../argilla_server/api/schemas/v1/datasets.py | 38 +++- .../argilla_server/api/schemas/v1/records.py | 5 +- .../src/argilla_server/bulk/records_bulk.py | 4 + .../src/argilla_server/contexts/datasets.py | 56 +++--- .../argilla_server/contexts/distribution.py | 42 +++++ argilla-server/src/argilla_server/enums.py | 9 + .../src/argilla_server/models/database.py | 31 ++- .../src/argilla_server/search_engine/base.py | 4 + .../argilla_server/search_engine/commons.py | 6 + .../src/argilla_server/validators/datasets.py | 48 +++++ argilla-server/tests/factories.py | 3 +- .../records_bulk/test_dataset_records_bulk.py | 3 +- .../v1/datasets/test_create_dataset.py | 139 ++++++++++++++ ...est_search_current_user_dataset_records.py | 5 +- .../datasets/test_search_dataset_records.py | 4 +- .../v1/datasets/test_update_dataset.py | 178 ++++++++++++++++++ .../test_create_dataset_records_bulk.py | 145 ++++++++++++++ .../v1/records/test_create_record_response.py | 100 ++++++++-- .../test_upsert_dataset_records_bulk.py | 153 +++++++++++++++ ...test_create_current_user_responses_bulk.py | 10 +- .../v1/responses/test_delete_response.py | 66 +++++++ .../v1/responses/test_update_response.py | 71 ++++++- .../unit/api/handlers/v1/test_datasets.py | 39 +++- .../handlers/v1/test_list_dataset_records.py | 6 +- .../unit/api/handlers/v1/test_records.py | 10 +- .../tests/unit/search_engine/test_commons.py | 12 +- argilla/src/argilla/_models/_search.py | 6 + argilla/src/argilla/records/_search.py | 4 +- 36 files changed, 1274 insertions(+), 89 deletions(-) create mode 100644 argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py rename argilla-server/src/argilla_server/alembic/versions/{3ff6484f8b37_add_record_metadata_column.py => 3ff6484f8b37_add_metadata_column_to_records_table.py} (82%) create mode 100644 argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py rename argilla-server/src/argilla_server/alembic/versions/{b8458008b60e_add_allow_extra_metadata_column_to_.py => b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py} (81%) create mode 100644 argilla-server/src/argilla_server/contexts/distribution.py create mode 100644 argilla-server/src/argilla_server/validators/datasets.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py create mode 100644 argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py diff --git a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts index e0e30adfd3..40ce2645eb 100644 --- a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts @@ -42,10 +42,8 @@ export class RecordRepository { constructor(private readonly axios: NuxtAxiosInstance) {} getRecords(criteria: RecordCriteria): Promise { - if (criteria.isFilteringByAdvanceSearch) - return this.getRecordsByAdvanceSearch(criteria); - - return this.getRecordsByDatasetId(criteria); + return this.getRecordsByAdvanceSearch(criteria); + // return this.getRecordsByDatasetId(criteria); } async getRecord(recordId: string): Promise { @@ -264,6 +262,30 @@ export class RecordRepository { }; } + body.filters = { + and: [ + { + type: "terms", + scope: { + entity: "response", + property: "status", + }, + values: [status], + }, + ], + }; + + if (status === "pending") { + body.filters.and.push({ + type: "terms", + scope: { + entity: "record", + property: "status", + }, + values: ["pending"], + }); + } + if ( isFilteringByMetadata || isFilteringByResponse || diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index de84587e41..827037a2c3 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -16,12 +16,17 @@ These are the section headers that we use: ## [Unreleased]() -## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) +### Added + +- Added support to specify `distribution` attribute when creating a dataset. ([#5013](https://github.com/argilla-io/argilla/pull/5013)) +- Added support to change `distribution` attribute when updating a dataset. ([#5028](https://github.com/argilla-io/argilla/pull/5028)) ### Changed - Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) +## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) + ### Removed - Removed all API v0 endpoints. ([#4852](https://github.com/argilla-io/argilla/pull/4852)) diff --git a/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py b/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py new file mode 100644 index 0000000000..767b277573 --- /dev/null +++ b/argilla-server/src/argilla_server/alembic/versions/237f7c674d74_add_status_column_to_records_table.py @@ -0,0 +1,60 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""add status column to records table + +Revision ID: 237f7c674d74 +Revises: 45a12f74448b +Create Date: 2024-06-18 17:59:36.992165 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "237f7c674d74" +down_revision = "45a12f74448b" +branch_labels = None +depends_on = None + + +record_status_enum = sa.Enum("pending", "completed", name="record_status_enum") + + +def upgrade() -> None: + record_status_enum.create(op.get_bind()) + + op.add_column("records", sa.Column("status", record_status_enum, server_default="pending", nullable=False)) + op.create_index(op.f("ix_records_status"), "records", ["status"], unique=False) + + # NOTE: Updating existent records to have "completed" status when they have + # at least one response with "submitted" status. + op.execute(""" + UPDATE records + SET status = 'completed' + WHERE id IN ( + SELECT DISTINCT record_id + FROM responses + WHERE status = 'submitted' + ); + """) + + +def downgrade() -> None: + op.drop_index(op.f("ix_records_status"), table_name="records") + op.drop_column("records", "status") + + record_status_enum.drop(op.get_bind()) diff --git a/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py b/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py similarity index 82% rename from argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py rename to argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py index 7ac80ad895..b5949f5364 100644 --- a/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_record_metadata_column.py +++ b/argilla-server/src/argilla_server/alembic/versions/3ff6484f8b37_add_metadata_column_to_records_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""add record metadata column +"""add metadata column to records table Revision ID: 3ff6484f8b37 Revises: ae5522b4c674 @@ -31,12 +31,8 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column("records", sa.Column("metadata", sa.JSON(), nullable=True)) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("records", "metadata") - # ### end Alembic commands ### diff --git a/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py b/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py new file mode 100644 index 0000000000..791da07439 --- /dev/null +++ b/argilla-server/src/argilla_server/alembic/versions/45a12f74448b_add_distribution_column_to_datasets_table.py @@ -0,0 +1,45 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""add distribution column to datasets table + +Revision ID: 45a12f74448b +Revises: d00f819ccc67 +Create Date: 2024-06-13 11:23:43.395093 + +""" + +import json + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "45a12f74448b" +down_revision = "d00f819ccc67" +branch_labels = None +depends_on = None + +DISTRIBUTION_VALUE = json.dumps({"strategy": "overlap", "min_submitted": 1}) + + +def upgrade() -> None: + op.add_column("datasets", sa.Column("distribution", sa.JSON(), nullable=True)) + op.execute(f"UPDATE datasets SET distribution = '{DISTRIBUTION_VALUE}'") + with op.batch_alter_table("datasets") as batch_op: + batch_op.alter_column("distribution", nullable=False) + + +def downgrade() -> None: + op.drop_column("datasets", "distribution") diff --git a/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py b/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py similarity index 81% rename from argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py rename to argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py index 8b23340448..f8fa87536e 100644 --- a/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_.py +++ b/argilla-server/src/argilla_server/alembic/versions/b8458008b60e_add_allow_extra_metadata_column_to_datasets_table.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""add allow_extra_metadata column to dataset table +"""add allow_extra_metadata column to datasets table Revision ID: b8458008b60e Revises: 7cbcccf8b57a @@ -31,14 +31,10 @@ def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.add_column( "datasets", sa.Column("allow_extra_metadata", sa.Boolean(), server_default=sa.text("true"), nullable=False) ) - # ### end Alembic commands ### def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### op.drop_column("datasets", "allow_extra_metadata") - # ### end Alembic commands ### diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py index 63f95391e1..0590b41bb4 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py @@ -189,7 +189,7 @@ async def create_dataset( ): await authorize(current_user, DatasetPolicy.create(dataset_create.workspace_id)) - return await datasets.create_dataset(db, dataset_create) + return await datasets.create_dataset(db, dataset_create.dict()) @router.post("/datasets/{dataset_id}/fields", status_code=status.HTTP_201_CREATED, response_model=Field) @@ -302,4 +302,4 @@ async def update_dataset( await authorize(current_user, DatasetPolicy.update(dataset)) - return await datasets.update_dataset(db, dataset, dataset_update) + return await datasets.update_dataset(db, dataset, dataset_update.dict(exclude_unset=True)) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/responses.py b/argilla-server/src/argilla_server/api/handlers/v1/responses.py index 56cb695c95..ddc389563a 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/responses.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/responses.py @@ -64,7 +64,9 @@ async def update_response( response = await Response.get_or_raise( db, response_id, - options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)], + options=[ + selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions), + ], ) await authorize(current_user, ResponsePolicy.update(response)) @@ -83,7 +85,9 @@ async def delete_response( response = await Response.get_or_raise( db, response_id, - options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)], + options=[ + selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions), + ], ) await authorize(current_user, ResponsePolicy.delete(response)) diff --git a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py index 5cac33bdb7..1e1b69d836 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py @@ -13,11 +13,11 @@ # limitations under the License. from datetime import datetime -from typing import List, Optional +from typing import List, Literal, Optional, Union from uuid import UUID from argilla_server.api.schemas.v1.commons import UpdateSchema -from argilla_server.enums import DatasetStatus +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus from argilla_server.pydantic_v1 import BaseModel, Field, constr try: @@ -44,6 +44,32 @@ ] +class DatasetOverlapDistribution(BaseModel): + strategy: Literal[DatasetDistributionStrategy.overlap] + min_submitted: int + + +DatasetDistribution = DatasetOverlapDistribution + + +class DatasetOverlapDistributionCreate(BaseModel): + strategy: Literal[DatasetDistributionStrategy.overlap] + min_submitted: int = Field( + ge=1, + description="Minimum number of submitted responses to consider a record as completed", + ) + + +DatasetDistributionCreate = DatasetOverlapDistributionCreate + + +class DatasetOverlapDistributionUpdate(DatasetDistributionCreate): + pass + + +DatasetDistributionUpdate = DatasetOverlapDistributionUpdate + + class RecordMetrics(BaseModel): count: int @@ -74,6 +100,7 @@ class Dataset(BaseModel): guidelines: Optional[str] allow_extra_metadata: bool status: DatasetStatus + distribution: DatasetDistribution workspace_id: UUID last_activity_at: datetime inserted_at: datetime @@ -91,6 +118,10 @@ class DatasetCreate(BaseModel): name: DatasetName guidelines: Optional[DatasetGuidelines] allow_extra_metadata: bool = True + distribution: DatasetDistributionCreate = DatasetOverlapDistributionCreate( + strategy=DatasetDistributionStrategy.overlap, + min_submitted=1, + ) workspace_id: UUID @@ -98,5 +129,6 @@ class DatasetUpdate(UpdateSchema): name: Optional[DatasetName] guidelines: Optional[DatasetGuidelines] allow_extra_metadata: Optional[bool] + distribution: Optional[DatasetDistributionUpdate] - __non_nullable_fields__ = {"name", "allow_extra_metadata"} + __non_nullable_fields__ = {"name", "allow_extra_metadata", "distribution"} diff --git a/argilla-server/src/argilla_server/api/schemas/v1/records.py b/argilla-server/src/argilla_server/api/schemas/v1/records.py index 13f37c3ae0..0cf215954a 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/records.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/records.py @@ -23,7 +23,7 @@ from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate from argilla_server.api.schemas.v1.suggestions import Suggestion, SuggestionCreate, SuggestionFilterScope -from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder +from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder, RecordStatus from argilla_server.pydantic_v1 import BaseModel, Field, StrictStr, root_validator, validator from argilla_server.pydantic_v1.utils import GetterDict from argilla_server.search_engine import TextQuery @@ -66,6 +66,7 @@ def get(self, key: str, default: Any) -> Any: class Record(BaseModel): id: UUID + status: RecordStatus fields: Dict[str, Any] metadata: Optional[Dict[str, Any]] external_id: Optional[str] @@ -196,7 +197,7 @@ def _has_relationships(self): class RecordFilterScope(BaseModel): entity: Literal["record"] - property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at]] + property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at], Literal["status"]] class Records(BaseModel): diff --git a/argilla-server/src/argilla_server/bulk/records_bulk.py b/argilla-server/src/argilla_server/bulk/records_bulk.py index 0e3d372be5..6acbc30031 100644 --- a/argilla-server/src/argilla_server/bulk/records_bulk.py +++ b/argilla-server/src/argilla_server/bulk/records_bulk.py @@ -29,6 +29,7 @@ ) from argilla_server.api.schemas.v1.responses import UserResponseCreate from argilla_server.api.schemas.v1.suggestions import SuggestionCreate +from argilla_server.contexts import distribution from argilla_server.contexts.accounts import fetch_users_by_ids_as_dict from argilla_server.contexts.records import ( fetch_records_by_external_ids_as_dict, @@ -67,6 +68,7 @@ async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCr await self._upsert_records_relationships(records, bulk_create.items) await _preload_records_relationships_before_index(self._db, records) + await distribution.update_records_status(self._db, records) await self._search_engine.index_records(dataset, records) await self._db.commit() @@ -207,6 +209,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp await self._upsert_records_relationships(records, bulk_upsert.items) await _preload_records_relationships_before_index(self._db, records) + await distribution.update_records_status(self._db, records) await self._search_engine.index_records(dataset, records) await self._db.commit() @@ -237,6 +240,7 @@ async def _preload_records_relationships_before_index(db: "AsyncSession", record .filter(Record.id.in_([record.id for record in records])) .options( selectinload(Record.responses).selectinload(Response.user), + selectinload(Record.responses_submitted), selectinload(Record.suggestions).selectinload(Suggestion.question), selectinload(Record.vectors), ) diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 34468c2b18..1dbf52fc53 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -37,10 +37,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload -from argilla_server.api.schemas.v1.datasets import ( - DatasetCreate, - DatasetProgress, -) +from argilla_server.api.schemas.v1.datasets import DatasetProgress from argilla_server.api.schemas.v1.fields import FieldCreate from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyCreate, MetadataPropertyUpdate from argilla_server.api.schemas.v1.records import ( @@ -63,7 +60,7 @@ VectorSettingsCreate, ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema -from argilla_server.contexts import accounts +from argilla_server.contexts import accounts, distribution from argilla_server.enums import DatasetStatus, RecordInclude, UserRole from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError from argilla_server.models import ( @@ -82,6 +79,7 @@ ) from argilla_server.models.suggestions import SuggestionCreateWithRecordId from argilla_server.search_engine import SearchEngine +from argilla_server.validators.datasets import DatasetCreateValidator, DatasetUpdateValidator from argilla_server.validators.responses import ( ResponseCreateValidator, ResponseUpdateValidator, @@ -122,22 +120,18 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> return result.scalars().all() -async def create_dataset(db: AsyncSession, dataset_create: DatasetCreate): - if await Workspace.get(db, dataset_create.workspace_id) is None: - raise UnprocessableEntityError(f"Workspace with id `{dataset_create.workspace_id}` not found") +async def create_dataset(db: AsyncSession, dataset_attrs: dict): + dataset = Dataset( + name=dataset_attrs["name"], + guidelines=dataset_attrs["guidelines"], + allow_extra_metadata=dataset_attrs["allow_extra_metadata"], + distribution=dataset_attrs["distribution"], + workspace_id=dataset_attrs["workspace_id"], + ) - if await Dataset.get_by(db, name=dataset_create.name, workspace_id=dataset_create.workspace_id): - raise NotUniqueError( - f"Dataset with name `{dataset_create.name}` already exists for workspace with id `{dataset_create.workspace_id}`" - ) + await DatasetCreateValidator.validate(db, dataset) - return await Dataset.create( - db, - name=dataset_create.name, - guidelines=dataset_create.guidelines, - allow_extra_metadata=dataset_create.allow_extra_metadata, - workspace_id=dataset_create.workspace_id, - ) + return await dataset.save(db) async def _count_required_fields_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int: @@ -176,6 +170,12 @@ async def publish_dataset(db: AsyncSession, search_engine: SearchEngine, dataset return dataset +async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> Dataset: + await DatasetUpdateValidator.validate(db, dataset, dataset_attrs) + + return await dataset.update(db, **dataset_attrs) + + async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: Dataset) -> Dataset: async with db.begin_nested(): dataset = await dataset.delete(db, autocommit=False) @@ -186,11 +186,6 @@ async def delete_dataset(db: AsyncSession, search_engine: SearchEngine, dataset: return dataset -async def update_dataset(db: AsyncSession, dataset: Dataset, dataset_update: "DatasetUpdate") -> Dataset: - params = dataset_update.dict(exclude_unset=True) - return await dataset.update(db, **params) - - async def create_field(db: AsyncSession, dataset: Dataset, field_create: FieldCreate) -> Field: if dataset.is_ready: raise UnprocessableEntityError("Field cannot be created for a published dataset") @@ -945,6 +940,9 @@ async def create_response( await db.flush([response]) await _touch_dataset_last_activity_at(db, record.dataset) await search_engine.update_record_response(response) + await db.refresh(record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, record) + await search_engine.partial_record_update(record, status=record.status) await db.commit() @@ -968,6 +966,9 @@ async def update_response( await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.update_record_response(response) + await db.refresh(response.record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, response.record) + await search_engine.partial_record_update(response.record, status=response.record.status) await db.commit() @@ -997,6 +998,9 @@ async def upsert_response( await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.update_record_response(response) + await db.refresh(record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, record) + await search_engine.partial_record_update(record, status=record.status) await db.commit() @@ -1006,9 +1010,13 @@ async def upsert_response( async def delete_response(db: AsyncSession, search_engine: SearchEngine, response: Response) -> Response: async with db.begin_nested(): response = await response.delete(db, autocommit=False) + await _load_users_from_responses(response) await _touch_dataset_last_activity_at(db, response.record.dataset) await search_engine.delete_record_response(response) + await db.refresh(response.record, attribute_names=[Record.responses_submitted.key]) + await distribution.update_record_status(db, response.record) + await search_engine.partial_record_update(record=response.record, status=response.record.status) await db.commit() diff --git a/argilla-server/src/argilla_server/contexts/distribution.py b/argilla-server/src/argilla_server/contexts/distribution.py new file mode 100644 index 0000000000..92973801ce --- /dev/null +++ b/argilla-server/src/argilla_server/contexts/distribution.py @@ -0,0 +1,42 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus +from argilla_server.models import Record + + +# TODO: Do this with one single update statement for all records if possible to avoid too many queries. +async def update_records_status(db: AsyncSession, records: List[Record]): + for record in records: + await update_record_status(db, record) + + +async def update_record_status(db: AsyncSession, record: Record) -> Record: + if record.dataset.distribution_strategy == DatasetDistributionStrategy.overlap: + return await _update_record_status_with_overlap_strategy(db, record) + + raise NotImplementedError(f"unsupported distribution strategy `{record.dataset.distribution_strategy}`") + + +async def _update_record_status_with_overlap_strategy(db: AsyncSession, record: Record) -> Record: + if len(record.responses_submitted) >= record.dataset.distribution["min_submitted"]: + record.status = RecordStatus.completed + else: + record.status = RecordStatus.pending + + return await record.save(db, autocommit=False) diff --git a/argilla-server/src/argilla_server/enums.py b/argilla-server/src/argilla_server/enums.py index 13b4843280..fcf0b3142f 100644 --- a/argilla-server/src/argilla_server/enums.py +++ b/argilla-server/src/argilla_server/enums.py @@ -43,12 +43,21 @@ class DatasetStatus(str, Enum): ready = "ready" +class DatasetDistributionStrategy(str, Enum): + overlap = "overlap" + + class UserRole(str, Enum): owner = "owner" admin = "admin" annotator = "annotator" +class RecordStatus(str, Enum): + pending = "pending" + completed = "completed" + + class RecordInclude(str, Enum): responses = "responses" suggestions = "suggestions" diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 468b682467..37bd7730c9 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -29,9 +29,12 @@ DatasetStatus, MetadataPropertyType, QuestionType, + RecordStatus, ResponseStatus, SuggestionType, UserRole, + DatasetDistributionStrategy, + RecordStatus, ) from argilla_server.models.base import DatabaseModel from argilla_server.models.metadata_properties import MetadataPropertySettings @@ -180,11 +183,17 @@ def __repr__(self) -> str: ) +RecordStatusEnum = SAEnum(RecordStatus, name="record_status_enum") + + class Record(DatabaseModel): __tablename__ = "records" fields: Mapped[dict] = mapped_column(JSON, default={}) metadata_: Mapped[Optional[dict]] = mapped_column("metadata", MutableDict.as_mutable(JSON), nullable=True) + status: Mapped[RecordStatus] = mapped_column( + RecordStatusEnum, default=RecordStatus.pending, server_default=RecordStatus.pending, index=True + ) external_id: Mapped[Optional[str]] = mapped_column(index=True) dataset_id: Mapped[UUID] = mapped_column(ForeignKey("datasets.id", ondelete="CASCADE"), index=True) @@ -195,6 +204,13 @@ class Record(DatabaseModel): passive_deletes=True, order_by=Response.inserted_at.asc(), ) + responses_submitted: Mapped[List["Response"]] = relationship( + back_populates="record", + cascade="all, delete-orphan", + passive_deletes=True, + primaryjoin=f"and_(Record.id==Response.record_id, Response.status=='{ResponseStatus.submitted}')", + order_by=Response.inserted_at.asc(), + ) suggestions: Mapped[List["Suggestion"]] = relationship( back_populates="record", cascade="all, delete-orphan", @@ -210,17 +226,17 @@ class Record(DatabaseModel): __table_args__ = (UniqueConstraint("external_id", "dataset_id", name="record_external_id_dataset_id_uq"),) + def vector_value_by_vector_settings(self, vector_settings: "VectorSettings") -> Union[List[float], None]: + for vector in self.vectors: + if vector.vector_settings_id == vector_settings.id: + return vector.value + def __repr__(self): return ( f"Record(id={str(self.id)!r}, external_id={self.external_id!r}, dataset_id={str(self.dataset_id)!r}, " f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" ) - def vector_value_by_vector_settings(self, vector_settings: "VectorSettings") -> Union[List[float], None]: - for vector in self.vectors: - if vector.vector_settings_id == vector_settings.id: - return vector.value - class Question(DatabaseModel): __tablename__ = "questions" @@ -304,6 +320,7 @@ class Dataset(DatabaseModel): guidelines: Mapped[Optional[str]] = mapped_column(Text) allow_extra_metadata: Mapped[bool] = mapped_column(default=True, server_default=sql.true()) status: Mapped[DatasetStatus] = mapped_column(DatasetStatusEnum, default=DatasetStatus.draft, index=True) + distribution: Mapped[dict] = mapped_column(MutableDict.as_mutable(JSON)) workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True) inserted_at: Mapped[datetime] = mapped_column(default=datetime.utcnow) updated_at: Mapped[datetime] = mapped_column(default=inserted_at_current_value, onupdate=datetime.utcnow) @@ -353,6 +370,10 @@ def is_draft(self): def is_ready(self): return self.status == DatasetStatus.ready + @property + def distribution_strategy(self) -> DatasetDistributionStrategy: + return DatasetDistributionStrategy(self.distribution["strategy"]) + def metadata_property_by_name(self, name: str) -> Union["MetadataProperty", None]: for metadata_property in self.metadata_properties: if metadata_property.name == name: diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index 7c9146cafe..ee1dbcc386 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -317,6 +317,10 @@ async def configure_metadata_property(self, dataset: Dataset, metadata_property: async def index_records(self, dataset: Dataset, records: Iterable[Record]): pass + @abstractmethod + async def partial_record_update(self, record: Record, **update): + pass + @abstractmethod async def delete_records(self, dataset: Dataset, records: Iterable[Record]): pass diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 2030b59ae5..b328224f19 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -346,6 +346,10 @@ async def index_records(self, dataset: Dataset, records: Iterable[Record]): await self._bulk_op_request(bulk_actions) + async def partial_record_update(self, record: Record, **update): + index_name = await self._get_dataset_index(record.dataset) + await self._update_document_request(index_name=index_name, id=str(record.id), body={"doc": update}) + async def delete_records(self, dataset: Dataset, records: Iterable[Record]): index_name = await self._get_dataset_index(dataset) @@ -552,6 +556,7 @@ def _map_record_to_es_document(self, record: Record) -> Dict[str, Any]: document = { "id": str(record.id), "fields": record.fields, + "status": record.status, "inserted_at": record.inserted_at, "updated_at": record.updated_at, } @@ -712,6 +717,7 @@ def _configure_index_mappings(self, dataset: Dataset) -> dict: "properties": { # See https://www.elastic.co/guide/en/elasticsearch/reference/current/explicit-mapping.html "id": {"type": "keyword"}, + "status": {"type": "keyword"}, RecordSortField.inserted_at.value: {"type": "date_nanos"}, RecordSortField.updated_at.value: {"type": "date_nanos"}, "responses": {"dynamic": True, "type": "object"}, diff --git a/argilla-server/src/argilla_server/validators/datasets.py b/argilla-server/src/argilla_server/validators/datasets.py new file mode 100644 index 0000000000..aae2a5fc83 --- /dev/null +++ b/argilla-server/src/argilla_server/validators/datasets.py @@ -0,0 +1,48 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError +from argilla_server.models import Dataset, Workspace + + +class DatasetCreateValidator: + @classmethod + async def validate(cls, db: AsyncSession, dataset: Dataset) -> None: + await cls._validate_workspace_is_present(db, dataset.workspace_id) + await cls._validate_name_is_not_duplicated(db, dataset.name, dataset.workspace_id) + + @classmethod + async def _validate_workspace_is_present(cls, db: AsyncSession, workspace_id: UUID) -> None: + if await Workspace.get(db, workspace_id) is None: + raise UnprocessableEntityError(f"Workspace with id `{workspace_id}` not found") + + @classmethod + async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, workspace_id: UUID) -> None: + if await Dataset.get_by(db, name=name, workspace_id=workspace_id): + raise NotUniqueError(f"Dataset with name `{name}` already exists for workspace with id `{workspace_id}`") + + +class DatasetUpdateValidator: + @classmethod + async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None: + cls._validate_distribution(dataset, dataset_attrs) + + @classmethod + def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None: + if dataset.is_ready and dataset_attrs.get("distribution") is not None: + raise UnprocessableEntityError(f"Distribution settings cannot be modified for a published dataset") diff --git a/argilla-server/tests/factories.py b/argilla-server/tests/factories.py index 5c77b9a0f5..c429fed9af 100644 --- a/argilla-server/tests/factories.py +++ b/argilla-server/tests/factories.py @@ -16,7 +16,7 @@ import random import factory -from argilla_server.enums import FieldType, MetadataPropertyType, OptionsOrder +from argilla_server.enums import DatasetDistributionStrategy, FieldType, MetadataPropertyType, OptionsOrder from argilla_server.models import ( Dataset, Field, @@ -203,6 +203,7 @@ class Meta: model = Dataset name = factory.Sequence(lambda n: f"dataset-{n}") + distribution = {"strategy": DatasetDistributionStrategy.overlap, "min_submitted": 1} workspace = factory.SubFactory(WorkspaceFactory) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py index d7e95520d5..3d1f0bf6da 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_dataset_records_bulk.py @@ -15,7 +15,7 @@ from uuid import UUID import pytest -from argilla_server.enums import DatasetStatus +from argilla_server.enums import DatasetStatus, RecordStatus from argilla_server.models import Dataset, Record from httpx import AsyncClient from sqlalchemy import func, select @@ -87,6 +87,7 @@ async def test_create_dataset_records_bulk( "items": [ { "id": str(record.id), + "status": RecordStatus.pending, "dataset_id": str(dataset.id), "external_id": record.external_id, "fields": record.fields, diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py new file mode 100644 index 0000000000..4261145d0c --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_create_dataset.py @@ -0,0 +1,139 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from argilla_server.models import Dataset +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from tests.factories import WorkspaceFactory + + +@pytest.mark.asyncio +class TestCreateDataset: + def url(self) -> str: + return "/api/v1/datasets" + + async def test_create_dataset_with_default_distribution( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "workspace_id": str(workspace.id), + }, + ) + + dataset = (await db.execute(select(Dataset))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(dataset.id), + "name": "Dataset Name", + "guidelines": None, + "allow_extra_metadata": True, + "status": DatasetStatus.draft, + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + "workspace_id": str(workspace.id), + "last_activity_at": dataset.last_activity_at.isoformat(), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + async def test_create_dataset_with_overlap_distribution( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + "workspace_id": str(workspace.id), + }, + ) + + dataset = (await db.execute(select(Dataset))).scalar_one() + + assert response.status_code == 201 + assert response.json() == { + "id": str(dataset.id), + "name": "Dataset Name", + "guidelines": None, + "allow_extra_metadata": True, + "status": DatasetStatus.draft, + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + "workspace_id": str(workspace.id), + "last_activity_at": dataset.last_activity_at.isoformat(), + "inserted_at": dataset.inserted_at.isoformat(), + "updated_at": dataset.updated_at.isoformat(), + } + + async def test_create_dataset_with_overlap_distribution_using_invalid_min_submitted_value( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset name", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 0, + }, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 + + async def test_create_dataset_with_invalid_distribution_strategy( + self, db: AsyncSession, async_client: AsyncClient, owner_auth_header: dict + ): + workspace = await WorkspaceFactory.create() + + response = await async_client.post( + self.url(), + headers=owner_auth_header, + json={ + "name": "Dataset Name", + "distribution": { + "strategy": "invalid_strategy", + }, + "workspace_id": str(workspace.id), + }, + ) + + assert response.status_code == 422 + assert (await db.execute(select(func.count(Dataset.id)))).scalar_one() == 0 diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py index e70072d814..8d4981e828 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_current_user_dataset_records.py @@ -16,7 +16,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import UserRole +from argilla_server.enums import UserRole, RecordStatus from argilla_server.search_engine import SearchEngine, SearchResponseItem, SearchResponses from httpx import AsyncClient @@ -71,6 +71,7 @@ async def test_search_with_filtered_metadata( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": record.metadata_, "external_id": record.external_id, @@ -122,6 +123,7 @@ async def test_search_with_filtered_metadata_as_annotator( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": {"annotator_meta": "value"}, "external_id": record.external_id, @@ -173,6 +175,7 @@ async def test_search_with_filtered_metadata_as_admin( { "record": { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": {"admin_meta": "value", "annotator_meta": "value", "extra": "value"}, "external_id": record.external_id, diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 3d22527c3b..73077c4381 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -17,7 +17,7 @@ import pytest from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_LE from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, SortOrder +from argilla_server.enums import RecordInclude, SortOrder, RecordStatus from argilla_server.search_engine import ( AndFilter, Order, @@ -118,6 +118,7 @@ async def test_with_include_responses( { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": { "sentiment": "neutral", "text": "This is a text", @@ -153,6 +154,7 @@ async def test_with_include_responses( { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": { "sentiment": "neutral", "text": "This is a text", diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py new file mode 100644 index 0000000000..cdb9b06ea2 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -0,0 +1,178 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +import pytest +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus +from httpx import AsyncClient + +from tests.factories import DatasetFactory + + +@pytest.mark.asyncio +class TestUpdateDataset: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}" + + async def test_update_dataset_distribution(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + }, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + } + + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + } + + async def test_update_dataset_without_distribution(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"name": "Dataset updated name"}, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + assert dataset.name == "Dataset updated name" + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_without_distribution_for_published_dataset( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"name": "Dataset updated name"}, + ) + + assert response.status_code == 200 + assert response.json()["distribution"] == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + assert dataset.name == "Dataset updated name" + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_for_published_dataset( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + }, + ) + + assert response.status_code == 422 + assert response.json() == {"detail": "Distribution settings cannot be modified for a published dataset"} + + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_with_invalid_strategy( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": "invalid_strategy", + }, + }, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_with_invalid_min_submitted_value( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 0, + }, + }, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + + async def test_update_dataset_distribution_as_none(self, async_client: AsyncClient, owner_auth_header: dict): + dataset = await DatasetFactory.create() + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={"distribution": None}, + ) + + assert response.status_code == 422 + assert dataset.distribution == { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py new file mode 100644 index 0000000000..1aae133535 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_create_dataset_records_bulk.py @@ -0,0 +1,145 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.models import User, Record +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus, ResponseStatus, DatasetStatus + +from tests.factories import AnnotatorFactory, DatasetFactory, TextFieldFactory, TextQuestionFactory + + +@pytest.mark.asyncio +class TestCreateDatasetRecordsBulk: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/records/bulk" + + async def test_create_dataset_records_bulk_updates_records_status( + self, db: AsyncSession, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + status=DatasetStatus.ready, + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + user = await AnnotatorFactory.create(workspaces=[dataset.workspace]) + + await TextFieldFactory.create(name="prompt", dataset=dataset) + await TextFieldFactory.create(name="response", dataset=dataset) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + response = await async_client.post( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "fields": { + "prompt": "Does exercise help reduce stress?", + "response": "Exercise can definitely help reduce stress.", + }, + }, + ], + }, + ) + + assert response.status_code == 201 + + response_items = response.json()["items"] + assert response_items[0]["status"] == RecordStatus.completed + assert response_items[1]["status"] == RecordStatus.pending + assert response_items[2]["status"] == RecordStatus.pending + assert response_items[3]["status"] == RecordStatus.pending + + assert (await Record.get(db, UUID(response_items[0]["id"]))).status == RecordStatus.completed + assert (await Record.get(db, UUID(response_items[1]["id"]))).status == RecordStatus.pending + assert (await Record.get(db, UUID(response_items[2]["id"]))).status == RecordStatus.pending + assert (await Record.get(db, UUID(response_items[3]["id"]))).status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py b/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py index 98b3a864b9..ce433d036d 100644 --- a/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_create_record_response.py @@ -16,13 +16,15 @@ from uuid import UUID import pytest -from argilla_server.enums import ResponseStatusFilter -from argilla_server.models import Response, User + from httpx import AsyncClient from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession -from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory +from argilla_server.enums import ResponseStatus, RecordStatus, DatasetDistributionStrategy +from argilla_server.models import Response, User + +from tests.factories import DatasetFactory, RecordFactory, SpanQuestionFactory, TextQuestionFactory @pytest.mark.asyncio @@ -52,7 +54,7 @@ async def test_create_record_response_for_span_question( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -72,7 +74,7 @@ async def test_create_record_response_for_span_question( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -101,7 +103,7 @@ async def test_create_record_response_for_span_question_with_additional_value_at ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -121,7 +123,7 @@ async def test_create_record_response_for_span_question_with_additional_value_at ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -146,7 +148,7 @@ async def test_create_record_response_for_span_question_with_empty_value( "value": [], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -162,7 +164,7 @@ async def test_create_record_response_for_span_question_with_empty_value( "value": [], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, "record_id": str(record.id), "user_id": str(owner.id), "inserted_at": datetime.fromisoformat(response_json["inserted_at"]).isoformat(), @@ -189,7 +191,7 @@ async def test_create_record_response_for_span_question_with_record_not_providin ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -219,7 +221,7 @@ async def test_create_record_response_for_span_question_with_invalid_value( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -244,7 +246,7 @@ async def test_create_record_response_for_span_question_with_start_greater_than_ "value": [{"label": "label-a", "start": 5, "end": 6}], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -273,7 +275,7 @@ async def test_create_record_response_for_span_question_with_end_greater_than_ex "value": [{"label": "label-a", "start": 4, "end": 6}], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -304,7 +306,7 @@ async def test_create_record_response_for_span_question_with_invalid_start( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -331,7 +333,7 @@ async def test_create_record_response_for_span_question_with_invalid_end( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -358,7 +360,7 @@ async def test_create_record_response_for_span_question_with_equal_start_and_end ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -385,7 +387,7 @@ async def test_create_record_response_for_span_question_with_end_smaller_than_st ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -412,7 +414,7 @@ async def test_create_record_response_for_span_question_with_non_existent_label( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -446,7 +448,7 @@ async def test_create_record_response_for_span_question_with_overlapped_values( ], }, }, - "status": ResponseStatusFilter.submitted, + "status": ResponseStatus.submitted, }, ) @@ -454,3 +456,63 @@ async def test_create_record_response_for_span_question_with_overlapped_values( assert response.json() == {"detail": "overlapping values found between spans at index idx=0 and idx=2"} assert (await db.execute(select(func.count(Response.id)))).scalar() == 0 + + async def test_create_record_response_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert response.status_code == 201 + assert record.status == RecordStatus.completed + + async def test_create_record_response_does_not_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + } + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + + response = await async_client.post( + self.url(record.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert response.status_code == 201 + assert record.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py b/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py new file mode 100644 index 0000000000..82b035a58a --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/records/test_upsert_dataset_records_bulk.py @@ -0,0 +1,153 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from uuid import UUID +from httpx import AsyncClient + +from argilla_server.models import User +from argilla_server.enums import DatasetDistributionStrategy, ResponseStatus, DatasetStatus, RecordStatus + +from tests.factories import DatasetFactory, RecordFactory, TextQuestionFactory, ResponseFactory, AnnotatorFactory + + +@pytest.mark.asyncio +class TestUpsertDatasetRecordsBulk: + def url(self, dataset_id: UUID) -> str: + return f"/api/v1/datasets/{dataset_id}/records/bulk" + + async def test_upsert_dataset_records_bulk_updates_records_status( + self, async_client: AsyncClient, owner: User, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + status=DatasetStatus.ready, + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + }, + ) + + user = await AnnotatorFactory.create(workspaces=[dataset.workspace]) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record_a = await RecordFactory.create(dataset=dataset) + assert record_a.status == RecordStatus.pending + + await ResponseFactory.create( + user=owner, + record=record_a, + status=ResponseStatus.submitted, + values={ + "text-question": { + "value": "text question response", + }, + }, + ) + + record_b = await RecordFactory.create(dataset=dataset) + assert record_b.status == RecordStatus.pending + + record_c = await RecordFactory.create(dataset=dataset) + assert record_c.status == RecordStatus.pending + + record_d = await RecordFactory.create(dataset=dataset) + assert record_d.status == RecordStatus.pending + + response = await async_client.put( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "items": [ + { + "id": str(record_a.id), + "responses": [ + { + "user_id": str(user.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_b.id), + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.submitted, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_c.id), + "responses": [ + { + "user_id": str(owner.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + { + "user_id": str(user.id), + "status": ResponseStatus.draft, + "values": { + "text-question": { + "value": "text question response", + }, + }, + }, + ], + }, + { + "id": str(record_d.id), + "responses": [], + }, + ], + }, + ) + + assert response.status_code == 200 + + respose_items = response.json()["items"] + assert respose_items[0]["status"] == RecordStatus.completed + assert respose_items[1]["status"] == RecordStatus.pending + assert respose_items[2]["status"] == RecordStatus.pending + assert respose_items[3]["status"] == RecordStatus.pending + + assert record_a.status == RecordStatus.completed + assert record_b.status == RecordStatus.pending + assert record_c.status == RecordStatus.pending + assert record_d.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py index 009cec7d2e..07b4bf0199 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_create_current_user_responses_bulk.py @@ -18,7 +18,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus +from argilla_server.enums import ResponseStatus, RecordStatus from argilla_server.models import Response, User from argilla_server.search_engine import SearchEngine from argilla_server.use_cases.responses.upsert_responses_in_bulk import UpsertResponsesInBulkUseCase @@ -111,7 +111,7 @@ async def test_multiple_responses( "item": { "id": str(response_to_create_id), "values": {"prompt-quality": {"value": 5}}, - "status": ResponseStatus.submitted.value, + "status": ResponseStatus.submitted, "record_id": str(records[0].id), "user_id": str(annotator.id), "inserted_at": datetime.fromisoformat(resp_json["items"][0]["item"]["inserted_at"]).isoformat(), @@ -123,7 +123,7 @@ async def test_multiple_responses( "item": { "id": str(response_to_update.id), "values": {"prompt-quality": {"value": 10}}, - "status": ResponseStatus.submitted.value, + "status": ResponseStatus.submitted, "record_id": str(records[1].id), "user_id": str(annotator.id), "inserted_at": datetime.fromisoformat(resp_json["items"][1]["item"]["inserted_at"]).isoformat(), @@ -146,6 +146,10 @@ async def test_multiple_responses( ], } + assert records[0].status == RecordStatus.completed + assert records[1].status == RecordStatus.completed + assert records[2].status == RecordStatus.pending + assert (await db.execute(select(func.count(Response.id)))).scalar() == 2 response_to_create = (await db.execute(select(Response).filter_by(id=response_to_create_id))).scalar_one() diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py new file mode 100644 index 0000000000..6b9d4ec749 --- /dev/null +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_delete_response.py @@ -0,0 +1,66 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +import pytest + +from httpx import AsyncClient + +from argilla_server.models import User +from argilla_server.enums import DatasetDistributionStrategy, RecordStatus, ResponseStatus + +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, TextQuestionFactory + + +@pytest.mark.asyncio +class TestDeleteResponse: + def url(self, response_id: UUID) -> str: + return f"/api/v1/responses/{response_id}" + + async def test_delete_response_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + } + ) + + record = await RecordFactory.create(status=RecordStatus.completed, dataset=dataset) + response = await ResponseFactory.create(record=record) + + resp = await async_client.delete(self.url(response.id), headers=owner_auth_header) + + assert resp.status_code == 200 + assert record.status == RecordStatus.pending + + async def test_delete_response_does_not_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 2, + } + ) + + record = await RecordFactory.create(status=RecordStatus.completed, dataset=dataset) + responses = await ResponseFactory.create_batch(3, record=record) + + resp = await async_client.delete(self.url(responses[0].id), headers=owner_auth_header) + + assert resp.status_code == 200 + assert record.status == RecordStatus.completed diff --git a/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py b/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py index f5ffab7b31..d5097f8c7b 100644 --- a/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py +++ b/argilla-server/tests/unit/api/handlers/v1/responses/test_update_response.py @@ -16,13 +16,15 @@ from uuid import UUID import pytest -from argilla_server.enums import ResponseStatus -from argilla_server.models import Response, User from httpx import AsyncClient + from sqlalchemy import select from sqlalchemy.ext.asyncio.session import AsyncSession -from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, SpanQuestionFactory +from argilla_server.enums import ResponseStatus, DatasetDistributionStrategy, RecordStatus +from argilla_server.models import Response, User + +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, SpanQuestionFactory, TextQuestionFactory @pytest.mark.asyncio @@ -560,3 +562,66 @@ async def test_update_response_for_span_question_with_non_existent_label( } assert (await db.execute(select(Response).filter_by(id=response.id))).scalar_one().values == response_values + + async def test_update_response_updates_record_status_to_completed( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset) + response = await ResponseFactory.create(record=record, status=ResponseStatus.draft) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={ + "values": { + "text-question": { + "value": "text question updated response", + }, + }, + "status": ResponseStatus.submitted, + }, + ) + + assert resp.status_code == 200 + assert record.status == RecordStatus.completed + + async def test_update_response_updates_record_status_to_pending( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create( + distribution={ + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, + ) + + await TextQuestionFactory.create(name="text-question", dataset=dataset) + + record = await RecordFactory.create(fields={"field-a": "Hello"}, dataset=dataset, status=RecordStatus.completed) + response = await ResponseFactory.create( + values={ + "text-question": { + "value": "text question response", + }, + }, + record=record, + status=ResponseStatus.submitted, + ) + + resp = await async_client.put( + self.url(response.id), + headers=owner_auth_header, + json={"status": ResponseStatus.draft}, + ) + + assert resp.status_code == 200 + assert record.status == RecordStatus.pending diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 650e9f3808..e0c9fe4d5e 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -34,11 +34,13 @@ ) from argilla_server.constants import API_KEY_HEADER_NAME from argilla_server.enums import ( + DatasetDistributionStrategy, DatasetStatus, OptionsOrder, RecordInclude, ResponseStatusFilter, SimilarityOrder, + RecordStatus, ) from argilla_server.models import ( Dataset, @@ -116,6 +118,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": None, "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_a.workspace_id), "last_activity_at": dataset_a.last_activity_at.isoformat(), "inserted_at": dataset_a.inserted_at.isoformat(), @@ -127,6 +133,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": "guidelines", "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_b.workspace_id), "last_activity_at": dataset_b.last_activity_at.isoformat(), "inserted_at": dataset_b.inserted_at.isoformat(), @@ -138,6 +148,10 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own "guidelines": None, "allow_extra_metadata": True, "status": "ready", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset_c.workspace_id), "last_activity_at": dataset_c.last_activity_at.isoformat(), "inserted_at": dataset_c.inserted_at.isoformat(), @@ -653,8 +667,6 @@ async def test_list_dataset_vectors_settings_without_authentication(self, async_ assert response.status_code == 401 - # Helper function to create records with responses - async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create(name="dataset") @@ -667,6 +679,10 @@ async def test_get_dataset(self, async_client: "AsyncClient", owner_auth_header: "guidelines": None, "allow_extra_metadata": True, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), @@ -839,13 +855,16 @@ async def test_create_dataset(self, async_client: "AsyncClient", db: "AsyncSessi await db.refresh(workspace) response_body = response.json() - assert (await db.execute(select(func.count(Dataset.id)))).scalar() == 1 assert response_body == { "id": str(UUID(response_body["id"])), "name": "name", "guidelines": "guidelines", "allow_extra_metadata": False, "status": "draft", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(workspace.id), "last_activity_at": datetime.fromisoformat(response_body["last_activity_at"]).isoformat(), "inserted_at": datetime.fromisoformat(response_body["inserted_at"]).isoformat(), @@ -3644,6 +3663,7 @@ async def test_search_current_user_dataset_records( { "record": { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": {"input": "input_a", "output": "output_a"}, "metadata": None, "external_id": records[0].external_id, @@ -3656,6 +3676,7 @@ async def test_search_current_user_dataset_records( { "record": { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"unit": "test"}, "external_id": records[1].external_id, @@ -3997,6 +4018,7 @@ async def test_search_current_user_dataset_records_with_include( { "record": { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": { "input": "input_a", "output": "output_a", @@ -4012,6 +4034,7 @@ async def test_search_current_user_dataset_records_with_include( { "record": { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": { "input": "input_b", "output": "output_b", @@ -4151,6 +4174,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -4167,6 +4191,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, @@ -4182,6 +4207,7 @@ async def test_search_current_user_dataset_records_with_include_vectors( { "record": { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, @@ -4245,6 +4271,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -4261,6 +4288,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, @@ -4276,6 +4304,7 @@ async def test_search_current_user_dataset_records_with_include_specific_vectors { "record": { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, @@ -4752,6 +4781,10 @@ async def test_update_dataset(self, async_client: "AsyncClient", db: "AsyncSessi "guidelines": guidelines, "allow_extra_metadata": allow_extra_metadata, "status": "ready", + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 1, + }, "workspace_id": str(dataset.workspace_id), "last_activity_at": dataset.last_activity_at.isoformat(), "inserted_at": dataset.inserted_at.isoformat(), diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index f088cfcda9..8f78940df3 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -18,7 +18,7 @@ import pytest from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole +from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole, RecordStatus from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace from argilla_server.search_engine import ( FloatMetadataFilter, @@ -821,6 +821,7 @@ async def test_list_current_user_dataset_records( "items": [ { "id": str(record_a.id), + "status": RecordStatus.pending, "fields": {"input": "input_a", "output": "output_a"}, "metadata": None, "dataset_id": str(dataset.id), @@ -830,6 +831,7 @@ async def test_list_current_user_dataset_records( }, { "id": str(record_b.id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"unit": "test"}, "dataset_id": str(dataset.id), @@ -839,6 +841,7 @@ async def test_list_current_user_dataset_records( }, { "id": str(record_c.id), + "status": RecordStatus.pending, "fields": {"input": "input_c", "output": "output_c"}, "metadata": None, "dataset_id": str(dataset.id), @@ -898,6 +901,7 @@ async def test_list_current_user_dataset_records_with_filtered_metadata_as_annot "items": [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"input": "input_b", "output": "output_b"}, "metadata": {"key1": "value1"}, "dataset_id": str(dataset.id), diff --git a/argilla-server/tests/unit/api/handlers/v1/test_records.py b/argilla-server/tests/unit/api/handlers/v1/test_records.py index ed7d9f8cc2..3c361b1666 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_records.py @@ -19,7 +19,7 @@ import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus +from argilla_server.enums import RecordStatus, ResponseStatus from argilla_server.models import Dataset, Record, Response, Suggestion, User, UserRole from argilla_server.search_engine import SearchEngine from sqlalchemy import func, select @@ -92,6 +92,7 @@ async def test_get_record(self, async_client: "AsyncClient", role: UserRole): assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -188,6 +189,7 @@ async def test_update_record(self, async_client: "AsyncClient", mock_search_engi assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": { "terms-metadata-property": "c", @@ -228,6 +230,7 @@ async def test_update_record(self, async_client: "AsyncClient", mock_search_engi "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), } + mock_search_engine.index_records.assert_called_once_with(dataset, [record]) async def test_update_record_with_null_metadata( @@ -251,6 +254,7 @@ async def test_update_record_with_null_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -278,6 +282,7 @@ async def test_update_record_with_no_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -310,6 +315,7 @@ async def test_update_record_with_list_terms_metadata( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": { "terms-metadata-property": ["a", "b", "c"], @@ -339,6 +345,7 @@ async def test_update_record_with_no_suggestions( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record.external_id, @@ -1413,6 +1420,7 @@ async def test_delete_record( assert response.status_code == 200 assert response.json() == { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "metadata": None, "external_id": record.external_id, diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index ecba3232a6..c4376ca686 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -16,7 +16,7 @@ import pytest import pytest_asyncio -from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder +from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( FloatMetadataFilter, @@ -263,6 +263,7 @@ async def refresh_records(records: List[Record]): for record in records: await record.awaitable_attrs.suggestions await record.awaitable_attrs.responses + await record.awaitable_attrs.responses_submitted await record.awaitable_attrs.vectors @@ -314,6 +315,7 @@ async def test_create_index_for_dataset( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -356,6 +358,7 @@ async def test_create_index_for_dataset_with_fields( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -428,6 +431,7 @@ async def test_create_index_for_dataset_with_metadata_properties( ], "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -475,6 +479,7 @@ async def test_create_index_for_dataset_with_questions( "dynamic": "strict", "properties": { "id": {"type": "keyword"}, + "status": {"type": "keyword"}, "inserted_at": {"type": "date_nanos"}, "updated_at": {"type": "date_nanos"}, ALL_RESPONSES_STATUSES_FIELD: {"type": "keyword"}, @@ -879,6 +884,7 @@ async def test_index_records(self, search_engine: BaseElasticAndOpenSearchEngine assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), @@ -937,6 +943,7 @@ async def test_index_records_with_suggestions( assert es_docs == [ { "id": str(records[0].id), + "status": RecordStatus.pending, "fields": records[0].fields, "inserted_at": records[0].inserted_at.isoformat(), "updated_at": records[0].updated_at.isoformat(), @@ -944,6 +951,7 @@ async def test_index_records_with_suggestions( }, { "id": str(records[1].id), + "status": RecordStatus.pending, "fields": records[1].fields, "inserted_at": records[1].inserted_at.isoformat(), "updated_at": records[1].updated_at.isoformat(), @@ -978,6 +986,7 @@ async def test_index_records_with_metadata( assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), @@ -1017,6 +1026,7 @@ async def test_index_records_with_vectors( assert es_docs == [ { "id": str(record.id), + "status": RecordStatus.pending, "fields": record.fields, "inserted_at": record.inserted_at.isoformat(), "updated_at": record.updated_at.isoformat(), diff --git a/argilla/src/argilla/_models/_search.py b/argilla/src/argilla/_models/_search.py index f62dbff0b7..3c256805a0 100644 --- a/argilla/src/argilla/_models/_search.py +++ b/argilla/src/argilla/_models/_search.py @@ -17,6 +17,11 @@ from pydantic import BaseModel, Field +class RecordFilterScopeModel(BaseModel): + entity: Literal["record"] = "record" + property: Literal["status"] = "status" + + class ResponseFilterScopeModel(BaseModel): """Filter scope for filtering on a response entity.""" @@ -42,6 +47,7 @@ class MetadataFilterScopeModel(BaseModel): ScopeModel = Annotated[ Union[ + RecordFilterScopeModel, ResponseFilterScopeModel, SuggestionFilterScopeModel, MetadataFilterScopeModel, diff --git a/argilla/src/argilla/records/_search.py b/argilla/src/argilla/records/_search.py index adc56b5750..6ccdcee33a 100644 --- a/argilla/src/argilla/records/_search.py +++ b/argilla/src/argilla/records/_search.py @@ -26,6 +26,7 @@ FilterModel, AndFilterModel, QueryModel, + RecordFilterScopeModel, ) @@ -54,8 +55,9 @@ def model(self) -> FilterModel: @staticmethod def _extract_filter_scope(field: str) -> ScopeModel: field = field.strip() - if field == "status": + return RecordFilterScopeModel(property="status") + elif field == "responses.status": return ResponseFilterScopeModel(property="status") elif "metadata" in field: _, md_property = field.split(".") From f084ab7026a13475c467aef6bb3fe430eb6c0f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dami=C3=A1n=20Pumar?= Date: Thu, 4 Jul 2024 09:36:19 +0200 Subject: [PATCH 02/36] =?UTF-8?q?=E2=9C=A8=20Remove=20unused=20method?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../repositories/RecordRepository.ts | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts index 40ce2645eb..871282d9e7 100644 --- a/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/RecordRepository.ts @@ -43,7 +43,6 @@ export class RecordRepository { getRecords(criteria: RecordCriteria): Promise { return this.getRecordsByAdvanceSearch(criteria); - // return this.getRecordsByDatasetId(criteria); } async getRecord(recordId: string): Promise { @@ -186,35 +185,6 @@ export class RecordRepository { } } - private async getRecordsByDatasetId( - criteria: RecordCriteria - ): Promise { - const { datasetId, status, page } = criteria; - const { from, many } = page.server; - try { - const url = `/v1/me/datasets/${datasetId}/records`; - - const params = this.createParams(from, many, status); - - const { data } = await this.axios.get>( - url, - { - params, - } - ); - const { items: records, total } = data; - - return { - records, - total, - }; - } catch (err) { - throw { - response: RECORD_API_ERRORS.ERROR_FETCHING_RECORDS, - }; - } - } - private async getRecordsByAdvanceSearch( criteria: RecordCriteria ): Promise { From 6df52560973eda8f1dd7b67fa282ed5db18ea5e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 4 Jul 2024 11:09:56 +0200 Subject: [PATCH 03/36] feat: improve Records `responses_submitted` relationship to be view only (#5148) # Description Add changes to `responses_submitted` relationship to avoid problems with existent `responses` relationship and avoid a warning message that SQLAlchemy was reporting. Refs #5000 **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** - [x] Warning is not showing anymore. - [x] Test are passing. **Checklist** - I added relevant documentation - follows the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/src/argilla_server/models/database.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 37bd7730c9..3230916362 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -206,8 +206,7 @@ class Record(DatabaseModel): ) responses_submitted: Mapped[List["Response"]] = relationship( back_populates="record", - cascade="all, delete-orphan", - passive_deletes=True, + viewonly=True, primaryjoin=f"and_(Record.id==Response.record_id, Response.status=='{ResponseStatus.submitted}')", order_by=Response.inserted_at.asc(), ) From cf3408c7988285b3083edd28fa9c7936370283ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Thu, 4 Jul 2024 11:56:13 +0200 Subject: [PATCH 04/36] feat: change metrics to support new distribution task logic (#5140) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR adds changes to the endpoints to get the dataset progress and current user metrics in the following way: ## `GET /datasets/:dataset_id/progress` I have changed the endpoint to support the new business logic behind the distribution task. Responding with only `completed` and `pending` type of records and using `total` as the sum of the two types of records. Old response without distribution task: ```json { "total": 8, "submitted": 2, "discarded": 2, "conflicting": 1, "pending": 3 } ``` New response with the changes from this PR supporting distribution task: * The `completed` attribute will have the count of all the records with status as `completed` for the dataset. * The `pending` attribute will have the count of all the records with status as `pending` for the dataset. * The `total` attribute will have the sum of the `completed` and `pending` attributes. ```json { "total": 5 "completed": 2, "pending": 3, } ``` @damianpumar some changes are required on the frontend to support this new endpoint structure. ## `GET /me/datasets/:dataset_id/metrics` Old response without distribution task: ```json { "records": { "count": 7 }, "responses": { "count": 4, "submitted": 1, "discarded": 2, "draft": 1 } } ``` New response with the changes from this PR supporting distribution task: * `records` section has been eliminated because is not necessary anymore. * `responses` `count` section has been renamed to `total`. * `pending` section has been added to the `responses` section. ```json { "responses": { "total": 7, "submitted": 1, "discarded": 2, "draft": 1, "pending": 3 } } ``` The logic behind these attributes is the following: * `total` is the sum of `submitted`, `discarded`, `draft` and `pending` attribute values. * `submitted` is the count of all responses belonging to the current user in the specified dataset with `submitted` status. * `discarded` is the count of all responses belonging to the current user in the specified dataset with `discarded` status. * `draft` is the count of all responses belonging to the current user in the specified dataset with `draft` status. * `pending` is the count of all records with `pending` status for the dataset that has not responses belonging to the current user. @damianpumar some changes are required on the frontend to support this new endpoint structure as well. Closes #5139 **Type of change** - Breaking change (fix or feature that would cause existing functionality to not work as expected) **How Has This Been Tested** - [x] Modifying existent tests. - [x] Running test suite with SQLite and PostgreSQL. **Checklist** - I added relevant documentation - follows the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paco Aranda Co-authored-by: Damián Pumar --- .../useDatasetProgressViewModel.ts | 22 +--- argilla-frontend/translation/de.js | 5 + argilla-frontend/translation/en.js | 4 +- .../domain/entities/dataset/Metrics.test.ts | 28 ++-- .../v1/domain/entities/dataset/Metrics.ts | 16 +-- .../v1/domain/entities/dataset/Progress.ts | 4 +- .../repositories/DatasetRepository.ts | 8 +- .../repositories/MetricsRepository.ts | 12 +- .../v1/infrastructure/types/dataset.ts | 4 +- argilla-server/CHANGELOG.md | 2 + .../api/handlers/v1/datasets/datasets.py | 18 +-- .../argilla_server/api/schemas/v1/datasets.py | 12 +- .../src/argilla_server/contexts/datasets.py | 121 +++++++++++------- .../v1/datasets/test_get_dataset_progress.py | 80 ++---------- .../v1/datasets/test_update_dataset.py | 3 +- .../unit/api/handlers/v1/test_datasets.py | 50 ++++++-- 16 files changed, 168 insertions(+), 221 deletions(-) diff --git a/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts b/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts index f2b1ef6afc..149b45ac10 100644 --- a/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts +++ b/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts @@ -22,25 +22,11 @@ export const useDatasetProgressViewModel = ({ progressRanges.value = [ { - id: "submitted", - name: t("datasets.submitted"), + id: "completed", + name: t("datasets.completed"), color: "#0508D9", - value: progress.value.submitted, - tooltip: `${progress.value.submitted}/${progress.value.total}`, - }, - { - id: "conflicting", - name: t("datasets.conflicting"), - color: "#8893c0", - value: progress.value.conflicting, - tooltip: `${progress.value.conflicting}/${progress.value.total}`, - }, - { - id: "discarded", - name: t("datasets.discarded"), - color: "#b7b7b7", - value: progress.value.discarded, - tooltip: `${progress.value.discarded}/${progress.value.total}`, + value: progress.value.completed, + tooltip: `${progress.value.completed}/${progress.value.total}`, }, { id: "pending", diff --git a/argilla-frontend/translation/de.js b/argilla-frontend/translation/de.js index 099bd2e233..8d17eb4ac9 100644 --- a/argilla-frontend/translation/de.js +++ b/argilla-frontend/translation/de.js @@ -36,6 +36,11 @@ export default { datasetSettings: "einstellungen", userSettings: "meine einstellungen", }, + datasets: { + left: "übrig", + completed: "Vollendet", + pending: "Ausstehend", + }, recordStatus: { pending: "Ausstehend", draft: "Entwurf", diff --git a/argilla-frontend/translation/en.js b/argilla-frontend/translation/en.js index af12e6df17..6ceac06d00 100644 --- a/argilla-frontend/translation/en.js +++ b/argilla-frontend/translation/en.js @@ -42,9 +42,7 @@ export default { }, datasets: { left: "left", - submitted: "Submitted", - conflicting: "Conflicting", - discarded: "Discarded", + completed: "Completed", pending: "Pending", }, recordStatus: { diff --git a/argilla-frontend/v1/domain/entities/dataset/Metrics.test.ts b/argilla-frontend/v1/domain/entities/dataset/Metrics.test.ts index 792450fe5b..322f480007 100644 --- a/argilla-frontend/v1/domain/entities/dataset/Metrics.test.ts +++ b/argilla-frontend/v1/domain/entities/dataset/Metrics.test.ts @@ -20,67 +20,67 @@ describe("Metrics", () => { describe("total", () => { it("should return the total number of records", () => { - const metrics = new Metrics(1, 0, 0, 0, 0); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.total; - expect(result).toEqual(1); + expect(result).toEqual(15); }); }); describe("responded", () => { it("should return the number of responded records", () => { - const metrics = new Metrics(5, 5, 3, 1, 1); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.responded; - expect(result).toEqual(5); + expect(result).toEqual(10); }); }); describe("pending", () => { it("should return the number of pending records", () => { - const metrics = new Metrics(5, 4, 3, 1, 0); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.pending; - expect(result).toEqual(1); + expect(result).toEqual(5); }); }); describe("progress", () => { it("should return the progress of responded records", () => { - const metrics = new Metrics(5, 4, 3, 1, 0); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.progress; - expect(result).toEqual(0.8); + expect(result).toEqual(0.6666666666666666); }); }); describe("percentage", () => { it("should return the percentage of draft records", () => { - const metrics = new Metrics(5, 4, 3, 1, 1); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.percentage.draft; - expect(result).toEqual(20); + expect(result).toEqual(6.666666666666667); }); it("should return the percentage of submitted records", () => { - const metrics = new Metrics(5, 4, 3, 1, 1); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.percentage.submitted; - expect(result).toEqual(60); + expect(result).toEqual(26.666666666666668); }); it("should return the percentage of discarded records", () => { - const metrics = new Metrics(5, 4, 3, 1, 1); + const metrics = new Metrics(15, 4, 5, 1, 5); const result = metrics.percentage.discarded; - expect(result).toEqual(20); + expect(result).toEqual(33.333333333333336); }); }); }); diff --git a/argilla-frontend/v1/domain/entities/dataset/Metrics.ts b/argilla-frontend/v1/domain/entities/dataset/Metrics.ts index 31c80d6e08..ec1245e010 100644 --- a/argilla-frontend/v1/domain/entities/dataset/Metrics.ts +++ b/argilla-frontend/v1/domain/entities/dataset/Metrics.ts @@ -7,11 +7,11 @@ export class Metrics { }; constructor( - private readonly records: number, - public readonly responses: number, + public readonly total: number, public readonly submitted: number, public readonly discarded: number, - public readonly draft: number + public readonly draft: number, + public readonly pending: number ) { this.percentage = { pending: (this.pending * 100) / this.total, @@ -22,21 +22,13 @@ export class Metrics { } get hasMetrics() { - return this.records > 0; - } - - get total() { - return this.records; + return this.total > 0; } get responded() { return this.submitted + this.discarded + this.draft; } - get pending() { - return this.total - this.responded; - } - get progress() { return this.responded / this.total; } diff --git a/argilla-frontend/v1/domain/entities/dataset/Progress.ts b/argilla-frontend/v1/domain/entities/dataset/Progress.ts index 64c137f672..d996580c3d 100644 --- a/argilla-frontend/v1/domain/entities/dataset/Progress.ts +++ b/argilla-frontend/v1/domain/entities/dataset/Progress.ts @@ -1,9 +1,7 @@ export class Progress { constructor( public readonly total: number, - public readonly submitted: number, - public readonly discarded: number, - public readonly conflicting: number, + public readonly completed: number, public readonly pending: number ) {} } diff --git a/argilla-frontend/v1/infrastructure/repositories/DatasetRepository.ts b/argilla-frontend/v1/infrastructure/repositories/DatasetRepository.ts index 875935d9f0..fb82353fb8 100644 --- a/argilla-frontend/v1/infrastructure/repositories/DatasetRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/DatasetRepository.ts @@ -107,13 +107,7 @@ export class DatasetRepository implements IDatasetRepository { largeCache() ); - return new Progress( - data.total, - data.submitted, - data.discarded, - data.conflicting, - data.pending - ); + return new Progress(data.total, data.completed, data.pending); } catch (err) { throw { response: DATASET_API_ERRORS.ERROR_DELETING_DATASET, diff --git a/argilla-frontend/v1/infrastructure/repositories/MetricsRepository.ts b/argilla-frontend/v1/infrastructure/repositories/MetricsRepository.ts index 7cff90f7f9..2ddc434ef7 100644 --- a/argilla-frontend/v1/infrastructure/repositories/MetricsRepository.ts +++ b/argilla-frontend/v1/infrastructure/repositories/MetricsRepository.ts @@ -3,14 +3,12 @@ import { largeCache } from "./AxiosCache"; import { Metrics } from "~/v1/domain/entities/dataset/Metrics"; interface BackendMetrics { - records: { - count: number; - }; responses: { - count: number; + total: number; submitted: number; discarded: number; draft: number; + pending: number; }; } @@ -25,11 +23,11 @@ export class MetricsRepository { ); return new Metrics( - data.records.count, - data.responses.count, + data.responses.total, data.responses.submitted, data.responses.discarded, - data.responses.draft + data.responses.draft, + data.responses.pending ); } catch { /* lint:disable:no-empty */ diff --git a/argilla-frontend/v1/infrastructure/types/dataset.ts b/argilla-frontend/v1/infrastructure/types/dataset.ts index 7b160fcbbf..e270b5495f 100644 --- a/argilla-frontend/v1/infrastructure/types/dataset.ts +++ b/argilla-frontend/v1/infrastructure/types/dataset.ts @@ -16,8 +16,6 @@ export interface BackendDatasetFeedbackTaskResponse { export interface BackendProgress { total: number; - submitted: number; - discarded: number; - conflicting: number; + completed: number; pending: number; } diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 2883fc9c6e..e466dbbded 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -24,6 +24,8 @@ These are the section headers that we use: ### Changed - Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) +- [breaking] Change `GET /datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) +- [breaking] Change `GET /me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) ### Fixed diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py index 0590b41bb4..85bf7962c8 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/datasets.py @@ -147,23 +147,7 @@ async def get_current_user_dataset_metrics( await authorize(current_user, DatasetPolicy.get(dataset)) - return { - "records": { - "count": await datasets.count_records_by_dataset_id(db, dataset_id), - }, - "responses": { - "count": await datasets.count_responses_by_dataset_id_and_user_id(db, dataset_id, current_user.id), - "submitted": await datasets.count_responses_by_dataset_id_and_user_id( - db, dataset_id, current_user.id, ResponseStatus.submitted - ), - "discarded": await datasets.count_responses_by_dataset_id_and_user_id( - db, dataset_id, current_user.id, ResponseStatus.discarded - ), - "draft": await datasets.count_responses_by_dataset_id_and_user_id( - db, dataset_id, current_user.id, ResponseStatus.draft - ), - }, - } + return await datasets.get_user_dataset_metrics(db, current_user.id, dataset.id) @router.get("/datasets/{dataset_id}/progress", response_model=DatasetProgress) diff --git a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py index 1e1b69d836..dd9f1941f1 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/datasets.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/datasets.py @@ -70,27 +70,21 @@ class DatasetOverlapDistributionUpdate(DatasetDistributionCreate): DatasetDistributionUpdate = DatasetOverlapDistributionUpdate -class RecordMetrics(BaseModel): - count: int - - class ResponseMetrics(BaseModel): - count: int + total: int submitted: int discarded: int draft: int + pending: int class DatasetMetrics(BaseModel): - records: RecordMetrics responses: ResponseMetrics class DatasetProgress(BaseModel): total: int - submitted: int - discarded: int - conflicting: int + completed: int pending: int diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 1dbf52fc53..700dfeaefa 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -37,7 +37,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload -from argilla_server.api.schemas.v1.datasets import DatasetProgress from argilla_server.api.schemas.v1.fields import FieldCreate from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyCreate, MetadataPropertyUpdate from argilla_server.api.schemas.v1.records import ( @@ -61,7 +60,7 @@ ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema from argilla_server.contexts import accounts, distribution -from argilla_server.enums import DatasetStatus, RecordInclude, UserRole +from argilla_server.enums import DatasetStatus, RecordInclude, UserRole, RecordStatus from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError from argilla_server.models import ( Dataset, @@ -372,39 +371,85 @@ async def _configure_query_relationships( return query -async def count_records_by_dataset_id(db: AsyncSession, dataset_id: UUID) -> int: - return (await db.execute(select(func.count(Record.id)).filter_by(dataset_id=dataset_id))).scalar_one() - - -async def get_dataset_progress(db: AsyncSession, dataset_id: UUID) -> DatasetProgress: - submitted_case = case((Response.status == ResponseStatus.submitted, 1), else_=0) - discarded_case = case((Response.status == ResponseStatus.discarded, 1), else_=0) +async def get_user_dataset_metrics(db: AsyncSession, user_id: UUID, dataset_id: UUID) -> dict: + responses_submitted, responses_discarded, responses_draft, responses_pending = await asyncio.gather( + db.execute( + select(func.count(Response.id)) + .join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id)) + .filter( + Response.user_id == user_id, + Response.status == ResponseStatus.submitted, + ), + ), + db.execute( + select(func.count(Response.id)) + .join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id)) + .filter( + Response.user_id == user_id, + Response.status == ResponseStatus.discarded, + ), + ), + db.execute( + select(func.count(Response.id)) + .join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id)) + .filter( + Response.user_id == user_id, + Response.status == ResponseStatus.draft, + ), + ), + db.execute( + select(func.count(Record.id)) + .outerjoin(Response, and_(Response.record_id == Record.id, Response.user_id == user_id)) + .filter( + Record.dataset_id == dataset_id, + Record.status == RecordStatus.pending, + Response.id == None, + ), + ), + ) - submitted_clause = func.sum(submitted_case) > 0, func.sum(discarded_case) == 0 - discarded_clause = func.sum(discarded_case) > 0, func.sum(submitted_case) == 0 - conflicting_clause = func.sum(submitted_case) > 0, func.sum(discarded_case) > 0 + responses_submitted = responses_submitted.scalar_one() + responses_discarded = responses_discarded.scalar_one() + responses_draft = responses_draft.scalar_one() + responses_pending = responses_pending.scalar_one() + responses_total = responses_submitted + responses_discarded + responses_draft + responses_pending + + return { + "responses": { + "total": responses_total, + "submitted": responses_submitted, + "discarded": responses_discarded, + "draft": responses_draft, + "pending": responses_pending, + }, + } - query = select(Record.id).join(Response).filter(Record.dataset_id == dataset_id).group_by(Record.id) - total, submitted, discarded, conflicting = await asyncio.gather( - count_records_by_dataset_id(db, dataset_id), - db.execute(select(func.count("*")).select_from(query.having(*submitted_clause))), - db.execute(select(func.count("*")).select_from(query.having(*discarded_clause))), - db.execute(select(func.count("*")).select_from(query.having(*conflicting_clause))), +async def get_dataset_progress(db: AsyncSession, dataset_id: UUID) -> dict: + records_completed, records_pending = await asyncio.gather( + db.execute( + select(func.count(Record.id)).filter( + Record.dataset_id == dataset_id, + Record.status == RecordStatus.completed, + ), + ), + db.execute( + select(func.count(Record.id)).filter( + Record.dataset_id == dataset_id, + Record.status == RecordStatus.pending, + ), + ), ) - submitted = submitted.scalar_one() - discarded = discarded.scalar_one() - conflicting = conflicting.scalar_one() - pending = total - submitted - discarded - conflicting - - return DatasetProgress( - total=total, - submitted=submitted, - discarded=discarded, - conflicting=conflicting, - pending=pending, - ) + records_completed = records_completed.scalar_one() + records_pending = records_pending.scalar_one() + records_total = records_completed + records_pending + + return { + "total": records_total, + "completed": records_completed, + "pending": records_pending, + } _EXTRA_METADATA_FLAG = "extra" @@ -901,22 +946,6 @@ async def delete_record(db: AsyncSession, search_engine: "SearchEngine", record: return record -async def count_responses_by_dataset_id_and_user_id( - db: AsyncSession, dataset_id: UUID, user_id: UUID, response_status: Optional[ResponseStatus] = None -) -> int: - expressions = [Response.user_id == user_id] - if response_status: - expressions.append(Response.status == response_status) - - return ( - await db.execute( - select(func.count(Response.id)) - .join(Record, and_(Record.id == Response.record_id, Record.dataset_id == dataset_id)) - .filter(*expressions) - ) - ).scalar_one() - - async def create_response( db: AsyncSession, search_engine: SearchEngine, record: Record, user: User, response_create: ResponseCreate ) -> Response: diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_get_dataset_progress.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_get_dataset_progress.py index d3cb4e7393..6fb06a06c9 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_get_dataset_progress.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_get_dataset_progress.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from uuid import UUID, uuid4 +from httpx import AsyncClient -import pytest from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import ResponseStatus, UserRole -from httpx import AsyncClient +from argilla_server.enums import UserRole, RecordStatus -from tests.factories import DatasetFactory, RecordFactory, ResponseFactory, UserFactory +from tests.factories import DatasetFactory, RecordFactory, UserFactory @pytest.mark.asyncio @@ -30,71 +31,16 @@ def url(self, dataset_id: UUID) -> str: async def test_get_dataset_progress(self, async_client: AsyncClient, owner_auth_header: dict): dataset = await DatasetFactory.create() - record_with_one_submitted_response = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create(record=record_with_one_submitted_response) - - record_with_multiple_submitted_responses = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create_batch(3, record=record_with_multiple_submitted_responses) - - record_with_one_draft_response = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create(record=record_with_one_draft_response, status=ResponseStatus.draft) - - record_with_multiple_draft_responses = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create_batch(3, record=record_with_multiple_draft_responses, status=ResponseStatus.draft) - - record_with_one_discarded_response = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create(record=record_with_one_discarded_response, status=ResponseStatus.discarded) - - record_with_multiple_discarded_responses = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create_batch( - 3, record=record_with_multiple_discarded_responses, status=ResponseStatus.discarded - ) - - record_with_mixed_responses = await RecordFactory.create(dataset=dataset) - await ResponseFactory.create(record=record_with_mixed_responses) - await ResponseFactory.create(record=record_with_mixed_responses, status=ResponseStatus.draft) - await ResponseFactory.create(record=record_with_mixed_responses, status=ResponseStatus.discarded) - - record_without_responses = await RecordFactory.create(dataset=dataset) - - other_dataset = await DatasetFactory.create() - - other_record_with_one_submitted_response = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create(record=other_record_with_one_submitted_response) - - other_record_with_multiple_submitted_responses = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create_batch(3, record=other_record_with_multiple_submitted_responses) - - other_record_with_one_draft_response = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create(record=other_record_with_one_draft_response, status=ResponseStatus.draft) - - other_record_with_multiple_draft_responses = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create_batch( - 3, record=other_record_with_multiple_draft_responses, status=ResponseStatus.draft - ) - - other_record_with_one_discarded_response = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create(record=other_record_with_one_discarded_response, status=ResponseStatus.discarded) - - other_record_with_multiple_discarded_responses = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create_batch( - 3, record=other_record_with_multiple_discarded_responses, status=ResponseStatus.discarded - ) - - other_record_with_mixed_responses = await RecordFactory.create(dataset=other_dataset) - await ResponseFactory.create(record=other_record_with_mixed_responses) - await ResponseFactory.create(record=other_record_with_mixed_responses, status=ResponseStatus.draft) - await ResponseFactory.create(record=other_record_with_mixed_responses, status=ResponseStatus.discarded) + records_completed = await RecordFactory.create_batch(3, status=RecordStatus.completed, dataset=dataset) + records_pending = await RecordFactory.create_batch(2, status=RecordStatus.pending, dataset=dataset) response = await async_client.get(self.url(dataset.id), headers=owner_auth_header) assert response.status_code == 200 assert response.json() == { - "total": 8, - "submitted": 2, - "discarded": 2, - "conflicting": 1, - "pending": 3, + "completed": 3, + "pending": 2, + "total": 5, } async def test_get_dataset_progress_with_empty_dataset(self, async_client: AsyncClient, owner_auth_header: dict): @@ -104,11 +50,9 @@ async def test_get_dataset_progress_with_empty_dataset(self, async_client: Async assert response.status_code == 200 assert response.json() == { - "total": 0, - "submitted": 0, - "discarded": 0, - "conflicting": 0, + "completed": 0, "pending": 0, + "total": 0, } @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py index cdb9b06ea2..097bc0a1ec 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -15,9 +15,10 @@ from uuid import UUID import pytest -from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus from httpx import AsyncClient +from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus + from tests.factories import DatasetFactory diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index e0c9fe4d5e..9404b3850e 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -735,11 +735,12 @@ async def test_get_current_user_dataset_metrics( self, async_client: "AsyncClient", owner: User, owner_auth_header: dict ): dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) + record_a = await RecordFactory.create(dataset=dataset, status=RecordStatus.completed) + record_b = await RecordFactory.create(dataset=dataset, status=RecordStatus.completed) record_c = await RecordFactory.create(dataset=dataset) record_d = await RecordFactory.create(dataset=dataset) await RecordFactory.create_batch(3, dataset=dataset) + await RecordFactory.create_batch(2, dataset=dataset, status=RecordStatus.completed) await ResponseFactory.create(record=record_a, user=owner) await ResponseFactory.create(record=record_b, user=owner, status=ResponseStatus.discarded) await ResponseFactory.create(record=record_c, user=owner, status=ResponseStatus.discarded) @@ -758,33 +759,43 @@ async def test_get_current_user_dataset_metrics( assert response.status_code == 200 assert response.json() == { - "records": { - "count": 7, - }, "responses": { - "count": 4, + "total": 7, "submitted": 1, "discarded": 2, "draft": 1, + "pending": 3, }, } - async def test_get_current_user_dataset_metrics_without_authentication(self, async_client: "AsyncClient"): + async def test_get_current_user_dataset_metrics_with_empty_dataset( + self, async_client: "AsyncClient", owner_auth_header: dict + ): dataset = await DatasetFactory.create() - response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/metrics") + response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/metrics", headers=owner_auth_header) - assert response.status_code == 401 + assert response.status_code == 200 + assert response.json() == { + "responses": { + "total": 0, + "submitted": 0, + "discarded": 0, + "draft": 0, + "pending": 0, + }, + } @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) async def test_get_current_user_dataset_metrics_as_annotator(self, async_client: "AsyncClient", role: UserRole): dataset = await DatasetFactory.create() user = await AnnotatorFactory.create(workspaces=[dataset.workspace], role=role) record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) + record_b = await RecordFactory.create(dataset=dataset, status=RecordStatus.completed) record_c = await RecordFactory.create(dataset=dataset) record_d = await RecordFactory.create(dataset=dataset) await RecordFactory.create_batch(2, dataset=dataset) + await RecordFactory.create_batch(3, dataset=dataset, status=RecordStatus.completed) await ResponseFactory.create(record=record_a, user=user) await ResponseFactory.create(record=record_b, user=user) await ResponseFactory.create(record=record_c, user=user, status=ResponseStatus.discarded) @@ -800,15 +811,28 @@ async def test_get_current_user_dataset_metrics_as_annotator(self, async_client: await ResponseFactory.create(record=other_record_c, status=ResponseStatus.discarded) response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/metrics", headers={API_KEY_HEADER_NAME: user.api_key} + f"/api/v1/me/datasets/{dataset.id}/metrics", + headers={API_KEY_HEADER_NAME: user.api_key}, ) assert response.status_code == 200 assert response.json() == { - "records": {"count": 6}, - "responses": {"count": 4, "submitted": 2, "discarded": 1, "draft": 1}, + "responses": { + "total": 6, + "submitted": 2, + "discarded": 1, + "draft": 1, + "pending": 2, + }, } + async def test_get_current_user_dataset_metrics_without_authentication(self, async_client: "AsyncClient"): + dataset = await DatasetFactory.create() + + response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/metrics") + + assert response.status_code == 401 + @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) async def test_get_current_user_dataset_metrics_restricted_user_from_different_workspace( self, async_client: "AsyncClient", role: UserRole From 267811ccf4c7c892b42726ac8a5d9f6fa92af815 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Thu, 4 Jul 2024 14:59:06 +0200 Subject: [PATCH 05/36] [REFACTOR] `argilla-server`: Remove list current user records endpoint (#5153) This PR is the first of a series of PRs for cleaning the listing-records-related endpoints. This PR removes the `GET /api/v1/me/datasets/:dataset_id/records` endpoint since the only client was the frontend application and now is using the equivalent search endpoint. **Type of change** - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - follows the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/CHANGELOG.md | 4 + .../api/handlers/v1/datasets/records.py | 38 - .../handlers/v1/test_list_dataset_records.py | 750 ------------------ 3 files changed, 4 insertions(+), 788 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 827037a2c3..5733f80aa1 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -25,6 +25,10 @@ These are the section headers that we use: - Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) +### Removed + +- Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) + ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) ### Removed diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index e032aa7037..ed229437d7 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -353,44 +353,6 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord raise UnprocessableEntityError(str(e)) -@router.get("/me/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) -async def list_current_user_dataset_records( - *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), - dataset_id: UUID, - metadata: MetadataQueryParams = Depends(), - sort_by_query_param: SortByQueryParamParsed, - include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), - offset: int = 0, - limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), - current_user: User = Security(auth.get_current_user), -): - dataset = await Dataset.get_or_raise(db, dataset_id, options=[selectinload(Dataset.metadata_properties)]) - - await authorize(current_user, DatasetPolicy.get(dataset)) - - records, total = await _filter_records_using_search_engine( - db, - search_engine, - dataset=dataset, - parsed_metadata=metadata.metadata_parsed, - limit=limit, - offset=offset, - user=current_user, - response_statuses=response_statuses, - include=include, - sort_by_query_param=sort_by_query_param, - ) - - for record in records: - record.dataset = dataset - record.metadata_ = await _filter_record_metadata_for_user(record, current_user) - - return Records(items=records, total=total) - - @router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) async def list_dataset_records( *, diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index 8f78940df3..a697f88510 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -793,753 +793,3 @@ async def create_dataset_with_user_responses( ] return dataset, questions, records, responses, suggestions - - async def test_list_current_user_dataset_records( - self, async_client: "AsyncClient", mock_search_engine: SearchEngine, owner: User, owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - record_a, record_b, record_c = records - - mock_search_engine.search.return_value = SearchResponses( - total=3, - items=[ - SearchResponseItem(record_id=record_a.id, score=14.2), - SearchResponseItem(record_id=record_b.id, score=12.2), - SearchResponseItem(record_id=record_c.id, score=10.2), - ], - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header) - - assert response.status_code == 200 - assert response.json() == { - "total": 3, - "items": [ - { - "id": str(record_a.id), - "status": RecordStatus.pending, - "fields": {"input": "input_a", "output": "output_a"}, - "metadata": None, - "dataset_id": str(dataset.id), - "external_id": record_a.external_id, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "status": RecordStatus.pending, - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"unit": "test"}, - "dataset_id": str(dataset.id), - "external_id": record_b.external_id, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "status": RecordStatus.pending, - "fields": {"input": "input_c", "output": "output_c"}, - "metadata": None, - "dataset_id": str(dataset.id), - "external_id": record_c.external_id, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - } - - async def test_list_current_user_dataset_records_with_filtered_metadata_as_annotator( - self, async_client: "AsyncClient", mock_search_engine: SearchEngine, owner: User - ): - workspace = await WorkspaceFactory.create() - user = await AnnotatorFactory.create() - await WorkspaceUserFactory.create(workspace_id=workspace.id, user_id=user.id) - - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - await TermsMetadataPropertyFactory.create( - name="key1", - dataset=dataset, - allowed_roles=[UserRole.admin, UserRole.annotator], - ) - await TermsMetadataPropertyFactory.create( - name="key2", - dataset=dataset, - allowed_roles=[UserRole.admin], - ) - await TermsMetadataPropertyFactory.create( - name="key3", - dataset=dataset, - allowed_roles=[UserRole.admin], - ) - - record = await RecordFactory.create( - dataset=dataset, - fields={"input": "input_b", "output": "output_b"}, - metadata_={"key1": "value1", "key2": "value2", "key3": "value3", "extra": "extra"}, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=1, - items=[SearchResponseItem(record_id=record.id, score=14.2)], - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 200 - assert response.json() == { - "total": 1, - "items": [ - { - "id": str(record.id), - "status": RecordStatus.pending, - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"key1": "value1"}, - "dataset_id": str(dataset.id), - "external_id": record.external_id, - "inserted_at": record.inserted_at.isoformat(), - "updated_at": record.updated_at.isoformat(), - } - ], - } - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin, UserRole.owner]) - @pytest.mark.parametrize( - "includes", - [[RecordInclude.responses], [RecordInclude.suggestions], [RecordInclude.responses, RecordInclude.suggestions]], - ) - async def test_list_current_user_dataset_records_with_include( - self, async_client: "AsyncClient", role: UserRole, includes: List[RecordInclude] - ): - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - dataset, questions, records, responses, suggestions = await self.create_dataset_with_user_responses( - user, workspace - ) - record_a, record_b, record_c = records - response_a_user, response_b_user = responses[1], responses[3] - suggestion_a, suggestion_b = suggestions - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - params = [("include", include.value) for include in includes] - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", params=params, headers={API_KEY_HEADER_NAME: user.api_key} - ) - - expected = { - "total": 3, - "items": [ - { - "id": str(record_a.id), - "fields": {"input": "input_a", "output": "output_a"}, - "metadata": None, - "external_id": record_a.external_id, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"input": "input_b", "output": "output_b"}, - "metadata": {"unit": "test"}, - "external_id": record_b.external_id, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"input": "input_c", "output": "output_c"}, - "metadata": None, - "external_id": record_c.external_id, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - } - - if RecordInclude.responses in includes: - expected["items"][0]["responses"] = [ - { - "id": str(response_a_user.id), - "values": None, - "status": "discarded", - "user_id": str(user.id), - "inserted_at": response_a_user.inserted_at.isoformat(), - "updated_at": response_a_user.updated_at.isoformat(), - } - ] - expected["items"][1]["responses"] = [ - { - "id": str(response_b_user.id), - "values": { - "input_ok": {"value": "no"}, - "output_ok": {"value": "no"}, - }, - "status": "submitted", - "user_id": str(user.id), - "inserted_at": response_b_user.inserted_at.isoformat(), - "updated_at": response_b_user.updated_at.isoformat(), - }, - ] - expected["items"][2]["responses"] = [] - - if RecordInclude.suggestions in includes: - expected["items"][0]["suggestions"] = [ - { - "id": str(suggestion_a.id), - "value": "option-1", - "score": None, - "agent": None, - "type": None, - "question_id": str(questions[0].id), - } - ] - expected["items"][1]["suggestions"] = [ - { - "id": str(suggestion_b.id), - "value": "option-2", - "score": 0.75, - "agent": "unit-test-agent", - "type": "model", - "question_id": str(questions[0].id), - } - ] - expected["items"][2]["suggestions"] = [] - - assert response.status_code == 200 - assert response.json() == expected - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_include_vectors( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) - record_c = await RecordFactory.create(dataset=dataset) - vector_settings_a = await VectorSettingsFactory.create(name="vector-a", dimensions=3, dataset=dataset) - vector_settings_b = await VectorSettingsFactory.create(name="vector-b", dimensions=2, dataset=dataset) - - await VectorFactory.create(value=[1.0, 2.0, 3.0], vector_settings=vector_settings_a, record=record_a) - await VectorFactory.create(value=[4.0, 5.0], vector_settings=vector_settings_b, record=record_a) - await VectorFactory.create(value=[1.0, 2.0], vector_settings=vector_settings_b, record=record_b) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"include": RecordInclude.vectors.value}, - headers=owner_auth_header, - ) - - assert response.status_code == 200 - assert response.json() == { - "items": [ - { - "id": str(record_a.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_a.external_id, - "vectors": { - "vector-a": [1.0, 2.0, 3.0], - "vector-b": [4.0, 5.0], - }, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_b.external_id, - "vectors": { - "vector-b": [1.0, 2.0], - }, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_c.external_id, - "vectors": {}, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - "total": 3, - } - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_include_specific_vectors( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(dataset=dataset) - record_b = await RecordFactory.create(dataset=dataset) - record_c = await RecordFactory.create(dataset=dataset) - vector_settings_a = await VectorSettingsFactory.create(name="vector-a", dimensions=3, dataset=dataset) - vector_settings_b = await VectorSettingsFactory.create(name="vector-b", dimensions=2, dataset=dataset) - vector_settings_c = await VectorSettingsFactory.create(name="vector-c", dimensions=4, dataset=dataset) - - await VectorFactory.create(value=[1.0, 2.0, 3.0], vector_settings=vector_settings_a, record=record_a) - await VectorFactory.create(value=[4.0, 5.0], vector_settings=vector_settings_b, record=record_a) - await VectorFactory.create(value=[6.0, 7.0, 8.0, 9.0], vector_settings=vector_settings_c, record=record_a) - await VectorFactory.create(value=[1.0, 2.0], vector_settings=vector_settings_b, record=record_b) - await VectorFactory.create(value=[10.0, 11.0, 12.0, 13.0], vector_settings=vector_settings_c, record=record_b) - await VectorFactory.create(value=[14.0, 15.0, 16.0, 17.0], vector_settings=vector_settings_c, record=record_c) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"include": f"{RecordInclude.vectors.value}:{vector_settings_a.name},{vector_settings_b.name}"}, - headers=owner_auth_header, - ) - - assert response.status_code == 200 - assert response.json() == { - "items": [ - { - "id": str(record_a.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_a.external_id, - "vectors": { - "vector-a": [1.0, 2.0, 3.0], - "vector-b": [4.0, 5.0], - }, - "inserted_at": record_a.inserted_at.isoformat(), - "updated_at": record_a.updated_at.isoformat(), - }, - { - "id": str(record_b.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_b.external_id, - "vectors": { - "vector-b": [1.0, 2.0], - }, - "inserted_at": record_b.inserted_at.isoformat(), - "updated_at": record_b.updated_at.isoformat(), - }, - { - "id": str(record_c.id), - "fields": {"text": "This is a text", "sentiment": "neutral"}, - "metadata": None, - "external_id": record_c.external_id, - "vectors": {}, - "inserted_at": record_c.inserted_at.isoformat(), - "updated_at": record_c.updated_at.isoformat(), - }, - ], - "total": 3, - } - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_offset( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"offset": 2} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_limit( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"limit": 1} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_a.id)] - - @pytest.mark.skip(reason="Factory integration with search engine") - async def test_list_current_user_dataset_records_with_offset_and_limit( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - record_c = await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params={"offset": 1, "limit": 1} - ) - - assert response.status_code == 200 - - response_body = response.json() - assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - - @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), - [ - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), - ), - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, le=None), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(ge=None, le=20), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, le=None), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(ge=None, le=11.32), - ), - ], - ) - async def test_list_current_user_dataset_records_with_metadata_filter( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - metadata_property = await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - - response_json = response.json() - assert response_json["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, - user_id=owner.id, - ) - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("response_status_filter", ["missing", "pending", "discarded", "submitted", "draft"]) - async def test_list_current_user_dataset_records_with_response_status_filter( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict, response_status_filter: str - ): - num_responses_per_status = 10 - response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}} - - dataset = await DatasetFactory.create() - # missing responses - await RecordFactory.create_batch(size=num_responses_per_status, dataset=dataset) - # discarded responses - await self.create_records_with_response(num_responses_per_status, dataset, owner, ResponseStatus.discarded) - # submitted responses - await self.create_records_with_response( - num_responses_per_status, dataset, owner, ResponseStatus.submitted, response_values - ) - # drafted responses - await self.create_records_with_response( - num_responses_per_status, dataset, owner, ResponseStatus.draft, response_values - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records?response_status={response_status_filter}&include=responses", - headers=owner_auth_header, - ) - - assert response.status_code == 200 - response_json = response.json() - - assert len(response_json["items"]) == num_responses_per_status - - if response_status_filter in ["missing", "pending"]: - assert all([len(record["responses"]) == 0 for record in response_json["items"]]) - else: - assert all( - [record["responses"][0]["status"] == response_status_filter for record in response_json["items"]] - ) - - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - ], - ) - async def test_list_current_user_dataset_records_with_sort_by( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: "User", - owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] - } - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - assert response.json()["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, - user_id=owner.id, - ) - - async def test_list_current_user_dataset_records_with_sort_by_with_wrong_sort_order_value( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "inserted_at:wrong"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } - - async def test_list_current_user_dataset_records_with_sort_by_with_non_existent_metadata_property( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "metadata.i-do-not-exist:asc"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." - } - - async def test_list_current_user_dataset_records_with_sort_by_with_invalid_field( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", - params={"sort_by": "not-valid"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } - - async def test_list_current_user_dataset_records_without_authentication(self, async_client: "AsyncClient"): - dataset = await DatasetFactory.create() - - response = await async_client.get(f"/api/v1/me/datasets/{dataset.id}/records") - - assert response.status_code == 401 - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize("role", [UserRole.admin, UserRole.annotator]) - async def test_list_current_user_dataset_records_as_restricted_user( - self, async_client: "AsyncClient", role: UserRole - ): - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - dataset = await DatasetFactory.create(workspace=workspace) - record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - record_b = await RecordFactory.create( - fields={"record_b": "value_b"}, metadata_={"unit": "test"}, dataset=dataset - ) - record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) - expected_records = [record_a, record_b, record_c] - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 200 - - response_items = response.json()["items"] - - for expected_record in expected_records: - found_items = [item for item in response_items if item["id"] == str(expected_record.id)] - assert found_items, expected_record - - assert found_items[0] == { - "id": str(expected_record.id), - "fields": expected_record.fields, - "metadata": expected_record.metadata_, - "external_id": expected_record.external_id, - "inserted_at": expected_record.inserted_at.isoformat(), - "updated_at": expected_record.updated_at.isoformat(), - } - - @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) - async def test_list_current_user_dataset_records_as_restricted_user_from_different_workspace( - self, async_client: "AsyncClient", role: UserRole - ): - dataset = await DatasetFactory.create() - workspace = await WorkspaceFactory.create() - user = await UserFactory.create(workspaces=[workspace], role=role) - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: user.api_key} - ) - - assert response.status_code == 403 - - async def test_list_current_user_dataset_records_with_nonexistent_dataset_id( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset_id = uuid4() - - await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/me/datasets/{dataset_id}/records", - headers=owner_auth_header, - ) - - assert response.status_code == 404 - assert response.json() == {"detail": f"Dataset with id `{dataset_id}` not found"} From 89f9bdee9e75f5a730bf30940ae25dae5da3e16a Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Thu, 4 Jul 2024 15:03:03 +0200 Subject: [PATCH 06/36] [BREAKING- REFACTOR] `argilla-server`: remove metadata filter query param (#5156) # Description > [!NOTE] > This PR must be merged after https://github.com/argilla-io/argilla/pull/5153 This PR removes support for filtering using metadata as a query param: - This filter is not available anymore for list endpoints - The metadata filter can be defined as part of the request body for search filters. **Type of change** - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - follows the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/CHANGELOG.md | 1 + .../api/handlers/v1/datasets/records.py | 46 +---- .../argilla_server/api/schemas/v1/records.py | 12 -- .../src/argilla_server/search_engine/base.py | 78 +------- .../argilla_server/search_engine/commons.py | 29 --- .../datasets/test_search_dataset_records.py | 2 - .../unit/api/handlers/v1/test_datasets.py | 168 +++++++----------- .../handlers/v1/test_list_dataset_records.py | 119 +------------ .../tests/unit/search_engine/test_commons.py | 48 ++--- 9 files changed, 93 insertions(+), 410 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 5733f80aa1..27aada3fee 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -28,6 +28,7 @@ These are the section headers that we use: ### Removed - Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) +- [breaking] Removed support for `metadata` query param. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index ed229437d7..e1a2d519dc 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -27,8 +27,6 @@ Filters, FilterScope, MetadataFilterScope, - MetadataParsedQueryParam, - MetadataQueryParams, Order, RangeFilter, RecordFilterScope, @@ -51,19 +49,15 @@ ) from argilla_server.contexts import datasets, search from argilla_server.database import get_async_db -from argilla_server.enums import MetadataPropertyType, RecordSortField, ResponseStatusFilter, SortOrder +from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings from argilla_server.search_engine import ( AndFilter, - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, SearchEngine, SearchResponses, SortBy, - TermsMetadataFilter, UserResponseStatusFilter, get_search_engine, ) @@ -106,7 +100,6 @@ async def _filter_records_using_search_engine( db: "AsyncSession", search_engine: "SearchEngine", dataset: Dataset, - parsed_metadata: List[MetadataParsedQueryParam], limit: int, offset: int, user: Optional[User] = None, @@ -121,7 +114,6 @@ async def _filter_records_using_search_engine( limit=limit, offset=offset, user=user, - parsed_metadata=parsed_metadata, response_statuses=response_statuses, sort_by_query_param=sort_by_query_param, ) @@ -182,7 +174,6 @@ async def _get_search_responses( db: "AsyncSession", search_engine: "SearchEngine", dataset: Dataset, - parsed_metadata: List[MetadataParsedQueryParam], limit: int, offset: int, search_records_query: Optional[SearchRecordsQuery] = None, @@ -228,7 +219,6 @@ async def _get_search_responses( if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id): raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.") - metadata_filters = await _build_metadata_filters(db, dataset, parsed_metadata) response_status_filter = await _build_response_status_filter_for_search(response_statuses, user=user) sort_by = await _build_sort_by(db, dataset, sort_by_query_param) @@ -240,7 +230,6 @@ async def _get_search_responses( "record": record, "query": text_query, "order": vector_query.order, - "metadata_filters": metadata_filters, "user_response_status_filter": response_status_filter, "max_results": limit, } @@ -253,7 +242,6 @@ async def _get_search_responses( search_params = { "dataset": dataset, "query": text_query, - "metadata_filters": metadata_filters, "user_response_status_filter": response_status_filter, "offset": offset, "limit": limit, @@ -271,32 +259,6 @@ async def _get_search_responses( return await search_engine.search(**search_params) -async def _build_metadata_filters( - db: "AsyncSession", dataset: Dataset, parsed_metadata: List[MetadataParsedQueryParam] -) -> List["MetadataFilter"]: - try: - metadata_filters = [] - for metadata_param in parsed_metadata: - metadata_property = await MetadataProperty.get_by(db, name=metadata_param.name, dataset_id=dataset.id) - if metadata_property is None: - continue # won't fail on unknown metadata filter name - - if metadata_property.type == MetadataPropertyType.terms: - metadata_filter_class = TermsMetadataFilter - elif metadata_property.type == MetadataPropertyType.integer: - metadata_filter_class = IntegerMetadataFilter - elif metadata_property.type == MetadataPropertyType.float: - metadata_filter_class = FloatMetadataFilter - else: - raise ValueError(f"Not found filter for type {metadata_property.type}") - - metadata_filters.append(metadata_filter_class.from_string(metadata_property, metadata_param.value)) - except (UnprocessableEntityError, ValueError) as ex: - raise UnprocessableEntityError(f"Cannot parse provided metadata filters: {ex}") - - return metadata_filters - - async def _build_response_status_filter_for_search( response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None ) -> Optional[UserResponseStatusFilter]: @@ -359,7 +321,6 @@ async def list_dataset_records( db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, - metadata: MetadataQueryParams = Depends(), sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), @@ -375,7 +336,6 @@ async def list_dataset_records( db, search_engine, dataset=dataset, - parsed_metadata=metadata.metadata_parsed, limit=limit, offset=offset, response_statuses=response_statuses, @@ -489,7 +449,6 @@ async def search_current_user_dataset_records( telemetry_client: TelemetryClient = Depends(get_telemetry_client), dataset_id: UUID, body: SearchRecordsQuery, - metadata: MetadataQueryParams = Depends(), sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), @@ -515,7 +474,6 @@ async def search_current_user_dataset_records( search_engine=search_engine, dataset=dataset, search_records_query=body, - parsed_metadata=metadata.metadata_parsed, limit=limit, offset=offset, user=current_user, @@ -563,7 +521,6 @@ async def search_dataset_records( search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, body: SearchRecordsQuery, - metadata: MetadataQueryParams = Depends(), sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), @@ -584,7 +541,6 @@ async def search_dataset_records( search_records_query=body, limit=limit, offset=offset, - parsed_metadata=metadata.metadata_parsed, response_statuses=response_statuses, sort_by_query_param=sort_by_query_param, ) diff --git a/argilla-server/src/argilla_server/api/schemas/v1/records.py b/argilla-server/src/argilla_server/api/schemas/v1/records.py index 0cf215954a..b5ff7c3f4c 100644 --- a/argilla-server/src/argilla_server/api/schemas/v1/records.py +++ b/argilla-server/src/argilla_server/api/schemas/v1/records.py @@ -13,12 +13,9 @@ # limitations under the License. from datetime import datetime - from typing import Annotated, Any, Dict, List, Literal, Optional, Union from uuid import UUID -import fastapi - from argilla_server.api.schemas.v1.commons import UpdateSchema from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate @@ -223,15 +220,6 @@ def __init__(self, string: str): self.value: str = "".join(v).strip() -class MetadataQueryParams(BaseModel): - metadata: List[str] = Field(fastapi.Query([], pattern=r"^(?=.*[a-z0-9])[a-z0-9_-]+:(.+(,(.+))*)$")) - - @property - def metadata_parsed(self) -> List[MetadataParsedQueryParam]: - # TODO: Validate metadata fields names from query params - return [MetadataParsedQueryParam(q) for q in self.metadata] - - class VectorQuery(BaseModel): name: str record_id: Optional[UUID] = None diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index ee1dbcc386..08a4e459c8 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -15,17 +15,13 @@ from abc import ABCMeta, abstractmethod from contextlib import asynccontextmanager from typing import ( - Any, AsyncGenerator, - ClassVar, - Dict, Generic, Iterable, List, Optional, - Type, - TypeVar, Union, + TypeVar, ) from uuid import UUID @@ -38,16 +34,12 @@ SortOrder, ) from argilla_server.models import Dataset, MetadataProperty, Record, Response, Suggestion, User, Vector, VectorSettings -from argilla_server.pydantic_v1 import BaseModel, Field, root_validator +from argilla_server.pydantic_v1 import BaseModel, Field from argilla_server.pydantic_v1.generics import GenericModel __all__ = [ "SearchEngine", "TextQuery", - "MetadataFilter", - "TermsMetadataFilter", - "IntegerMetadataFilter", - "FloatMetadataFilter", "UserResponseStatusFilter", "SearchResponseItem", "SearchResponses", @@ -147,67 +139,6 @@ def has_pending_status(self) -> bool: return ResponseStatusFilter.pending in self.statuses or ResponseStatusFilter.missing in self.statuses -class MetadataFilter(BaseModel): - metadata_property: MetadataProperty - - class Config: - arbitrary_types_allowed = True - - @classmethod - @abstractmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "MetadataFilter": - pass - - -class TermsMetadataFilter(MetadataFilter): - values: List[str] - - @classmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "MetadataFilter": - return cls(metadata_property=metadata_property, values=string.split(",")) - - -NT = TypeVar("NT", int, float) - - -class _RangeModel(GenericModel, Generic[NT]): - ge: Optional[NT] - le: Optional[NT] - - -class NumericMetadataFilter(GenericModel, Generic[NT], MetadataFilter): - ge: Optional[NT] = None - le: Optional[NT] = None - - _json_model: ClassVar[Type[_RangeModel]] - - @root_validator(skip_on_failure=True) - def check_bounds(cls, values: Dict[str, Any]) -> Dict[str, Any]: - ge = values.get("ge") - le = values.get("le") - - if ge is None and le is None: - raise ValueError("One of 'ge' or 'le' values must be specified") - - if ge is not None and le is not None and ge > le: - raise ValueError(f"'ge' ({ge}) must be lower or equal than 'le' ({le})") - - return values - - @classmethod - def from_string(cls, metadata_property: MetadataProperty, string: str) -> "NumericMetadataFilter": - model = cls._json_model.parse_raw(string) - return cls(metadata_property=metadata_property, ge=model.ge, le=model.le) - - -class IntegerMetadataFilter(NumericMetadataFilter[int]): - _json_model = _RangeModel[int] - - -class FloatMetadataFilter(NumericMetadataFilter[float]): - _json_model = _RangeModel[float] - - class SearchResponseItem(BaseModel): record_id: UUID score: Optional[float] @@ -236,6 +167,9 @@ class TermCount(BaseModel): values: List[TermCount] = Field(default_factory=list) +NT = TypeVar("NT", int, float) + + class NumericMetadataMetrics(GenericModel, Generic[NT]): min: Optional[NT] max: Optional[NT] @@ -350,7 +284,6 @@ async def search( sort: Optional[List[Order]] = None, # TODO: remove them and keep filter and order user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, sort_by: Optional[List[SortBy]] = None, # END TODO offset: int = 0, @@ -380,7 +313,6 @@ async def similarity_search( filter: Optional[Filter] = None, # TODO: remove them and keep filter user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, # END TODO max_results: int = 100, order: SimilarityOrder = SimilarityOrder.most_similar, diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index b328224f19..501bff03f3 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -38,11 +38,8 @@ AndFilter, Filter, FilterScope, - FloatMetadataFilter, FloatMetadataMetrics, - IntegerMetadataFilter, IntegerMetadataMetrics, - MetadataFilter, MetadataFilterScope, MetadataMetrics, Order, @@ -55,7 +52,6 @@ SortBy, SuggestionFilterScope, TermsFilter, - TermsMetadataFilter, TermsMetadataMetrics, TextQuery, UserResponseStatusFilter, @@ -203,25 +199,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str: return str(vector_settings.id) -# This function will be moved once the `metadata_filters` argument is removed from search and similarity_search methods -def _unify_metadata_filters_with_filter(metadata_filters: List[MetadataFilter], filter: Optional[Filter]) -> Filter: - filters = [] - if filter: - filters.append(filter) - - for metadata_filter in metadata_filters: - metadata_scope = MetadataFilterScope(metadata_property=metadata_filter.metadata_property.name) - if isinstance(metadata_filter, TermsMetadataFilter): - new_filter = TermsFilter(scope=metadata_scope, values=metadata_filter.values) - elif isinstance(metadata_filter, (IntegerMetadataFilter, FloatMetadataFilter)): - new_filter = RangeFilter(scope=metadata_scope, ge=metadata_filter.ge, le=metadata_filter.le) - else: - raise ValueError(f"Cannot process request for metadata filter {metadata_filter}") - filters.append(new_filter) - - return AndFilter(filters=filters) - - # This function will be moved once the response status filter is removed from search and similarity_search methods def _unify_user_response_status_filter_with_filter( user_response_status_filter: UserResponseStatusFilter, filter: Optional[Filter] = None @@ -418,15 +395,12 @@ async def similarity_search( filter: Optional[Filter] = None, # TODO: remove them and keep filter user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, # END TODO max_results: int = 100, order: SimilarityOrder = SimilarityOrder.most_similar, threshold: Optional[float] = None, ) -> SearchResponses: # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if metadata_filters: - filter = _unify_metadata_filters_with_filter(metadata_filters, filter) if user_response_status_filter and user_response_status_filter.statuses: filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter) # END TODO @@ -625,7 +599,6 @@ async def search( sort: Optional[List[Order]] = None, # TODO: Remove these arguments user_response_status_filter: Optional[UserResponseStatusFilter] = None, - metadata_filters: Optional[List[MetadataFilter]] = None, sort_by: Optional[List[SortBy]] = None, # END TODO offset: int = 0, @@ -635,8 +608,6 @@ async def search( # See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if metadata_filters: - filter = _unify_metadata_filters_with_filter(metadata_filters, filter) if user_response_status_filter and user_response_status_filter.statuses: filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 73077c4381..253c8a9433 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -316,7 +316,6 @@ async def test_with_filter( RangeFilter(scope=SuggestionFilterScope(question=question.name, property="score"), ge=0.5), ] ), - metadata_filters=[], offset=0, limit=50, query=None, @@ -367,7 +366,6 @@ async def test_with_sort( Order(scope=ResponseFilterScope(question=question.name), order=SortOrder.asc), Order(scope=SuggestionFilterScope(question=question.name, property="score"), order=SortOrder.desc), ], - metadata_filters=[], offset=0, limit=50, query=None, diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index e0c9fe4d5e..bc75af2ce5 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -57,16 +57,16 @@ VectorSettings, ) from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, SearchEngine, SearchResponseItem, SearchResponses, SortBy, - TermsMetadataFilter, TextQuery, UserResponseStatusFilter, + AndFilter, + TermsFilter, + MetadataFilterScope, + RangeFilter, ) from sqlalchemy import func, inspect, select @@ -3650,7 +3650,6 @@ async def test_search_current_user_dataset_records( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, @@ -3691,55 +3690,85 @@ async def test_search_current_user_dataset_records( } @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), + ("property_config", "metadata_filter", "expected_filter"), [ ( {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), + { + "type": "terms", + "values": ["value"], + "scope": {"entity": "metadata", "metadata_property": "terms_prop"}, + }, + TermsFilter(scope=MetadataFilterScope(metadata_property="terms_prop"), values=["value"]), ), ( {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), + { + "type": "terms", + "values": ["value1", "value2"], + "scope": {"entity": "metadata", "metadata_property": "terms_prop"}, + }, + TermsFilter(scope=MetadataFilterScope(metadata_property="terms_prop"), values=["value1", "value2"]), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), + { + "type": "range", + "ge": 10, + "le": 20, + "scope": {"entity": "metadata", "metadata_property": "integer_prop"}, + }, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + ge=10, + le=20, + ), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, high=None), + {"type": "range", "ge": 20, "scope": {"entity": "metadata", "metadata_property": "integer_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + ge=20, + ), ), ( {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(low=None, le=20), + {"type": "range", "le": 20, "scope": {"entity": "metadata", "metadata_property": "integer_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="integer_prop"), + le=20, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), + { + "type": "range", + "ge": -1.30, + "le": 23.23, + "scope": {"entity": "metadata", "metadata_property": "float_prop"}, + }, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + ge=-1.30, + le=23.23, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, high=None), + {"type": "range", "ge": 23.23, "scope": {"entity": "metadata", "metadata_property": "float_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + ge=23.23, + ), ), ( {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(low=None, le=11.32), + {"type": "range", "le": 11.32, "scope": {"entity": "metadata", "metadata_property": "float_prop"}}, + RangeFilter( + scope=MetadataFilterScope(metadata_property="float_prop"), + le=11.32, + ), ), ], ) @@ -3749,15 +3778,14 @@ async def test_search_current_user_dataset_records_with_metadata_filter( mock_search_engine: SearchEngine, owner: User, owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, + property_config, + metadata_filter: dict, + expected_filter: Any, ): workspace = await WorkspaceFactory.create() dataset, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - metadata_property = await MetadataPropertyFactory.create( + await MetadataPropertyFactory.create( name=property_config["name"], settings=property_config["settings"], dataset=dataset, @@ -3771,12 +3799,9 @@ async def test_search_current_user_dataset_records_with_metadata_filter( ], ) - params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}, "filters": {"and": [metadata_filter]}} response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params=params, headers=owner_auth_header, json=query_json, ) @@ -3785,7 +3810,7 @@ async def test_search_current_user_dataset_records_with_metadata_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], + filter=AndFilter(filters=[expected_filter]), user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, @@ -3793,62 +3818,6 @@ async def test_search_current_user_dataset_records_with_metadata_filter( user_id=owner.id, ) - @pytest.mark.parametrize( - ("property_config", "wrong_value"), - [ - ({"name": "terms_prop", "settings": {"type": "terms"}}, None), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "terms_prop"), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "terms_prop:"), - ({"name": "terms_prop", "settings": {"type": "terms"}}, "wrong-value"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, None), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop:"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "integer_prop:{}"), - ({"name": "integer_prop", "settings": {"type": "integer"}}, "wrong-value"), - ({"name": "float_prop", "settings": {"type": "float"}}, None), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop"), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop:"), - ({"name": "float_prop", "settings": {"type": "float"}}, "float_prop:{}"), - ({"name": "float_prop", "settings": {"type": "float"}}, "wrong-value"), - ], - ) - async def test_search_current_user_dataset_records_with_wrong_metadata_filter_values( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - wrong_value: str, - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - - await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - total=2, - ) - - params = {"metadata": [wrong_value]} - - query_json = {"query": {"text": {"q": "Hello"}}} - response = await async_client.post( - f"/api/v1/me/datasets/{dataset.id}/records/search", - params=params, - headers=owner_auth_header, - json=query_json, - ) - assert response.status_code == 422, response.json() - @pytest.mark.parametrize( "sorts", [ @@ -3915,7 +3884,6 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, @@ -4121,7 +4089,6 @@ async def test_search_current_user_dataset_records_with_include( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], sort_by=None, user_response_status_filter=None, offset=0, @@ -4337,7 +4304,6 @@ async def test_search_current_user_dataset_records_with_response_status_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatusFilter.submitted]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, @@ -4384,7 +4350,6 @@ async def test_search_current_user_dataset_records_with_record_vector( query=None, order=SimilarityOrder.most_similar, max_results=5, - metadata_filters=[], user_response_status_filter=None, ) @@ -4428,7 +4393,6 @@ async def test_search_current_user_dataset_records_with_vector_value( query=None, order=SimilarityOrder.most_similar, max_results=10, - metadata_filters=[], user_response_status_filter=None, ) @@ -4477,7 +4441,6 @@ async def test_search_current_user_dataset_records_with_vector_value_and_query( query=TextQuery(q="Test query"), order=SimilarityOrder.most_similar, max_results=10, - metadata_filters=[], user_response_status_filter=None, ) @@ -4570,7 +4533,6 @@ async def test_search_current_user_dataset_records_with_offset_and_limit( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - metadata_filters=[], user_response_status_filter=None, offset=0, limit=5, diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index a697f88510..62fcb141f6 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -12,43 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Type, Union -from uuid import uuid4 +from typing import List, Optional, Tuple, Union import pytest +from httpx import AsyncClient + from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.constants import API_KEY_HEADER_NAME -from argilla_server.enums import RecordInclude, RecordSortField, ResponseStatus, UserRole, RecordStatus +from argilla_server.enums import RecordInclude, ResponseStatus from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, - MetadataFilter, SearchEngine, SearchResponseItem, SearchResponses, SortBy, - TermsMetadataFilter, ) -from httpx import AsyncClient - from tests.factories import ( AdminFactory, AnnotatorFactory, DatasetFactory, LabelSelectionQuestionFactory, - MetadataPropertyFactory, RecordFactory, ResponseFactory, SuggestionFactory, TermsMetadataPropertyFactory, TextFieldFactory, TextQuestionFactory, - UserFactory, VectorFactory, VectorSettingsFactory, WorkspaceFactory, - WorkspaceUserFactory, ) @@ -398,108 +390,6 @@ async def create_records_with_response( for record in await RecordFactory.create_batch(size=num_records, dataset=dataset): await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status) - @pytest.mark.parametrize( - ("property_config", "param_value", "expected_filter_class", "expected_filter_args"), - [ - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value", - TermsMetadataFilter, - dict(values=["value"]), - ), - ( - {"name": "terms_prop", "settings": {"type": "terms"}}, - "value1,value2", - TermsMetadataFilter, - dict(values=["value1", "value2"]), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 10, "le": 20}', - IntegerMetadataFilter, - dict(ge=10, le=20), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"ge": 20}', - IntegerMetadataFilter, - dict(ge=20, high=None), - ), - ( - {"name": "integer_prop", "settings": {"type": "integer"}}, - '{"le": 20}', - IntegerMetadataFilter, - dict(ge=None, le=20), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": -1.30, "le": 23.23}', - FloatMetadataFilter, - dict(ge=-1.30, le=23.23), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"ge": 23.23}', - FloatMetadataFilter, - dict(ge=23.23, high=None), - ), - ( - {"name": "float_prop", "settings": {"type": "float"}}, - '{"le": 11.32}', - FloatMetadataFilter, - dict(ge=None, le=11.32), - ), - ], - ) - async def test_list_dataset_records_with_metadata_filter( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: User, - owner_auth_header: dict, - property_config: dict, - param_value: str, - expected_filter_class: Type[MetadataFilter], - expected_filter_args: dict, - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - metadata_property = await MetadataPropertyFactory.create( - name=property_config["name"], - settings=property_config["settings"], - dataset=dataset, - ) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = {"metadata": [f"{metadata_property.name}:{param_value}"]} - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - - response_json = response.json() - assert response_json["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - metadata_filters=[expected_filter_class(metadata_property=metadata_property, **expected_filter_args)], - user_response_status_filter=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=[SortBy(field=RecordSortField.inserted_at)], - ) - @pytest.mark.skip(reason="Factory integration with search engine") @pytest.mark.parametrize( "response_status_filter", ["missing", "pending", "discarded", "submitted", "draft", ["submitted", "draft"]] @@ -626,7 +516,6 @@ async def test_list_dataset_records_with_sort_by( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=None, - metadata_filters=[], user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index c4376ca686..1f91c7ddd9 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -19,15 +19,15 @@ from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( - FloatMetadataFilter, - IntegerMetadataFilter, ResponseFilterScope, SortBy, SuggestionFilterScope, TermsFilter, - TermsMetadataFilter, TextQuery, UserResponseStatusFilter, + Filter, + MetadataFilterScope, + RangeFilter, ) from argilla_server.search_engine.commons import ( ALL_RESPONSES_STATUSES_FIELD, @@ -676,19 +676,19 @@ async def test_search_with_response_status_filter_with_no_user( assert result.total == expected_items @pytest.mark.parametrize( - ("metadata_filters_config", "expected_items"), + ("filter", "expected_items"), [ - ([{"name": "label", "values": ["neutral"]}], 4), - ([{"name": "label", "values": ["positive"]}], 1), - ([{"name": "label", "values": ["neutral", "positive"]}], 5), - ([{"name": "textId", "ge": 3, "le": 4}], 2), - ([{"name": "textId", "ge": 3, "le": 3}], 1), - ([{"name": "textId", "ge": 3}], 6), - ([{"name": "textId", "le": 4}], 5), - ([{"name": "seq_float", "ge": 0.0, "le": 12.03}], 3), - ([{"name": "seq_float", "ge": 0.13, "le": 0.13}], 1), - ([{"name": "seq_float", "ge": 0.0}], 7), - ([{"name": "seq_float", "le": 12.03}], 5), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["neutral"]), 4), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["positive"]), 1), + (TermsFilter(scope=MetadataFilterScope(metadata_property="label"), values=["neutral", "positive"]), 5), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3, le=4), 2), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3, le=3), 1), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), ge=3), 6), + (RangeFilter(scope=MetadataFilterScope(metadata_property="textId"), le=4), 5), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0, le=12.03), 3), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0.13, le=0.13), 1), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), ge=0.0), 7), + (RangeFilter(scope=MetadataFilterScope(metadata_property="seq_float"), le=12.03), 5), ], ) async def test_search_with_metadata_filter( @@ -696,24 +696,10 @@ async def test_search_with_metadata_filter( search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - metadata_filters_config: List[dict], + filter: Filter, expected_items: int, ): - metadata_filters = [] - for metadata_filter_config in metadata_filters_config: - name = metadata_filter_config.pop("name") - for metadata_property in test_banking_sentiment_dataset.metadata_properties: - if name == metadata_property.name: - if metadata_property.type == MetadataPropertyType.terms: - filter_cls = TermsMetadataFilter - elif metadata_property.type == MetadataPropertyType.integer: - filter_cls = IntegerMetadataFilter - else: - filter_cls = FloatMetadataFilter - metadata_filters.append(filter_cls(metadata_property=metadata_property, **metadata_filter_config)) - break - - result = await search_engine.search(test_banking_sentiment_dataset, metadata_filters=metadata_filters) + result = await search_engine.search(test_banking_sentiment_dataset, filter=filter) assert len(result.items) == expected_items assert result.total == expected_items From 040446508bd25bba5cf8c817a89428d0c445d681 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Thu, 4 Jul 2024 16:55:36 +0200 Subject: [PATCH 07/36] [BREAKING - REFACTOR] `argilla-server`: remove user response status support (#5163) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description > [!NOTE] > This PR must be merged after https://github.com/argilla-io/argilla/pull/5156 This PR removes support for filtering records with response_status query param: - This filter is removed for listing records endpoints - The response status filter is available for search endpoints using the filter request body. **Type of change** - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: José Francisco Calvo --- argilla-server/CHANGELOG.md | 1 + .../api/handlers/v1/datasets/records.py | 12 ------ .../src/argilla_server/search_engine/base.py | 4 -- .../argilla_server/search_engine/commons.py | 25 ------------ .../datasets/test_search_dataset_records.py | 2 - .../unit/api/handlers/v1/test_datasets.py | 38 +++++++++++-------- .../handlers/v1/test_list_dataset_records.py | 1 - .../tests/unit/search_engine/test_commons.py | 24 ++++++------ 8 files changed, 37 insertions(+), 70 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 27aada3fee..05a8e7e583 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -28,6 +28,7 @@ These are the section headers that we use: ### Removed - Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) +- [breaking] Removed support for `response_status` query param. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) - [breaking] Removed support for `metadata` query param. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index e1a2d519dc..4f68c1125f 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -103,7 +103,6 @@ async def _filter_records_using_search_engine( limit: int, offset: int, user: Optional[User] = None, - response_statuses: Optional[List[ResponseStatusFilter]] = None, include: Optional[RecordIncludeParam] = None, sort_by_query_param: Optional[Dict[str, str]] = None, ) -> Tuple[List[Record], int]: @@ -114,7 +113,6 @@ async def _filter_records_using_search_engine( limit=limit, offset=offset, user=user, - response_statuses=response_statuses, sort_by_query_param=sort_by_query_param, ) @@ -178,7 +176,6 @@ async def _get_search_responses( offset: int, search_records_query: Optional[SearchRecordsQuery] = None, user: Optional[User] = None, - response_statuses: Optional[List[ResponseStatusFilter]] = None, sort_by_query_param: Optional[Dict[str, str]] = None, ) -> "SearchResponses": search_records_query = search_records_query or SearchRecordsQuery() @@ -219,7 +216,6 @@ async def _get_search_responses( if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id): raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.") - response_status_filter = await _build_response_status_filter_for_search(response_statuses, user=user) sort_by = await _build_sort_by(db, dataset, sort_by_query_param) if vector_query and vector_settings: @@ -230,7 +226,6 @@ async def _get_search_responses( "record": record, "query": text_query, "order": vector_query.order, - "user_response_status_filter": response_status_filter, "max_results": limit, } @@ -242,7 +237,6 @@ async def _get_search_responses( search_params = { "dataset": dataset, "query": text_query, - "user_response_status_filter": response_status_filter, "offset": offset, "limit": limit, "sort_by": sort_by, @@ -323,7 +317,6 @@ async def list_dataset_records( dataset_id: UUID, sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -338,7 +331,6 @@ async def list_dataset_records( dataset=dataset, limit=limit, offset=offset, - response_statuses=response_statuses, include=include, sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY, ) @@ -451,7 +443,6 @@ async def search_current_user_dataset_records( body: SearchRecordsQuery, sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -477,7 +468,6 @@ async def search_current_user_dataset_records( limit=limit, offset=offset, user=current_user, - response_statuses=response_statuses, sort_by_query_param=sort_by_query_param, ) @@ -523,7 +513,6 @@ async def search_dataset_records( body: SearchRecordsQuery, sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), - response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), @@ -541,7 +530,6 @@ async def search_dataset_records( search_records_query=body, limit=limit, offset=offset, - response_statuses=response_statuses, sort_by_query_param=sort_by_query_param, ) diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index 08a4e459c8..687c51bad3 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -283,7 +283,6 @@ async def search( filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, # TODO: remove them and keep filter and order - user_response_status_filter: Optional[UserResponseStatusFilter] = None, sort_by: Optional[List[SortBy]] = None, # END TODO offset: int = 0, @@ -311,9 +310,6 @@ async def similarity_search( record: Optional[Record] = None, query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, - # TODO: remove them and keep filter - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - # END TODO max_results: int = 100, order: SimilarityOrder = SimilarityOrder.most_similar, threshold: Optional[float] = None, diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 501bff03f3..0b7606c642 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -199,19 +199,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str: return str(vector_settings.id) -# This function will be moved once the response status filter is removed from search and similarity_search methods -def _unify_user_response_status_filter_with_filter( - user_response_status_filter: UserResponseStatusFilter, filter: Optional[Filter] = None -) -> Filter: - scope = ResponseFilterScope(user=user_response_status_filter.user, property="status") - response_filter = TermsFilter(scope=scope, values=[status.value for status in user_response_status_filter.statuses]) - - if filter: - return AndFilter(filters=[filter, response_filter]) - else: - return response_filter - - # This function will be moved once the `sort_by` argument is removed from search and similarity_search methods def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]: if order: @@ -393,18 +380,10 @@ async def similarity_search( record: Optional[Record] = None, query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, - # TODO: remove them and keep filter - user_response_status_filter: Optional[UserResponseStatusFilter] = None, - # END TODO max_results: int = 100, order: SimilarityOrder = SimilarityOrder.most_similar, threshold: Optional[float] = None, ) -> SearchResponses: - # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if user_response_status_filter and user_response_status_filter.statuses: - filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter) - # END TODO - if bool(value) == bool(record): raise ValueError("Must provide either vector value or record to compute the similarity search") @@ -598,7 +577,6 @@ async def search( filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, # TODO: Remove these arguments - user_response_status_filter: Optional[UserResponseStatusFilter] = None, sort_by: Optional[List[SortBy]] = None, # END TODO offset: int = 0, @@ -608,9 +586,6 @@ async def search( # See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if user_response_status_filter and user_response_status_filter.statuses: - filter = _unify_user_response_status_filter_with_filter(user_response_status_filter, filter) - if sort_by: sort = _unify_sort_by_with_order(sort_by, sort) # END TODO diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 253c8a9433..5229f53cf9 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -320,7 +320,6 @@ async def test_with_filter( limit=50, query=None, sort_by=None, - user_response_status_filter=None, ) async def test_with_sort( @@ -370,7 +369,6 @@ async def test_with_sort( limit=50, query=None, sort_by=None, - user_response_status_filter=None, ) async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict): diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index bc75af2ce5..1d9ddcf22c 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -19,6 +19,8 @@ from uuid import UUID, uuid4 import pytest +from sqlalchemy import func, inspect, select + from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.api.schemas.v1.datasets import DATASET_GUIDELINES_MAX_LENGTH, DATASET_NAME_MAX_LENGTH from argilla_server.api.schemas.v1.fields import FIELD_CREATE_NAME_MAX_LENGTH, FIELD_CREATE_TITLE_MAX_LENGTH @@ -62,14 +64,12 @@ SearchResponses, SortBy, TextQuery, - UserResponseStatusFilter, AndFilter, TermsFilter, MetadataFilterScope, RangeFilter, + ResponseFilterScope, ) -from sqlalchemy import func, inspect, select - from tests.factories import ( AdminFactory, AnnotatorFactory, @@ -80,7 +80,6 @@ LabelSelectionQuestionFactory, MetadataPropertyFactory, MultiLabelSelectionQuestionFactory, - OwnerFactory, QuestionFactory, RatingQuestionFactory, RecordFactory, @@ -3650,7 +3649,6 @@ async def test_search_current_user_dataset_records( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, sort_by=None, @@ -3811,7 +3809,6 @@ async def test_search_current_user_dataset_records_with_metadata_filter( dataset=dataset, query=TextQuery(q="Hello", field="input"), filter=AndFilter(filters=[expected_filter]), - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, sort_by=None, @@ -3884,7 +3881,6 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, sort_by=expected_sorts_by, @@ -4090,7 +4086,6 @@ async def test_search_current_user_dataset_records_with_include( dataset=dataset, query=TextQuery(q="Hello", field="input"), sort_by=None, - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -4293,18 +4288,35 @@ async def test_search_current_user_dataset_records_with_response_status_filter( dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) mock_search_engine.search.return_value = SearchResponses(items=[]) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "filters": { + "and": [ + { + "type": "terms", + "scope": {"entity": "response", "property": "status"}, + "values": [ResponseStatus.submitted], + } + ] + }, + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=owner_auth_header, json=query_json, - params={"response_status": ResponseStatus.submitted.value}, ) mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatusFilter.submitted]), + filter=AndFilter( + filters=[ + TermsFilter( + scope=ResponseFilterScope(property="status", user=owner), + values=[ResponseStatusFilter.submitted], + ) + ] + ), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, sort_by=None, @@ -4350,7 +4362,6 @@ async def test_search_current_user_dataset_records_with_record_vector( query=None, order=SimilarityOrder.most_similar, max_results=5, - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_vector_value( @@ -4393,7 +4404,6 @@ async def test_search_current_user_dataset_records_with_vector_value( query=None, order=SimilarityOrder.most_similar, max_results=10, - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_vector_value_and_query( @@ -4441,7 +4451,6 @@ async def test_search_current_user_dataset_records_with_vector_value_and_query( query=TextQuery(q="Test query"), order=SimilarityOrder.most_similar, max_results=10, - user_response_status_filter=None, ) async def test_search_current_user_dataset_records_with_wrong_vector( @@ -4533,7 +4542,6 @@ async def test_search_current_user_dataset_records_with_offset_and_limit( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - user_response_status_filter=None, offset=0, limit=5, sort_by=None, diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index 62fcb141f6..db2605c3de 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -516,7 +516,6 @@ async def test_list_dataset_records_with_sort_by( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=None, - user_response_status_filter=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, sort_by=expected_sorts_by, diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index 1f91c7ddd9..5ae8241927 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -595,7 +595,7 @@ async def test_search_with_response_status_filter( result = await search_engine.search( test_banking_sentiment_dataset, query=TextQuery(q="payment"), - user_response_status_filter=UserResponseStatusFilter(user=user, statuses=statuses), + filter=TermsFilter(scope=ResponseFilterScope(property="status"), values=statuses), ) assert len(result.items) == expected_items assert result.total == expected_items @@ -669,7 +669,7 @@ async def test_search_with_response_status_filter_with_no_user( result = await search_engine.search( test_banking_sentiment_dataset, - user_response_status_filter=UserResponseStatusFilter(statuses=statuses, user=None), + filter=TermsFilter(ResponseFilterScope(property="status"), values=statuses), ) assert len(result.items) == expected_items @@ -734,7 +734,7 @@ async def test_search_with_response_status_filter_does_not_affect_the_result_sco results = await search_engine.search( test_banking_sentiment_dataset, query=TextQuery(q="payment"), - user_response_status_filter=UserResponseStatusFilter(user=user, statuses=all_statuses), + filter=TermsFilter(scope=ResponseFilterScope(property="status", user=user), values=all_statuses), ) assert len(no_filter_results.items) == len(results.items) @@ -1334,32 +1334,34 @@ async def test_similarity_search_by_vector_value_with_order( assert responses.items[0].record_id != selected_record.id @pytest.mark.parametrize( - "user_response_status_filter", + "statuses", [ - None, - UserResponseStatusFilter(statuses=[ResponseStatusFilter.missing, ResponseStatusFilter.draft]), + [], + [ResponseStatusFilter.missing, ResponseStatusFilter.draft], ], ) - async def test_similarity_search_by_record_and_user_response_filter( + async def test_similarity_search_by_record_and_response_status_filter( self, search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset_with_vectors: Dataset, - user_response_status_filter: UserResponseStatusFilter, + statuses: List[ResponseStatusFilter], ): selected_record: Record = test_banking_sentiment_dataset_with_vectors.records[0] vector_settings: VectorSettings = test_banking_sentiment_dataset_with_vectors.vectors_settings[0] - if user_response_status_filter: + scope = ResponseFilterScope(property="status") + + if statuses: test_user = await UserFactory.create() - user_response_status_filter.user = test_user + scope.user = test_user responses = await search_engine.similarity_search( dataset=test_banking_sentiment_dataset_with_vectors, vector_settings=vector_settings, record=selected_record, max_results=1, - user_response_status_filter=user_response_status_filter, + filter=TermsFilter(scope=scope, values=statuses), ) assert responses.total == 1 From 20d4ab810915fd3838a898c405352f8d88cc363c Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 4 Jul 2024 17:34:16 +0200 Subject: [PATCH 08/36] refactor: Remove sort_by argument --- .../src/argilla_server/search_engine/base.py | 3 -- .../argilla_server/search_engine/commons.py | 33 +++---------------- 2 files changed, 4 insertions(+), 32 deletions(-) diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index 687c51bad3..db5bc87e2a 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -282,9 +282,6 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: remove them and keep filter and order - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, ) -> SearchResponses: diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 0b7606c642..5b9d5e66bc 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -199,23 +199,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str: return str(vector_settings.id) -# This function will be moved once the `sort_by` argument is removed from search and similarity_search methods -def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]: - if order: - return order - - new_order = [] - for sort in sort_by: - if isinstance(sort.field, MetadataProperty): - scope = MetadataFilterScope(metadata_property=sort.field.name) - else: - scope = RecordFilterScope(property=sort.field) - - new_order.append(Order(scope=scope, order=sort.order)) - - return new_order - - def is_response_status_scope(scope: FilterScope) -> bool: return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None @@ -327,14 +310,14 @@ async def update_record_response(self, response: Response): es_responses = self._map_record_responses_to_es([response]) - await self._update_document_request(index_name, id=record.id, body={"doc": {"responses": es_responses}}) + await self._update_document_request(index_name, id=str(record.id), body={"doc": {"responses": es_responses}}) async def delete_record_response(self, response: Response): record = response.record index_name = await self._get_dataset_index(record.dataset) await self._update_document_request( - index_name, id=record.id, body={"script": es_script_for_delete_user_response(response.user)} + index_name, id=str(record.id), body={"script": es_script_for_delete_user_response(response.user)} ) async def update_record_suggestion(self, suggestion: Suggestion): @@ -344,7 +327,7 @@ async def update_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"doc": {"suggestions": es_suggestions}}, ) @@ -353,7 +336,7 @@ async def delete_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'}, ) @@ -576,19 +559,11 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: Remove these arguments - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, user_id: Optional[str] = None, ) -> SearchResponses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html - - # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if sort_by: - sort = _unify_sort_by_with_order(sort_by, sort) - # END TODO index = await self._get_dataset_index(dataset) text_query = self._build_text_query(dataset, text=query) From 5f4e5b0efcb872f3d5523c58486e18c0d8eabbf4 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 4 Jul 2024 17:37:27 +0200 Subject: [PATCH 09/36] [breaking] refactor: Remove sort_by query param --- .../api/handlers/v1/datasets/records.py | 72 +------------------ 1 file changed, 1 insertion(+), 71 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 4f68c1125f..065295612f 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -19,7 +19,6 @@ from fastapi import APIRouter, Depends, Query, Security, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from typing_extensions import Annotated import argilla_server.search_engine as search_engine from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized @@ -52,12 +51,11 @@ from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE -from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings +from argilla_server.models import Dataset, Field, Record, User, VectorSettings from argilla_server.search_engine import ( AndFilter, SearchEngine, SearchResponses, - SortBy, UserResponseStatusFilter, get_search_engine, ) @@ -70,25 +68,6 @@ LIST_DATASET_RECORDS_DEFAULT_SORT_BY = {RecordSortField.inserted_at.value: "asc"} DELETE_DATASET_RECORDS_LIMIT = 100 -_RECORD_SORT_FIELD_VALUES = tuple(field.value for field in RecordSortField) -_VALID_SORT_VALUES = tuple(sort.value for sort in SortOrder) -_METADATA_PROPERTY_SORT_BY_REGEX = re.compile(r"^metadata\.(?P(?=.*[a-z0-9])[a-z0-9_-]+)$") - -SortByQueryParamParsed = Annotated[ - Dict[str, str], - Depends( - parse_query_param( - name="sort_by", - description=( - "The field used to sort the records. Expected format is `field` or `field:{asc,desc}`, where `field`" - " can be 'inserted_at', 'updated_at' or the name of a metadata property" - ), - max_values_per_key=1, - group_keys_without_values=False, - ) - ), -] - parse_record_include_param = parse_query_param( name="include", help="Relationships to include in the response", model=RecordIncludeParam ) @@ -104,7 +83,6 @@ async def _filter_records_using_search_engine( offset: int, user: Optional[User] = None, include: Optional[RecordIncludeParam] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> Tuple[List[Record], int]: search_responses = await _get_search_responses( db=db, @@ -113,7 +91,6 @@ async def _filter_records_using_search_engine( limit=limit, offset=offset, user=user, - sort_by_query_param=sort_by_query_param, ) record_ids = [response.record_id for response in search_responses.items] @@ -176,7 +153,6 @@ async def _get_search_responses( offset: int, search_records_query: Optional[SearchRecordsQuery] = None, user: Optional[User] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> "SearchResponses": search_records_query = search_records_query or SearchRecordsQuery() @@ -216,8 +192,6 @@ async def _get_search_responses( if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id): raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.") - sort_by = await _build_sort_by(db, dataset, sort_by_query_param) - if vector_query and vector_settings: similarity_search_params = { "dataset": dataset, @@ -239,7 +213,6 @@ async def _get_search_responses( "query": text_query, "offset": offset, "limit": limit, - "sort_by": sort_by, } if user is not None: @@ -265,43 +238,6 @@ async def _build_response_status_filter_for_search( return user_response_status_filter -async def _build_sort_by( - db: "AsyncSession", dataset: Dataset, sort_by_query_param: Optional[Dict[str, str]] = None -) -> Union[List[SortBy], None]: - if sort_by_query_param is None: - return None - - sorts_by = [] - for sort_field, sort_order in sort_by_query_param.items(): - if sort_field in _RECORD_SORT_FIELD_VALUES: - field = sort_field - elif (match := _METADATA_PROPERTY_SORT_BY_REGEX.match(sort_field)) is not None: - metadata_property_name = match.group("name") - metadata_property = await MetadataProperty.get_by(db, name=metadata_property_name, dataset_id=dataset.id) - if not metadata_property: - raise UnprocessableEntityError( - f"Provided metadata property in 'sort_by' query param '{metadata_property_name}' not found in " - f"dataset with '{dataset.id}'." - ) - - field = metadata_property - else: - valid_sort_fields = ", ".join(f"'{sort_field}'" for sort_field in _RECORD_SORT_FIELD_VALUES) - raise UnprocessableEntityError( - f"Provided sort field in 'sort_by' query param '{sort_field}' is not valid. It must be either" - f" {valid_sort_fields} or `metadata.metadata-property-name`" - ) - - if sort_order is not None and sort_order not in _VALID_SORT_VALUES: - raise UnprocessableEntityError( - f"Provided sort order in 'sort_by' query param '{sort_order}' for field '{sort_field}' is not valid.", - ) - - sorts_by.append(SortBy(field=field, order=sort_order or SortOrder.asc.value)) - - return sorts_by - - async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): try: await search.validate_search_records_query(db, query, dataset_id) @@ -315,7 +251,6 @@ async def list_dataset_records( db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -332,7 +267,6 @@ async def list_dataset_records( limit=limit, offset=offset, include=include, - sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY, ) return Records(items=records, total=total) @@ -441,7 +375,6 @@ async def search_current_user_dataset_records( telemetry_client: TelemetryClient = Depends(get_telemetry_client), dataset_id: UUID, body: SearchRecordsQuery, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -468,7 +401,6 @@ async def search_current_user_dataset_records( limit=limit, offset=offset, user=current_user, - sort_by_query_param=sort_by_query_param, ) record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = { @@ -511,7 +443,6 @@ async def search_dataset_records( search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, body: SearchRecordsQuery, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -530,7 +461,6 @@ async def search_dataset_records( search_records_query=body, limit=limit, offset=offset, - sort_by_query_param=sort_by_query_param, ) record_id_score_map = { From c8853921e73b67e7c6160a346d6dbcd8c2b7974b Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 4 Jul 2024 17:38:03 +0200 Subject: [PATCH 10/36] tests: Adapt tests --- .../datasets/test_search_dataset_records.py | 2 - .../unit/api/handlers/v1/test_datasets.py | 111 ++++++++-------- .../handlers/v1/test_list_dataset_records.py | 121 ------------------ .../tests/unit/search_engine/test_commons.py | 32 +++-- 4 files changed, 78 insertions(+), 188 deletions(-) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 5229f53cf9..5e3c6653de 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -319,7 +319,6 @@ async def test_with_filter( offset=0, limit=50, query=None, - sort_by=None, ) async def test_with_sort( @@ -368,7 +367,6 @@ async def test_with_sort( offset=0, limit=50, query=None, - sort_by=None, ) async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict): diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 1d9ddcf22c..f84154b4c4 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -14,7 +14,7 @@ import math import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type from unittest.mock import ANY, MagicMock from uuid import UUID, uuid4 @@ -43,6 +43,7 @@ ResponseStatusFilter, SimilarityOrder, RecordStatus, + SortOrder, ) from argilla_server.models import ( Dataset, @@ -62,13 +63,14 @@ SearchEngine, SearchResponseItem, SearchResponses, - SortBy, TextQuery, AndFilter, TermsFilter, MetadataFilterScope, RangeFilter, ResponseFilterScope, + Order, + RecordFilterScope, ) from tests.factories import ( AdminFactory, @@ -3651,7 +3653,6 @@ async def test_search_current_user_dataset_records( query=TextQuery(q="Hello", field="input"), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -3811,31 +3812,42 @@ async def test_search_current_user_dataset_records_with_metadata_filter( filter=AndFilter(filters=[expected_filter]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) @pytest.mark.parametrize( - "sorts", + "sort,expected_sort", [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "asc"}], + [Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.asc)], + ), + ( + [ + {"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}, + {"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "desc"}, + ], + [ + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), + Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.desc), + ], + ), ], ) async def test_search_current_user_dataset_records_with_sort_by( @@ -3844,16 +3856,15 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine: SearchEngine, owner: "User", owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], + sort: List[dict], + expected_sort: List[Order], ): workspace = await WorkspaceFactory.create() dataset, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) + for order in expected_sort: + if isinstance(order.scope, MetadataFilterScope): + await TermsMetadataPropertyFactory.create(name=order.scope.metadata_property, dataset=dataset) mock_search_engine.search.return_value = SearchResponses( total=2, @@ -3863,15 +3874,13 @@ async def test_search_current_user_dataset_records_with_sort_by( ], ) - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": sort, } - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} - response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params=query_params, headers=owner_auth_header, json=query_json, ) @@ -3883,7 +3892,7 @@ async def test_search_current_user_dataset_records_with_sort_by( query=TextQuery(q="Hello", field="input"), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, + sort=expected_sort, user_id=owner.id, ) @@ -3893,18 +3902,17 @@ async def test_search_current_user_dataset_records_with_sort_by_with_wrong_sort_ workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "record", "property": "wrong_property"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } async def test_search_current_user_dataset_records_with_sort_by_with_non_existent_metadata_property( self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict @@ -3912,17 +3920,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_non_existen workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "metadata", "metadata_property": "missing"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "metadata.i-do-not-exist:asc"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." + "detail": f"MetadataProperty not found filtering by name=missing, dataset_id={dataset.id}" } async def test_search_current_user_dataset_records_with_sort_by_with_invalid_field( @@ -3931,19 +3941,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_invalid_fie workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [ + {"scope": {"entity": "wrong", "property": "wrong"}, "order": "asc"}, + ], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "not-valid"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } @pytest.mark.parametrize( "includes", @@ -4085,7 +4095,6 @@ async def test_search_current_user_dataset_records_with_include( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - sort_by=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -4319,7 +4328,6 @@ async def test_search_current_user_dataset_records_with_response_status_filter( ), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -4544,7 +4552,6 @@ async def test_search_current_user_dataset_records_with_offset_and_limit( query=TextQuery(q="Hello", field="input"), offset=0, limit=5, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index db2605c3de..4f989e5399 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -17,16 +17,9 @@ import pytest from httpx import AsyncClient -from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.constants import API_KEY_HEADER_NAME from argilla_server.enums import RecordInclude, ResponseStatus from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace -from argilla_server.search_engine import ( - SearchEngine, - SearchResponseItem, - SearchResponses, - SortBy, -) from tests.factories import ( AdminFactory, AnnotatorFactory, @@ -35,7 +28,6 @@ RecordFactory, ResponseFactory, SuggestionFactory, - TermsMetadataPropertyFactory, TextFieldFactory, TextQuestionFactory, VectorFactory, @@ -453,119 +445,6 @@ async def test_list_dataset_records_with_response_status_filter( ] ) - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - ], - ) - async def test_list_dataset_records_with_sort_by( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: "User", - owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] - } - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - assert response.json()["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, - ) - - async def test_list_dataset_records_with_sort_by_with_wrong_sort_order_value( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } - - async def test_list_dataset_records_with_sort_by_with_non_existent_metadata_property( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "metadata.i-do-not-exist:asc"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." - } - - async def test_list_dataset_records_with_sort_by_with_invalid_field( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "not-valid"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } - async def test_list_dataset_records_without_authentication(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index 5ae8241927..c893366b58 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -16,7 +16,14 @@ import pytest import pytest_asyncio -from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus +from argilla_server.enums import ( + MetadataPropertyType, + QuestionType, + ResponseStatusFilter, + SimilarityOrder, + RecordStatus, + SortOrder, +) from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( ResponseFilterScope, @@ -28,6 +35,8 @@ Filter, MetadataFilterScope, RangeFilter, + Order, + RecordFilterScope, ) from argilla_server.search_engine.commons import ( ALL_RESPONSES_STATUSES_FIELD, @@ -820,12 +829,12 @@ async def test_search_with_pagination( assert all_results.items[offset : offset + limit] == results.items @pytest.mark.parametrize( - ("sort_by"), + ("sort_order"), [ - SortBy(field="inserted_at"), - SortBy(field="updated_at"), - SortBy(field="inserted_at", order="desc"), - SortBy(field="updated_at", order="desc"), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), ], ) async def test_search_with_sort_by( @@ -833,18 +842,15 @@ async def test_search_with_sort_by( search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - sort_by: SortBy, + sort_order: Order, ): def _local_sort_by(record: Record) -> Any: - if isinstance(sort_by.field, str): - return getattr(record, sort_by.field) - return record.metadata_[sort_by.field.name] + return getattr(record, sort_order.scope.property) - results = await search_engine.search(test_banking_sentiment_dataset, sort_by=[sort_by]) + results = await search_engine.search(test_banking_sentiment_dataset, sort=[sort_order]) records = test_banking_sentiment_dataset.records - if sort_by: - records = sorted(records, key=_local_sort_by, reverse=sort_by.order == "desc") + records = sorted(records, key=_local_sort_by, reverse=sort_order.order == "desc") assert [item.record_id for item in results.items] == [record.id for record in records] From 28b2998adee3c86fbbc5cf32316492422ec72a86 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 5 Jul 2024 09:30:20 +0200 Subject: [PATCH 11/36] chore: Update changelog --- argilla-server/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 05a8e7e583..07621c8dad 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -30,6 +30,7 @@ These are the section headers that we use: - Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) - [breaking] Removed support for `response_status` query param. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) - [breaking] Removed support for `metadata` query param. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) +- [breaking] Removed support for `sort_by` query param. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) From 209d64d9335e1772798fa48d39f58c66ba7e1a3e Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 5 Jul 2024 11:37:54 +0200 Subject: [PATCH 12/36] feat: Define new repositories --- .../argilla_server/repositories/__init__.py | 18 ++++++ .../argilla_server/repositories/datasets.py | 29 +++++++++ .../argilla_server/repositories/records.py | 63 +++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 argilla-server/src/argilla_server/repositories/__init__.py create mode 100644 argilla-server/src/argilla_server/repositories/datasets.py create mode 100644 argilla-server/src/argilla_server/repositories/records.py diff --git a/argilla-server/src/argilla_server/repositories/__init__.py b/argilla-server/src/argilla_server/repositories/__init__.py new file mode 100644 index 0000000000..98424d94a6 --- /dev/null +++ b/argilla-server/src/argilla_server/repositories/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argilla_server.repositories.datasets import DatasetsRepository +from argilla_server.repositories.records import RecordsRepository + +__all__ = ["DatasetsRepository", "RecordsRepository"] diff --git a/argilla-server/src/argilla_server/repositories/datasets.py b/argilla-server/src/argilla_server/repositories/datasets.py new file mode 100644 index 0000000000..46ac90e2fe --- /dev/null +++ b/argilla-server/src/argilla_server/repositories/datasets.py @@ -0,0 +1,29 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID + +from fastapi import Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.database import get_async_db +from argilla_server.models import Dataset + + +class DatasetsRepository: + def __init__(self, db: AsyncSession = Depends(get_async_db)): + self.db = db + + async def get(self, dataset_id: UUID) -> Dataset: + return await Dataset.get_or_raise(db=self.db, id=dataset_id) diff --git a/argilla-server/src/argilla_server/repositories/records.py b/argilla-server/src/argilla_server/repositories/records.py new file mode 100644 index 0000000000..ea93d94028 --- /dev/null +++ b/argilla-server/src/argilla_server/repositories/records.py @@ -0,0 +1,63 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List, Tuple, Sequence +from uuid import UUID + +from fastapi import Depends +from sqlalchemy import select, and_, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload, contains_eager + +from argilla_server.database import get_async_db +from argilla_server.models import Record, VectorSettings, Vector + + +class RecordsRepository: + def __init__( + self, + db: AsyncSession = Depends(get_async_db), + ): + self.db = db + + async def list_by_dataset_id( + self, + dataset_id: UUID, + offset: int, + limit: int, + with_responses: bool = False, + with_suggestions: bool = False, + with_vectors: Union[bool, List[str]] = False, + ) -> Tuple[Sequence[Record], int]: + query = select(Record).filter_by(dataset_id=dataset_id) + + if with_responses: + query = query.options(selectinload(Record.responses)) + if with_suggestions: + query = query.options(selectinload(Record.suggestions)) + if with_vectors is True: + query = query.options(selectinload(Record.vectors)) + elif isinstance(with_vectors, list): + subquery = select(VectorSettings.id).filter( + and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) + ) + query = query.outerjoin( + Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) + ).options(contains_eager(Record.vectors)) + + records = (await self.db.scalars(query.offset(offset).limit(limit).order_by(Record.inserted_at))).unique().all() + + total = await self.db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id)) + + return records, total From a350b0c46075dd718d20f2abd05d609df684a54e Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 5 Jul 2024 11:39:04 +0200 Subject: [PATCH 13/36] chore: Rewrite list endpoint using repositories --- .../api/handlers/v1/datasets/records.py | 75 +++++-------------- 1 file changed, 20 insertions(+), 55 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 065295612f..b6416ac4f1 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Query, Security, status @@ -48,15 +47,15 @@ ) from argilla_server.contexts import datasets, search from argilla_server.database import get_async_db -from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder +from argilla_server.enums import RecordSortField from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE from argilla_server.models import Dataset, Field, Record, User, VectorSettings +from argilla_server.repositories import DatasetsRepository, RecordsRepository from argilla_server.search_engine import ( AndFilter, SearchEngine, SearchResponses, - UserResponseStatusFilter, get_search_engine, ) from argilla_server.security import auth @@ -75,35 +74,6 @@ router = APIRouter() -async def _filter_records_using_search_engine( - db: "AsyncSession", - search_engine: "SearchEngine", - dataset: Dataset, - limit: int, - offset: int, - user: Optional[User] = None, - include: Optional[RecordIncludeParam] = None, -) -> Tuple[List[Record], int]: - search_responses = await _get_search_responses( - db=db, - search_engine=search_engine, - dataset=dataset, - limit=limit, - offset=offset, - user=user, - ) - - record_ids = [response.record_id for response in search_responses.items] - user_id = user.id if user else None - - return ( - await datasets.get_records_by_ids( - db=db, dataset_id=dataset.id, user_id=user_id, records_ids=record_ids, include=include - ), - search_responses.total, - ) - - def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> search_engine.FilterScope: if isinstance(scope, RecordFilterScope): return search_engine.RecordFilterScope(property=scope.property) @@ -226,18 +196,6 @@ async def _get_search_responses( return await search_engine.search(**search_params) -async def _build_response_status_filter_for_search( - response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None -) -> Optional[UserResponseStatusFilter]: - user_response_status_filter = None - - if response_statuses: - # TODO(@frascuchon): user response and status responses should be split into different filter types - user_response_status_filter = UserResponseStatusFilter(user=user, statuses=response_statuses) - - return user_response_status_filter - - async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): try: await search.validate_search_records_query(db, query, dataset_id) @@ -248,25 +206,32 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord @router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) async def list_dataset_records( *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), + datasets_repository: DatasetsRepository = Depends(), + records_repository: RecordsRepository = Depends(), dataset_id: UUID, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), ): - dataset = await Dataset.get_or_raise(db, dataset_id) - + dataset = await datasets_repository.get(dataset_id) await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset)) - records, total = await _filter_records_using_search_engine( - db, - search_engine, - dataset=dataset, - limit=limit, + include_args = ( + dict( + with_responses=include.with_responses, + with_suggestions=include.with_suggestions, + with_vectors=include.with_all_vectors or include.vectors, + ) + if include + else {} + ) + + records, total = await records_repository.list_by_dataset_id( + dataset_id=dataset.id, offset=offset, - include=include, + limit=limit, + **include_args, ) return Records(items=records, total=total) From 3537941e97a24260f179b3904db6e35de83f0b23 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 5 Jul 2024 11:39:55 +0200 Subject: [PATCH 14/36] tests: Enable skip tests for list dataset records --- .../handlers/v1/test_list_dataset_records.py | 150 ++++++++---------- 1 file changed, 67 insertions(+), 83 deletions(-) diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index 4f989e5399..80758804b8 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -38,7 +38,6 @@ @pytest.mark.asyncio class TestSuiteListDatasetRecords: - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -58,25 +57,31 @@ async def test_list_dataset_records(self, async_client: "AsyncClient", owner_aut "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"record_a": "value_a"}, "metadata": None, "external_id": record_a.external_id, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"record_b": "value_b"}, "metadata": {"unit": "test"}, "external_id": record_b.external_id, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"record_c": "value_c"}, "metadata": None, "external_id": record_c.external_id, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -188,7 +193,6 @@ async def test_list_dataset_records_with_include( assert response.status_code == 200 - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_include_vectors( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -214,6 +218,7 @@ async def test_list_dataset_records_with_include_vectors( "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -221,26 +226,31 @@ async def test_list_dataset_records_with_include_vectors( "vector-a": [1.0, 2.0, 3.0], "vector-b": [4.0, 5.0], }, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, "vectors": { "vector-b": [1.0, 2.0], }, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, "vectors": {}, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -248,7 +258,6 @@ async def test_list_dataset_records_with_include_vectors( "total": 3, } - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_include_specific_vectors( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -278,6 +287,7 @@ async def test_list_dataset_records_with_include_specific_vectors( "items": [ { "id": str(record_a.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_a.external_id, @@ -285,26 +295,31 @@ async def test_list_dataset_records_with_include_specific_vectors( "vector-a": [1.0, 2.0, 3.0], "vector-b": [4.0, 5.0], }, + "status": "pending", "inserted_at": record_a.inserted_at.isoformat(), "updated_at": record_a.updated_at.isoformat(), }, { "id": str(record_b.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_b.external_id, "vectors": { "vector-b": [1.0, 2.0], }, + "status": "pending", "inserted_at": record_b.inserted_at.isoformat(), "updated_at": record_b.updated_at.isoformat(), }, { "id": str(record_c.id), + "dataset_id": str(dataset.id), "fields": {"text": "This is a text", "sentiment": "neutral"}, "metadata": None, "external_id": record_c.external_id, "vectors": {}, + "status": "pending", "inserted_at": record_c.inserted_at.isoformat(), "updated_at": record_c.updated_at.isoformat(), }, @@ -312,7 +327,6 @@ async def test_list_dataset_records_with_include_specific_vectors( "total": 3, } - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_offset(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -331,7 +345,6 @@ async def test_list_dataset_records_with_offset(self, async_client: "AsyncClient response_body = response.json() assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_limit(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) @@ -350,7 +363,6 @@ async def test_list_dataset_records_with_limit(self, async_client: "AsyncClient" response_body = response.json() assert [item["id"] for item in response_body["items"]] == [str(record_a.id)] - @pytest.mark.skip(reason="Factory integration with search engine") async def test_list_dataset_records_with_offset_and_limit( self, async_client: "AsyncClient", owner_auth_header: dict ): @@ -371,80 +383,6 @@ async def test_list_dataset_records_with_offset_and_limit( response_body = response.json() assert [item["id"] for item in response_body["items"]] == [str(record_c.id)] - async def create_records_with_response( - self, - num_records: int, - dataset: Dataset, - user: User, - response_status: ResponseStatus, - response_values: Optional[dict] = None, - ): - for record in await RecordFactory.create_batch(size=num_records, dataset=dataset): - await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status) - - @pytest.mark.skip(reason="Factory integration with search engine") - @pytest.mark.parametrize( - "response_status_filter", ["missing", "pending", "discarded", "submitted", "draft", ["submitted", "draft"]] - ) - async def test_list_dataset_records_with_response_status_filter( - self, - async_client: "AsyncClient", - owner: "User", - owner_auth_header: dict, - response_status_filter: Union[str, List[str]], - ): - num_records_per_response_status = 10 - response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}} - - dataset = await DatasetFactory.create() - # missing responses - await RecordFactory.create_batch(size=num_records_per_response_status, dataset=dataset) - # discarded responses - await self.create_records_with_response( - num_records_per_response_status, dataset, owner, ResponseStatus.discarded - ) - # submitted responses - await self.create_records_with_response( - num_records_per_response_status, dataset, owner, ResponseStatus.submitted, response_values - ) - # drafted responses - await self.create_records_with_response( - num_records_per_response_status, dataset, owner, ResponseStatus.draft, response_values - ) - - other_dataset = await DatasetFactory.create() - await RecordFactory.create_batch(size=2, dataset=other_dataset) - - response_status_filter = ( - [response_status_filter] if isinstance(response_status_filter, str) else response_status_filter - ) - response_status_filter_url = [ - f"response_status={response_status}" for response_status in response_status_filter - ] - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records?{'&'.join(response_status_filter_url)}&include=responses", - headers=owner_auth_header, - ) - - assert response.status_code == 200 - response_json = response.json() - - assert len(response_json["items"]) == (num_records_per_response_status * len(response_status_filter)) - - if "missing" in response_status_filter: - assert ( - len([record for record in response_json["items"] if len(record["responses"]) == 0]) - >= num_records_per_response_status - ) - assert all( - [ - record["responses"][0]["status"] in response_status_filter - for record in response_json["items"] - if len(record["responses"]) > 0 - ] - ) - async def test_list_dataset_records_without_authentication(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() @@ -457,9 +395,9 @@ async def test_list_dataset_records_as_admin(self, async_client: "AsyncClient"): admin = await AdminFactory.create(workspaces=[workspace]) dataset = await DatasetFactory.create(workspace=workspace) - await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) - await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) - await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) + record_a = await RecordFactory.create(fields={"record_a": "value_a"}, dataset=dataset) + record_b = await RecordFactory.create(fields={"record_b": "value_b"}, dataset=dataset) + record_c = await RecordFactory.create(fields={"record_c": "value_c"}, dataset=dataset) other_dataset = await DatasetFactory.create() await RecordFactory.create_batch(size=2, dataset=other_dataset) @@ -468,6 +406,41 @@ async def test_list_dataset_records_as_admin(self, async_client: "AsyncClient"): f"/api/v1/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: admin.api_key} ) assert response.status_code == 200 + assert response.json() == { + "total": 3, + "items": [ + { + "id": str(record_a.id), + "dataset_id": str(dataset.id), + "fields": {"record_a": "value_a"}, + "metadata": None, + "external_id": record_a.external_id, + "status": "pending", + "inserted_at": record_a.inserted_at.isoformat(), + "updated_at": record_a.updated_at.isoformat(), + }, + { + "id": str(record_b.id), + "dataset_id": str(dataset.id), + "fields": {"record_b": "value_b"}, + "metadata": None, + "external_id": record_b.external_id, + "status": "pending", + "inserted_at": record_b.inserted_at.isoformat(), + "updated_at": record_b.updated_at.isoformat(), + }, + { + "id": str(record_c.id), + "dataset_id": str(dataset.id), + "fields": {"record_c": "value_c"}, + "metadata": None, + "external_id": record_c.external_id, + "status": "pending", + "inserted_at": record_c.inserted_at.isoformat(), + "updated_at": record_c.updated_at.isoformat(), + }, + ], + } async def test_list_dataset_records_as_annotator(self, async_client: "AsyncClient"): workspace = await WorkspaceFactory.create() @@ -560,3 +533,14 @@ async def create_dataset_with_user_responses( ] return dataset, questions, records, responses, suggestions + + async def create_records_with_response( + self, + num_records: int, + dataset: Dataset, + user: User, + response_status: ResponseStatus, + response_values: Optional[dict] = None, + ): + for record in await RecordFactory.create_batch(size=num_records, dataset=dataset): + await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status) From 808c837ce812de045691620c4c97aa387458d937 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 8 Jul 2024 17:09:40 +0200 Subject: [PATCH 15/36] [ENHANCEMENT]: `argilla-server`: allow update distribution for non annotated datasets (#5171) # Description This PR changes the current validator when updating the distribution task to allow updating the distribution task settings for datasets with records without ANY response. cc @nataliaElv **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../src/argilla_server/models/database.py | 11 ++++-- .../src/argilla_server/validators/datasets.py | 10 +++--- .../v1/datasets/test_update_dataset.py | 36 +++++++++++++++---- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 3230916362..6b9580dbb5 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -17,12 +17,12 @@ from typing import Any, List, Optional, Union from uuid import UUID -from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql +from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql, select, func, text from sqlalchemy import Enum as SAEnum from sqlalchemy.engine.default import DefaultExecutionContext from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.mutable import MutableDict, MutableList -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property from argilla_server.api.schemas.v1.questions import QuestionSettings from argilla_server.enums import ( @@ -361,6 +361,13 @@ class Dataset(DatabaseModel): __table_args__ = (UniqueConstraint("name", "workspace_id", name="dataset_name_workspace_id_uq"),) + @property + async def responses_count(self) -> int: + # TODO: This should be moved to proper repository + return await async_object_session(self).scalar( + select(func.count(Response.id)).join(Record).where(Record.dataset_id == self.id) + ) + @property def is_draft(self): return self.status == DatasetStatus.draft diff --git a/argilla-server/src/argilla_server/validators/datasets.py b/argilla-server/src/argilla_server/validators/datasets.py index aae2a5fc83..eb52576d41 100644 --- a/argilla-server/src/argilla_server/validators/datasets.py +++ b/argilla-server/src/argilla_server/validators/datasets.py @@ -40,9 +40,11 @@ async def _validate_name_is_not_duplicated(cls, db: AsyncSession, name: str, wor class DatasetUpdateValidator: @classmethod async def validate(cls, db: AsyncSession, dataset: Dataset, dataset_attrs: dict) -> None: - cls._validate_distribution(dataset, dataset_attrs) + await cls._validate_distribution(dataset, dataset_attrs) @classmethod - def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None: - if dataset.is_ready and dataset_attrs.get("distribution") is not None: - raise UnprocessableEntityError(f"Distribution settings cannot be modified for a published dataset") + async def _validate_distribution(cls, dataset: Dataset, dataset_attrs: dict) -> None: + if dataset_attrs.get("distribution") is not None and (await dataset.responses_count) > 0: + raise UnprocessableEntityError( + "Distribution settings cannot be modified for a dataset with records including responses" + ) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py index 097bc0a1ec..ea732d0536 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_update_dataset.py @@ -18,8 +18,7 @@ from httpx import AsyncClient from argilla_server.enums import DatasetDistributionStrategy, DatasetStatus - -from tests.factories import DatasetFactory +from tests.factories import DatasetFactory, RecordFactory, ResponseFactory @pytest.mark.asyncio @@ -96,7 +95,7 @@ async def test_update_dataset_without_distribution_for_published_dataset( "min_submitted": 1, } - async def test_update_dataset_distribution_for_published_dataset( + async def test_update_dataset_distribution_for_published_dataset_without_responses( self, async_client: AsyncClient, owner_auth_header: dict ): dataset = await DatasetFactory.create(status=DatasetStatus.ready) @@ -112,12 +111,37 @@ async def test_update_dataset_distribution_for_published_dataset( }, ) - assert response.status_code == 422 - assert response.json() == {"detail": "Distribution settings cannot be modified for a published dataset"} + assert response.status_code == 200 assert dataset.distribution == { "strategy": DatasetDistributionStrategy.overlap, - "min_submitted": 1, + "min_submitted": 4, + } + + async def test_update_dataset_distribution_for_dataset_with_responses( + self, async_client: AsyncClient, owner_auth_header: dict + ): + dataset = await DatasetFactory.create(status=DatasetStatus.ready) + records = await RecordFactory.create_batch(10, dataset=dataset) + + for record in records: + await ResponseFactory.create(record=record) + + response = await async_client.patch( + self.url(dataset.id), + headers=owner_auth_header, + json={ + "distribution": { + "strategy": DatasetDistributionStrategy.overlap, + "min_submitted": 4, + }, + }, + ) + + assert response.status_code == 422 + + assert response.json() == { + "detail": "Distribution settings cannot be modified for a dataset with records including responses" } async def test_update_dataset_distribution_with_invalid_strategy( From ba417dc671f4f46861c8a1c162f583a1c01f2d43 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Mon, 8 Jul 2024 18:25:27 +0200 Subject: [PATCH 16/36] [BREAKING - REFACTOR] `argilla-server`: remove `sort_by` query param (#5166) # Description This PR removes support of `sort_by` query param for list/search records endpoints. **Type of change** - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/CHANGELOG.md | 5 +- .../api/handlers/v1/datasets/records.py | 72 +---------- .../src/argilla_server/search_engine/base.py | 3 - .../argilla_server/search_engine/commons.py | 33 +---- .../datasets/test_search_dataset_records.py | 2 - .../unit/api/handlers/v1/test_datasets.py | 111 ++++++++-------- .../handlers/v1/test_list_dataset_records.py | 121 ------------------ .../tests/unit/search_engine/test_commons.py | 32 +++-- 8 files changed, 86 insertions(+), 293 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 05a8e7e583..b3f1483986 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -28,8 +28,9 @@ These are the section headers that we use: ### Removed - Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) -- [breaking] Removed support for `response_status` query param. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) -- [breaking] Removed support for `metadata` query param. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) +- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) +- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) +- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 4f68c1125f..065295612f 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -19,7 +19,6 @@ from fastapi import APIRouter, Depends, Query, Security, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from typing_extensions import Annotated import argilla_server.search_engine as search_engine from argilla_server.api.policies.v1 import DatasetPolicy, RecordPolicy, authorize, is_authorized @@ -52,12 +51,11 @@ from argilla_server.enums import RecordSortField, ResponseStatusFilter, SortOrder from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE -from argilla_server.models import Dataset, Field, MetadataProperty, Record, User, VectorSettings +from argilla_server.models import Dataset, Field, Record, User, VectorSettings from argilla_server.search_engine import ( AndFilter, SearchEngine, SearchResponses, - SortBy, UserResponseStatusFilter, get_search_engine, ) @@ -70,25 +68,6 @@ LIST_DATASET_RECORDS_DEFAULT_SORT_BY = {RecordSortField.inserted_at.value: "asc"} DELETE_DATASET_RECORDS_LIMIT = 100 -_RECORD_SORT_FIELD_VALUES = tuple(field.value for field in RecordSortField) -_VALID_SORT_VALUES = tuple(sort.value for sort in SortOrder) -_METADATA_PROPERTY_SORT_BY_REGEX = re.compile(r"^metadata\.(?P(?=.*[a-z0-9])[a-z0-9_-]+)$") - -SortByQueryParamParsed = Annotated[ - Dict[str, str], - Depends( - parse_query_param( - name="sort_by", - description=( - "The field used to sort the records. Expected format is `field` or `field:{asc,desc}`, where `field`" - " can be 'inserted_at', 'updated_at' or the name of a metadata property" - ), - max_values_per_key=1, - group_keys_without_values=False, - ) - ), -] - parse_record_include_param = parse_query_param( name="include", help="Relationships to include in the response", model=RecordIncludeParam ) @@ -104,7 +83,6 @@ async def _filter_records_using_search_engine( offset: int, user: Optional[User] = None, include: Optional[RecordIncludeParam] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> Tuple[List[Record], int]: search_responses = await _get_search_responses( db=db, @@ -113,7 +91,6 @@ async def _filter_records_using_search_engine( limit=limit, offset=offset, user=user, - sort_by_query_param=sort_by_query_param, ) record_ids = [response.record_id for response in search_responses.items] @@ -176,7 +153,6 @@ async def _get_search_responses( offset: int, search_records_query: Optional[SearchRecordsQuery] = None, user: Optional[User] = None, - sort_by_query_param: Optional[Dict[str, str]] = None, ) -> "SearchResponses": search_records_query = search_records_query or SearchRecordsQuery() @@ -216,8 +192,6 @@ async def _get_search_responses( if text_query and text_query.field and not await Field.get_by(db, name=text_query.field, dataset_id=dataset.id): raise UnprocessableEntityError(f"Field `{text_query.field}` not found in dataset `{dataset.id}`.") - sort_by = await _build_sort_by(db, dataset, sort_by_query_param) - if vector_query and vector_settings: similarity_search_params = { "dataset": dataset, @@ -239,7 +213,6 @@ async def _get_search_responses( "query": text_query, "offset": offset, "limit": limit, - "sort_by": sort_by, } if user is not None: @@ -265,43 +238,6 @@ async def _build_response_status_filter_for_search( return user_response_status_filter -async def _build_sort_by( - db: "AsyncSession", dataset: Dataset, sort_by_query_param: Optional[Dict[str, str]] = None -) -> Union[List[SortBy], None]: - if sort_by_query_param is None: - return None - - sorts_by = [] - for sort_field, sort_order in sort_by_query_param.items(): - if sort_field in _RECORD_SORT_FIELD_VALUES: - field = sort_field - elif (match := _METADATA_PROPERTY_SORT_BY_REGEX.match(sort_field)) is not None: - metadata_property_name = match.group("name") - metadata_property = await MetadataProperty.get_by(db, name=metadata_property_name, dataset_id=dataset.id) - if not metadata_property: - raise UnprocessableEntityError( - f"Provided metadata property in 'sort_by' query param '{metadata_property_name}' not found in " - f"dataset with '{dataset.id}'." - ) - - field = metadata_property - else: - valid_sort_fields = ", ".join(f"'{sort_field}'" for sort_field in _RECORD_SORT_FIELD_VALUES) - raise UnprocessableEntityError( - f"Provided sort field in 'sort_by' query param '{sort_field}' is not valid. It must be either" - f" {valid_sort_fields} or `metadata.metadata-property-name`" - ) - - if sort_order is not None and sort_order not in _VALID_SORT_VALUES: - raise UnprocessableEntityError( - f"Provided sort order in 'sort_by' query param '{sort_order}' for field '{sort_field}' is not valid.", - ) - - sorts_by.append(SortBy(field=field, order=sort_order or SortOrder.asc.value)) - - return sorts_by - - async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): try: await search.validate_search_records_query(db, query, dataset_id) @@ -315,7 +251,6 @@ async def list_dataset_records( db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -332,7 +267,6 @@ async def list_dataset_records( limit=limit, offset=offset, include=include, - sort_by_query_param=sort_by_query_param or LIST_DATASET_RECORDS_DEFAULT_SORT_BY, ) return Records(items=records, total=total) @@ -441,7 +375,6 @@ async def search_current_user_dataset_records( telemetry_client: TelemetryClient = Depends(get_telemetry_client), dataset_id: UUID, body: SearchRecordsQuery, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -468,7 +401,6 @@ async def search_current_user_dataset_records( limit=limit, offset=offset, user=current_user, - sort_by_query_param=sort_by_query_param, ) record_id_score_map: Dict[UUID, Dict[str, Union[float, SearchRecord, None]]] = { @@ -511,7 +443,6 @@ async def search_dataset_records( search_engine: SearchEngine = Depends(get_search_engine), dataset_id: UUID, body: SearchRecordsQuery, - sort_by_query_param: SortByQueryParamParsed, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), @@ -530,7 +461,6 @@ async def search_dataset_records( search_records_query=body, limit=limit, offset=offset, - sort_by_query_param=sort_by_query_param, ) record_id_score_map = { diff --git a/argilla-server/src/argilla_server/search_engine/base.py b/argilla-server/src/argilla_server/search_engine/base.py index 687c51bad3..db5bc87e2a 100644 --- a/argilla-server/src/argilla_server/search_engine/base.py +++ b/argilla-server/src/argilla_server/search_engine/base.py @@ -282,9 +282,6 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: remove them and keep filter and order - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, ) -> SearchResponses: diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 0b7606c642..5b9d5e66bc 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -199,23 +199,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str: return str(vector_settings.id) -# This function will be moved once the `sort_by` argument is removed from search and similarity_search methods -def _unify_sort_by_with_order(sort_by: List[SortBy], order: List[Order]) -> List[Order]: - if order: - return order - - new_order = [] - for sort in sort_by: - if isinstance(sort.field, MetadataProperty): - scope = MetadataFilterScope(metadata_property=sort.field.name) - else: - scope = RecordFilterScope(property=sort.field) - - new_order.append(Order(scope=scope, order=sort.order)) - - return new_order - - def is_response_status_scope(scope: FilterScope) -> bool: return isinstance(scope, ResponseFilterScope) and scope.property == "status" and scope.question is None @@ -327,14 +310,14 @@ async def update_record_response(self, response: Response): es_responses = self._map_record_responses_to_es([response]) - await self._update_document_request(index_name, id=record.id, body={"doc": {"responses": es_responses}}) + await self._update_document_request(index_name, id=str(record.id), body={"doc": {"responses": es_responses}}) async def delete_record_response(self, response: Response): record = response.record index_name = await self._get_dataset_index(record.dataset) await self._update_document_request( - index_name, id=record.id, body={"script": es_script_for_delete_user_response(response.user)} + index_name, id=str(record.id), body={"script": es_script_for_delete_user_response(response.user)} ) async def update_record_suggestion(self, suggestion: Suggestion): @@ -344,7 +327,7 @@ async def update_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"doc": {"suggestions": es_suggestions}}, ) @@ -353,7 +336,7 @@ async def delete_record_suggestion(self, suggestion: Suggestion): await self._update_document_request( index_name, - id=suggestion.record_id, + id=str(suggestion.record_id), body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'}, ) @@ -576,19 +559,11 @@ async def search( query: Optional[Union[TextQuery, str]] = None, filter: Optional[Filter] = None, sort: Optional[List[Order]] = None, - # TODO: Remove these arguments - sort_by: Optional[List[SortBy]] = None, - # END TODO offset: int = 0, limit: int = 100, user_id: Optional[str] = None, ) -> SearchResponses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html - - # TODO: This block will be moved (maybe to contexts/search.py), and only filter and order arguments will be kept - if sort_by: - sort = _unify_sort_by_with_order(sort_by, sort) - # END TODO index = await self._get_dataset_index(dataset) text_query = self._build_text_query(dataset, text=query) diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py index 5229f53cf9..5e3c6653de 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/test_search_dataset_records.py @@ -319,7 +319,6 @@ async def test_with_filter( offset=0, limit=50, query=None, - sort_by=None, ) async def test_with_sort( @@ -368,7 +367,6 @@ async def test_with_sort( offset=0, limit=50, query=None, - sort_by=None, ) async def test_with_invalid_filter(self, async_client: AsyncClient, owner_auth_header: dict): diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 1d9ddcf22c..f84154b4c4 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -14,7 +14,7 @@ import math import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type from unittest.mock import ANY, MagicMock from uuid import UUID, uuid4 @@ -43,6 +43,7 @@ ResponseStatusFilter, SimilarityOrder, RecordStatus, + SortOrder, ) from argilla_server.models import ( Dataset, @@ -62,13 +63,14 @@ SearchEngine, SearchResponseItem, SearchResponses, - SortBy, TextQuery, AndFilter, TermsFilter, MetadataFilterScope, RangeFilter, ResponseFilterScope, + Order, + RecordFilterScope, ) from tests.factories import ( AdminFactory, @@ -3651,7 +3653,6 @@ async def test_search_current_user_dataset_records( query=TextQuery(q="Hello", field="input"), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -3811,31 +3812,42 @@ async def test_search_current_user_dataset_records_with_metadata_filter( filter=AndFilter(filters=[expected_filter]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) @pytest.mark.parametrize( - "sorts", + "sort,expected_sort", [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "inserted_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "asc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc)], + ), + ( + [{"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}], + [Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc)], + ), + ( + [{"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "asc"}], + [Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.asc)], + ), + ( + [ + {"scope": {"entity": "record", "property": "updated_at"}, "order": "desc"}, + {"scope": {"entity": "metadata", "metadata_property": "terms-metadata-property"}, "order": "desc"}, + ], + [ + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), + Order(scope=MetadataFilterScope(metadata_property="terms-metadata-property"), order=SortOrder.desc), + ], + ), ], ) async def test_search_current_user_dataset_records_with_sort_by( @@ -3844,16 +3856,15 @@ async def test_search_current_user_dataset_records_with_sort_by( mock_search_engine: SearchEngine, owner: "User", owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], + sort: List[dict], + expected_sort: List[Order], ): workspace = await WorkspaceFactory.create() dataset, _, records, *_ = await self.create_dataset_with_user_responses(owner, workspace) - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) + for order in expected_sort: + if isinstance(order.scope, MetadataFilterScope): + await TermsMetadataPropertyFactory.create(name=order.scope.metadata_property, dataset=dataset) mock_search_engine.search.return_value = SearchResponses( total=2, @@ -3863,15 +3874,13 @@ async def test_search_current_user_dataset_records_with_sort_by( ], ) - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": sort, } - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} - response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params=query_params, headers=owner_auth_header, json=query_json, ) @@ -3883,7 +3892,7 @@ async def test_search_current_user_dataset_records_with_sort_by( query=TextQuery(q="Hello", field="input"), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, + sort=expected_sort, user_id=owner.id, ) @@ -3893,18 +3902,17 @@ async def test_search_current_user_dataset_records_with_sort_by_with_wrong_sort_ workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "record", "property": "wrong_property"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } async def test_search_current_user_dataset_records_with_sort_by_with_non_existent_metadata_property( self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict @@ -3912,17 +3920,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_non_existen workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [{"scope": {"entity": "metadata", "metadata_property": "missing"}, "order": "asc"}], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "metadata.i-do-not-exist:asc"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." + "detail": f"MetadataProperty not found filtering by name=missing, dataset_id={dataset.id}" } async def test_search_current_user_dataset_records_with_sort_by_with_invalid_field( @@ -3931,19 +3941,19 @@ async def test_search_current_user_dataset_records_with_sort_by_with_invalid_fie workspace = await WorkspaceFactory.create() dataset, *_ = await self.create_dataset_with_user_responses(owner, workspace) - query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + query_json = { + "query": {"text": {"q": "Hello", "field": "input"}}, + "sort": [ + {"scope": {"entity": "wrong", "property": "wrong"}, "order": "asc"}, + ], + } response = await async_client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", - params={"sort_by": "not-valid"}, headers=owner_auth_header, json=query_json, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } @pytest.mark.parametrize( "includes", @@ -4085,7 +4095,6 @@ async def test_search_current_user_dataset_records_with_include( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=TextQuery(q="Hello", field="input"), - sort_by=None, offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, user_id=owner.id, @@ -4319,7 +4328,6 @@ async def test_search_current_user_dataset_records_with_response_status_filter( ), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 @@ -4544,7 +4552,6 @@ async def test_search_current_user_dataset_records_with_offset_and_limit( query=TextQuery(q="Hello", field="input"), offset=0, limit=5, - sort_by=None, user_id=owner.id, ) assert response.status_code == 200 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index db2605c3de..4f989e5399 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -17,16 +17,9 @@ import pytest from httpx import AsyncClient -from argilla_server.api.handlers.v1.datasets.records import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla_server.constants import API_KEY_HEADER_NAME from argilla_server.enums import RecordInclude, ResponseStatus from argilla_server.models import Dataset, Question, Record, Response, Suggestion, User, Workspace -from argilla_server.search_engine import ( - SearchEngine, - SearchResponseItem, - SearchResponses, - SortBy, -) from tests.factories import ( AdminFactory, AnnotatorFactory, @@ -35,7 +28,6 @@ RecordFactory, ResponseFactory, SuggestionFactory, - TermsMetadataPropertyFactory, TextFieldFactory, TextQuestionFactory, VectorFactory, @@ -453,119 +445,6 @@ async def test_list_dataset_records_with_response_status_filter( ] ) - @pytest.mark.parametrize( - "sorts", - [ - [("inserted_at", None)], - [("inserted_at", "asc")], - [("inserted_at", "desc")], - [("updated_at", None)], - [("updated_at", "asc")], - [("updated_at", "desc")], - [("metadata.terms-metadata-property", None)], - [("metadata.terms-metadata-property", "asc")], - [("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "desc")], - [("inserted_at", "desc"), ("updated_at", "asc")], - [("inserted_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "asc"), ("updated_at", "desc"), ("metadata.terms-metadata-property", "asc")], - [("inserted_at", "desc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - [("inserted_at", "asc"), ("updated_at", "asc"), ("metadata.terms-metadata-property", "desc")], - ], - ) - async def test_list_dataset_records_with_sort_by( - self, - async_client: "AsyncClient", - mock_search_engine: SearchEngine, - owner: "User", - owner_auth_header: dict, - sorts: List[Tuple[str, Union[str, None]]], - ): - workspace = await WorkspaceFactory.create() - dataset, _, records, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - expected_sorts_by = [] - for field, order in sorts: - if field not in ("inserted_at", "updated_at"): - field = await TermsMetadataPropertyFactory.create(name=field.split(".")[-1], dataset=dataset) - expected_sorts_by.append(SortBy(field=field, order=order or "asc")) - - mock_search_engine.search.return_value = SearchResponses( - total=2, - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ], - ) - - query_params = { - "sort_by": [f"{field}:{order}" if order is not None else f"{field}:asc" for field, order in sorts] - } - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params=query_params, - headers=owner_auth_header, - ) - assert response.status_code == 200 - assert response.json()["total"] == 2 - - mock_search_engine.search.assert_called_once_with( - dataset=dataset, - query=None, - offset=0, - limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, - sort_by=expected_sorts_by, - ) - - async def test_list_dataset_records_with_sort_by_with_wrong_sort_order_value( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", params={"sort_by": "inserted_at:wrong"}, headers=owner_auth_header - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort order in 'sort_by' query param 'wrong' for field 'inserted_at' is not valid." - } - - async def test_list_dataset_records_with_sort_by_with_non_existent_metadata_property( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "metadata.i-do-not-exist:asc"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": f"Provided metadata property in 'sort_by' query param 'i-do-not-exist' not found in dataset with '{dataset.id}'." - } - - async def test_list_dataset_records_with_sort_by_with_invalid_field( - self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict - ): - workspace = await WorkspaceFactory.create() - dataset, _, _, _, _ = await self.create_dataset_with_user_responses(owner, workspace) - - response = await async_client.get( - f"/api/v1/datasets/{dataset.id}/records", - params={"sort_by": "not-valid"}, - headers=owner_auth_header, - ) - assert response.status_code == 422 - assert response.json() == { - "detail": "Provided sort field in 'sort_by' query param 'not-valid' is not valid. " - "It must be either 'inserted_at', 'updated_at' or `metadata.metadata-property-name`" - } - async def test_list_dataset_records_without_authentication(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index 5ae8241927..c893366b58 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -16,7 +16,14 @@ import pytest import pytest_asyncio -from argilla_server.enums import MetadataPropertyType, QuestionType, ResponseStatusFilter, SimilarityOrder, RecordStatus +from argilla_server.enums import ( + MetadataPropertyType, + QuestionType, + ResponseStatusFilter, + SimilarityOrder, + RecordStatus, + SortOrder, +) from argilla_server.models import Dataset, Question, Record, User, VectorSettings from argilla_server.search_engine import ( ResponseFilterScope, @@ -28,6 +35,8 @@ Filter, MetadataFilterScope, RangeFilter, + Order, + RecordFilterScope, ) from argilla_server.search_engine.commons import ( ALL_RESPONSES_STATUSES_FIELD, @@ -820,12 +829,12 @@ async def test_search_with_pagination( assert all_results.items[offset : offset + limit] == results.items @pytest.mark.parametrize( - ("sort_by"), + ("sort_order"), [ - SortBy(field="inserted_at"), - SortBy(field="updated_at"), - SortBy(field="inserted_at", order="desc"), - SortBy(field="updated_at", order="desc"), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.asc), + Order(scope=RecordFilterScope(property="inserted_at"), order=SortOrder.desc), + Order(scope=RecordFilterScope(property="updated_at"), order=SortOrder.desc), ], ) async def test_search_with_sort_by( @@ -833,18 +842,15 @@ async def test_search_with_sort_by( search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - sort_by: SortBy, + sort_order: Order, ): def _local_sort_by(record: Record) -> Any: - if isinstance(sort_by.field, str): - return getattr(record, sort_by.field) - return record.metadata_[sort_by.field.name] + return getattr(record, sort_order.scope.property) - results = await search_engine.search(test_banking_sentiment_dataset, sort_by=[sort_by]) + results = await search_engine.search(test_banking_sentiment_dataset, sort=[sort_order]) records = test_banking_sentiment_dataset.records - if sort_by: - records = sorted(records, key=_local_sort_by, reverse=sort_by.order == "desc") + records = sorted(records, key=_local_sort_by, reverse=sort_order.order == "desc") assert [item.record_id for item in results.items] == [record.id for record in records] From f241e41acde8110046d8f4667863896ce7d0543e Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 9 Jul 2024 11:18:14 +0200 Subject: [PATCH 17/36] fix: wrong filter naming after merge from develop --- argilla/src/argilla/records/_search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/argilla/src/argilla/records/_search.py b/argilla/src/argilla/records/_search.py index dfa2f4c99c..07a972bde9 100644 --- a/argilla/src/argilla/records/_search.py +++ b/argilla/src/argilla/records/_search.py @@ -56,7 +56,7 @@ def _extract_filter_scope(field: str) -> ScopeModel: field = field.strip() if field == "status": return RecordFilterScopeModel(property="status") - elif field == "responses.status": + elif field == "response.status": return ResponseFilterScopeModel(property="status") elif "metadata" in field: _, md_property = field.split(".") From bec0b0d0bf2c2d5e2f5f3de3d28df316b497e950 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 12 Jul 2024 11:13:36 +0200 Subject: [PATCH 18/36] feat: add session helper with serializable isolation level (#5165) # Description This PR add a new `get_serializable_async_db` function helper that returns a session using isolation leve as `SERIALIZABLE`. This session can be used on some handlers where we require that specific isolation level. As example I have added that session helper for handler deleting responses and PostgreSQL is showing the following received queries: ```sql 2024-07-04 17:09:40.417 CEST [83566] LOG: statement: BEGIN ISOLATION LEVEL READ COMMITTED; 2024-07-04 17:09:40.418 CEST [83566] LOG: execute __asyncpg_stmt_e__: SELECT users.first_name, users.last_name, users.username, users.role, users.api_key, users.password_hash, users.id, users.inserted_at, users.updated_at FROM users WHERE users.api_key = $1::VARCHAR 2024-07-04 17:09:40.418 CEST [83566] DETAIL: parameters: $1 = 'argilla.apikey' 2024-07-04 17:09:40.422 CEST [83566] LOG: execute __asyncpg_stmt_12__: SELECT users_1.id AS users_1_id, workspaces.name AS workspaces_name, workspaces.id AS workspaces_id, workspaces.inserted_at AS workspaces_inserted_at, workspaces.updated_at AS workspaces_updated_at FROM users AS users_1 JOIN workspaces_users AS workspaces_users_1 ON users_1.id = workspaces_users_1.user_id JOIN workspaces ON workspaces.id = workspaces_users_1.workspace_id WHERE users_1.id IN ($1::UUID) ORDER BY workspaces_users_1.inserted_at ASC 2024-07-04 17:09:40.422 CEST [83566] DETAIL: parameters: $1 = 'ed2d570f-cc9f-4d53-a433-74aa7a286a52' 2024-07-04 17:09:40.426 CEST [83566] LOG: execute __asyncpg_stmt_13__: SELECT users.first_name, users.last_name, users.username, users.role, users.api_key, users.password_hash, users.id, users.inserted_at, users.updated_at FROM users WHERE users.username = $1::VARCHAR 2024-07-04 17:09:40.426 CEST [83566] DETAIL: parameters: $1 = 'argilla' 2024-07-04 17:09:40.428 CEST [83566] LOG: execute __asyncpg_stmt_12__: SELECT users_1.id AS users_1_id, workspaces.name AS workspaces_name, workspaces.id AS workspaces_id, workspaces.inserted_at AS workspaces_inserted_at, workspaces.updated_at AS workspaces_updated_at FROM users AS users_1 JOIN workspaces_users AS workspaces_users_1 ON users_1.id = workspaces_users_1.user_id JOIN workspaces ON workspaces.id = workspaces_users_1.workspace_id WHERE users_1.id IN ($1::UUID) ORDER BY workspaces_users_1.inserted_at ASC 2024-07-04 17:09:40.428 CEST [83566] DETAIL: parameters: $1 = 'ed2d570f-cc9f-4d53-a433-74aa7a286a52' 2024-07-04 17:09:40.430 CEST [83563] LOG: statement: BEGIN ISOLATION LEVEL SERIALIZABLE; 2024-07-04 17:09:40.430 CEST [83563] LOG: execute __asyncpg_stmt_14__: SELECT responses.values, responses.status, responses.record_id, responses.user_id, responses.id, responses.inserted_at, responses.updated_at FROM responses WHERE responses.id = $1::UUID 2024-07-04 17:09:40.430 CEST [83563] DETAIL: parameters: $1 = 'fdea95a0-ee9a-43ea-b093-2e13f2473c19' 2024-07-04 17:09:40.431 CEST [83566] LOG: statement: ROLLBACK; 2024-07-04 17:09:40.432 CEST [83563] LOG: statement: ROLLBACK; ``` We can clearly see that there are two nested transaction: 1. The main one to get current user using default `get_async_db` helper. 2. A nested one using `get_serializable_async_db` (and setting `SERIALIZABLE` isolation level) trying to find the response by id. The response id used is fake so the transaction ends there and the deletion is not done. ## Missing things on this PR - [x] Fix some failing tests. - [ ] Tests are passing but still not changing the isolation level to `SERIALIZABLE`. - [ ] Check that this works as expected and does not affect SQLite. - [ ] Check that this works as expected with PostgreSQL (no concurrency errors). Closes #5155 **Type of change** - New feature (non-breaking change which adds functionality) - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** - [x] Manually seeing PostgreSQL logs. **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../argilla_server/api/handlers/v1/records.py | 4 +-- .../api/handlers/v1/responses.py | 6 ++--- .../src/argilla_server/contexts/datasets.py | 1 + argilla-server/src/argilla_server/database.py | 26 ++++++++++++++----- .../responses/upsert_responses_in_bulk.py | 6 +++-- argilla-server/tests/unit/conftest.py | 16 +++++++++--- 6 files changed, 42 insertions(+), 17 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/records.py b/argilla-server/src/argilla_server/api/handlers/v1/records.py index 3778921ee2..23398e93be 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/records.py @@ -26,7 +26,7 @@ from argilla_server.api.schemas.v1.suggestions import Suggestion as SuggestionSchema from argilla_server.api.schemas.v1.suggestions import SuggestionCreate, Suggestions from argilla_server.contexts import datasets -from argilla_server.database import get_async_db +from argilla_server.database import get_async_db, get_serializable_async_db from argilla_server.errors.future.base_errors import NotFoundError, UnprocessableEntityError from argilla_server.models import Dataset, Question, Record, Suggestion, User from argilla_server.search_engine import SearchEngine, get_search_engine @@ -88,7 +88,7 @@ async def update_record( @router.post("/records/{record_id}/responses", status_code=status.HTTP_201_CREATED, response_model=Response) async def create_record_response( *, - db: AsyncSession = Depends(get_async_db), + db: AsyncSession = Depends(get_serializable_async_db), search_engine: SearchEngine = Depends(get_search_engine), record_id: UUID, response_create: ResponseCreate, diff --git a/argilla-server/src/argilla_server/api/handlers/v1/responses.py b/argilla-server/src/argilla_server/api/handlers/v1/responses.py index ddc389563a..95e468351f 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/responses.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/responses.py @@ -28,7 +28,7 @@ ResponseUpdate, ) from argilla_server.contexts import datasets -from argilla_server.database import get_async_db +from argilla_server.database import get_serializable_async_db from argilla_server.models import Dataset, Record, Response, User from argilla_server.search_engine import SearchEngine, get_search_engine from argilla_server.security import auth @@ -55,7 +55,7 @@ async def create_current_user_responses_bulk( @router.put("/responses/{response_id}", response_model=ResponseSchema) async def update_response( *, - db: AsyncSession = Depends(get_async_db), + db: AsyncSession = Depends(get_serializable_async_db), search_engine: SearchEngine = Depends(get_search_engine), response_id: UUID, response_update: ResponseUpdate, @@ -77,7 +77,7 @@ async def update_response( @router.delete("/responses/{response_id}", response_model=ResponseSchema) async def delete_response( *, - db: AsyncSession = Depends(get_async_db), + db: AsyncSession = Depends(get_serializable_async_db), search_engine=Depends(get_search_engine), response_id: UUID, current_user: User = Security(auth.get_current_user), diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 700dfeaefa..4d5a5f89fe 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -967,6 +967,7 @@ async def create_response( ) await db.flush([response]) + await _load_users_from_responses([response]) await _touch_dataset_last_activity_at(db, record.dataset) await search_engine.update_record_response(response) await db.refresh(record, attribute_names=[Record.responses_submitted.key]) diff --git a/argilla-server/src/argilla_server/database.py b/argilla-server/src/argilla_server/database.py index e0bc4c4c95..eaaf27079b 100644 --- a/argilla-server/src/argilla_server/database.py +++ b/argilla-server/src/argilla_server/database.py @@ -14,19 +14,17 @@ import os from collections import OrderedDict from sqlite3 import Connection as SQLite3Connection -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING, AsyncGenerator, Optional from sqlalchemy import event, make_url from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.engine.interfaces import IsolationLevel +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine, AsyncSession from sqlalchemy.dialects.sqlite.aiosqlite import AsyncAdapt_aiosqlite_connection import argilla_server from argilla_server.settings import settings -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncSession - ALEMBIC_CONFIG_FILE = os.path.normpath(os.path.join(os.path.dirname(argilla_server.__file__), "alembic.ini")) TAGGED_REVISIONS = OrderedDict( @@ -55,9 +53,23 @@ def set_sqlite_pragma(dbapi_connection, connection_record): AsyncSessionLocal = async_sessionmaker(autocommit=False, expire_on_commit=False, bind=async_engine) -async def get_async_db() -> Generator["AsyncSession", None, None]: +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: + async for db in _get_async_db(): + yield db + + +async def get_serializable_async_db() -> AsyncGenerator[AsyncSession, None]: + async for db in _get_async_db(isolation_level="SERIALIZABLE"): + yield db + + +async def _get_async_db(isolation_level: Optional[IsolationLevel] = None) -> AsyncGenerator[AsyncSession, None]: + db: AsyncSession = AsyncSessionLocal() + + if isolation_level is not None: + await db.connection(execution_options={"isolation_level": isolation_level}) + try: - db: "AsyncSession" = AsyncSessionLocal() yield db finally: await db.close() diff --git a/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py b/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py index 547dd7e68b..520194e46a 100644 --- a/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py +++ b/argilla-server/src/argilla_server/use_cases/responses/upsert_responses_in_bulk.py @@ -20,7 +20,7 @@ from argilla_server.api.policies.v1 import RecordPolicy, authorize from argilla_server.api.schemas.v1.responses import Response, ResponseBulk, ResponseBulkError, ResponseUpsert from argilla_server.contexts import datasets -from argilla_server.database import get_async_db +from argilla_server.database import get_serializable_async_db from argilla_server.errors import future as errors from argilla_server.models import User from argilla_server.search_engine import SearchEngine, get_search_engine @@ -55,6 +55,8 @@ async def execute(self, responses: List[ResponseUpsert], user: User) -> List[Res class UpsertResponsesInBulkUseCaseFactory: def __call__( - self, db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine) + self, + db: AsyncSession = Depends(get_serializable_async_db), + search_engine: SearchEngine = Depends(get_search_engine), ): return UpsertResponsesInBulkUseCase(db, search_engine) diff --git a/argilla-server/tests/unit/conftest.py b/argilla-server/tests/unit/conftest.py index fe3479ea6d..a1dac6fbc5 100644 --- a/argilla-server/tests/unit/conftest.py +++ b/argilla-server/tests/unit/conftest.py @@ -14,14 +14,15 @@ import contextlib import uuid -from typing import TYPE_CHECKING, Dict, Generator +from typing import TYPE_CHECKING, Dict, Generator, Optional import pytest import pytest_asyncio +from sqlalchemy.engine.interfaces import IsolationLevel from argilla_server import telemetry from argilla_server.api.routes import api_v1 from argilla_server.constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY -from argilla_server.database import get_async_db +from argilla_server.database import get_async_db, get_serializable_async_db from argilla_server.models import User, UserRole, Workspace from argilla_server.search_engine import SearchEngine, get_search_engine from argilla_server.settings import settings @@ -78,10 +79,18 @@ async def async_client( ) -> Generator["AsyncClient", None, None]: from argilla_server import app - async def override_get_async_db(): + async def override_get_async_db(isolation_level: Optional[IsolationLevel] = None): session = TestSession() + + if isolation_level is not None: + await session.connection(execution_options={"isolation_level": isolation_level}) + yield session + async def override_get_serializable_async_db(): + async for session in override_get_async_db(isolation_level="SERIALIZABLE"): + yield session + async def override_get_search_engine(): yield mock_search_engine @@ -89,6 +98,7 @@ async def override_get_search_engine(): for api in [api_v1]: api.dependency_overrides[get_async_db] = override_get_async_db + api.dependency_overrides[get_serializable_async_db] = override_get_serializable_async_db api.dependency_overrides[get_search_engine] = override_get_search_engine async with AsyncClient(app=app, base_url="http://testserver") as async_client: From 85e847f68b1a5612d283ba39375d0631474ae639 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Fri, 12 Jul 2024 11:46:25 +0200 Subject: [PATCH 19/36] [REFACTOR] `argilla-server`: remove deprecated records endpoint (#5206) # Description This PR removes deprecated endpoints working with records to avoid creating records with a proper record status computation. The affected endpoints are: `POST /api/v1/datasets/:dataset_id/records` `PATCH /api/v1/datasets/:dataset_id/records` **Type of change** - Bug fix (non-breaking change which fixes an issue) - New feature (non-breaking change which adds functionality) - Breaking change (fix or feature that would cause existing functionality to not work as expected) - Refactor (change restructuring the codebase without changing functionality) - Improvement (change adding some improvement to an existing functionality) - Documentation update **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/CHANGELOG.md | 9 +- .../api/handlers/v1/datasets/records.py | 67 ----- .../src/argilla_server/contexts/datasets.py | 155 +---------- .../test_create_dataset_records_in_bulk.py} | 6 +- .../test_update_dataset_records_in_bulk.py} | 8 +- .../unit/api/handlers/v1/test_datasets.py | 262 ++++++------------ 6 files changed, 111 insertions(+), 396 deletions(-) rename argilla-server/tests/unit/api/handlers/v1/datasets/records/{test_create_dataset_records.py => records_bulk/test_create_dataset_records_in_bulk.py} (98%) rename argilla-server/tests/unit/api/handlers/v1/datasets/records/{test_update_dataset_records.py => records_bulk/test_update_dataset_records_in_bulk.py} (97%) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 0338a4e503..64651f3ad3 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -24,13 +24,18 @@ These are the section headers that we use: ### Changed - Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) -- [breaking] Change `GET /datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) -- [breaking] Change `GET /me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) +- [breaking] Change `GET /api/v1/datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) +- [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) ### Fixed - Fixed SQLite connection settings not working correctly due to a outdated conditional. ([#5149](https://github.com/argilla-io/argilla/pull/5149)) +### Removed + +- [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) +- [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) + ## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) ### Changed diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index e032aa7037..8cc5ee2538 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -34,8 +34,6 @@ RecordFilterScope, RecordIncludeParam, Records, - RecordsCreate, - RecordsUpdate, SearchRecord, SearchRecordsQuery, SearchRecordsResult, @@ -424,71 +422,6 @@ async def list_dataset_records( return Records(items=records, total=total) -@router.post( - "/datasets/{dataset_id}/records", - status_code=status.HTTP_204_NO_CONTENT, - deprecated=True, - description="Deprecated in favor of POST /datasets/{dataset_id}/records/bulk", -) -async def create_dataset_records( - *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), - dataset_id: UUID, - records_create: RecordsCreate, - current_user: User = Security(auth.get_current_user), -): - dataset = await Dataset.get_or_raise( - db, - dataset_id, - options=[ - selectinload(Dataset.fields), - selectinload(Dataset.questions), - selectinload(Dataset.metadata_properties), - selectinload(Dataset.vectors_settings), - ], - ) - - await authorize(current_user, DatasetPolicy.create_records(dataset)) - - await datasets.create_records(db, search_engine, dataset, records_create) - - telemetry_client.track_data(action="DatasetRecordsCreated", data={"records": len(records_create.items)}) - - -@router.patch( - "/datasets/{dataset_id}/records", - status_code=status.HTTP_204_NO_CONTENT, - deprecated=True, - description="Deprecated in favor of PUT /datasets/{dataset_id}/records/bulk", -) -async def update_dataset_records( - *, - db: AsyncSession = Depends(get_async_db), - search_engine: SearchEngine = Depends(get_search_engine), - telemetry_client: TelemetryClient = Depends(get_telemetry_client), - dataset_id: UUID, - records_update: RecordsUpdate, - current_user: User = Security(auth.get_current_user), -): - dataset = await Dataset.get_or_raise( - db, - dataset_id, - options=[ - selectinload(Dataset.fields), - selectinload(Dataset.questions), - selectinload(Dataset.metadata_properties), - ], - ) - - await authorize(current_user, DatasetPolicy.update_records(dataset)) - - await datasets.update_records(db, search_engine, dataset, records_update) - - telemetry_client.track_data(action="DatasetRecordsUpdated", data={"records": len(records_update.items)}) - - @router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT) async def delete_dataset_records( *, diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 4d5a5f89fe..b95fcde5e1 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -33,7 +33,7 @@ import sqlalchemy from fastapi.encoders import jsonable_encoder -from sqlalchemy import Select, and_, case, func, select +from sqlalchemy import Select, and_, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload, selectinload @@ -42,8 +42,6 @@ from argilla_server.api.schemas.v1.records import ( RecordCreate, RecordIncludeParam, - RecordsCreate, - RecordsUpdate, RecordUpdateWithId, ) from argilla_server.api.schemas.v1.responses import ( @@ -60,7 +58,7 @@ ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema from argilla_server.contexts import accounts, distribution -from argilla_server.enums import DatasetStatus, RecordInclude, UserRole, RecordStatus +from argilla_server.enums import DatasetStatus, UserRole, RecordStatus from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError from argilla_server.models import ( Dataset, @@ -74,7 +72,6 @@ User, Vector, VectorSettings, - Workspace, ) from argilla_server.models.suggestions import SuggestionCreateWithRecordId from argilla_server.search_engine import SearchEngine @@ -87,9 +84,6 @@ from argilla_server.validators.suggestions import SuggestionCreateValidator if TYPE_CHECKING: - from argilla_server.api.schemas.v1.datasets import ( - DatasetUpdate, - ) from argilla_server.api.schemas.v1.fields import FieldUpdate from argilla_server.api.schemas.v1.records import RecordUpdate from argilla_server.api.schemas.v1.suggestions import SuggestionCreate @@ -231,7 +225,8 @@ async def create_metadata_property( ) -> MetadataProperty: if await MetadataProperty.get_by(db, name=metadata_property_create.name, dataset_id=dataset.id): raise NotUniqueError( - f"Metadata property with name `{metadata_property_create.name}` already exists for dataset with id `{dataset.id}`" + f"Metadata property with name `{metadata_property_create.name}` already exists " + f"for dataset with id `{dataset.id}`" ) async with db.begin_nested(): @@ -292,7 +287,8 @@ async def create_vector_settings( if await VectorSettings.get_by(db, name=vector_settings_create.name, dataset_id=dataset.id): raise NotUniqueError( - f"Vector settings with name `{vector_settings_create.name}` already exists for dataset with id `{dataset.id}`" + f"Vector settings with name `{vector_settings_create.name}` already exists " + f"for dataset with id `{dataset.id}`" ) async with db.begin_nested(): @@ -403,7 +399,7 @@ async def get_user_dataset_metrics(db: AsyncSession, user_id: UUID, dataset_id: .filter( Record.dataset_id == dataset_id, Record.status == RecordStatus.pending, - Response.id == None, + Response.id == None, # noqa ), ), ) @@ -549,57 +545,6 @@ async def _build_record( ) -async def create_records( - db: AsyncSession, search_engine: SearchEngine, dataset: Dataset, records_create: RecordsCreate -): - if not dataset.is_ready: - raise UnprocessableEntityError("Records cannot be created for a non published dataset") - - records = [] - - caches = { - "users_ids_cache": set(), - "questions_cache": {}, - "metadata_properties_cache": {}, - "vectors_settings_cache": {}, - } - - for record_i, record_create in enumerate(records_create.items): - try: - record = await _build_record(db, dataset, record_create, caches) - - record.responses = await _build_record_responses( - db, record, record_create.responses, caches["users_ids_cache"] - ) - - record.suggestions = await _build_record_suggestions( - db, record, record_create.suggestions, caches["questions_cache"] - ) - - record.vectors = await _build_record_vectors( - db, - dataset, - record_create.vectors, - build_vector_func=lambda value, vector_settings_id: Vector( - value=value, vector_settings_id=vector_settings_id - ), - cache=caches["vectors_settings_cache"], - ) - - except (UnprocessableEntityError, ValueError) as e: - raise UnprocessableEntityError(f"Record at position {record_i} is not valid because {e}") from e - - records.append(record) - - async with db.begin_nested(): - db.add_all(records) - await db.flush(records) - await _preload_records_relationships_before_index(db, records) - await search_engine.index_records(dataset, records) - - await db.commit() - - async def _load_users_from_responses(responses: Union[Response, Iterable[Response]]) -> None: if isinstance(responses, Response): responses = [responses] @@ -808,92 +753,6 @@ async def preload_records_relationships_before_validate(db: AsyncSession, record ) -async def update_records( - db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, records_update: "RecordsUpdate" -) -> None: - records_ids = [record_update.id for record_update in records_update.items] - - if len(records_ids) != len(set(records_ids)): - raise UnprocessableEntityError("Found duplicate records IDs") - - existing_records_ids = await _exists_records_with_ids(db, dataset_id=dataset.id, records_ids=records_ids) - non_existing_records_ids = set(records_ids) - set(existing_records_ids) - - if len(non_existing_records_ids) > 0: - sorted_non_existing_records_ids = sorted(non_existing_records_ids, key=lambda x: records_ids.index(x)) - records_str = ", ".join([str(record_id) for record_id in sorted_non_existing_records_ids]) - raise UnprocessableEntityError(f"Found records that do not exist: {records_str}") - - # Lists to store the records that will be updated in the database or in the search engine - records_update_objects: List[Dict[str, Any]] = [] - records_search_engine_update: List[UUID] = [] - records_delete_suggestions: List[UUID] = [] - - # Cache dictionaries to avoid querying the database multiple times - caches = { - "metadata_properties": {}, - "questions": {}, - "vector_settings": {}, - } - - existing_records = await get_records_by_ids(db, records_ids=records_ids, dataset_id=dataset.id) - - suggestions = [] - upsert_vectors = [] - for record_i, (record_update, record) in enumerate(zip(records_update.items, existing_records)): - try: - params, record_suggestions, record_vectors, needs_search_engine_update, caches = await _build_record_update( - db, record, record_update, caches - ) - - if record_suggestions is not None: - suggestions.extend(record_suggestions) - records_delete_suggestions.append(record_update.id) - - upsert_vectors.extend(record_vectors) - - if needs_search_engine_update: - records_search_engine_update.append(record_update.id) - - # Only update the record if there are params to update - if len(params) > 1: - records_update_objects.append(params) - except (UnprocessableEntityError, ValueError) as e: - raise UnprocessableEntityError(f"Record at position {record_i} is not valid because {e}") from e - - async with db.begin_nested(): - if records_delete_suggestions: - params = [Suggestion.record_id.in_(records_delete_suggestions)] - await Suggestion.delete_many(db, params=params, autocommit=False) - - if suggestions: - db.add_all(suggestions) - - if upsert_vectors: - await Vector.upsert_many( - db, - objects=upsert_vectors, - constraints=[Vector.record_id, Vector.vector_settings_id], - autocommit=False, - ) - - if records_update_objects: - await Record.update_many(db, records_update_objects, autocommit=False) - - if records_search_engine_update: - records = await get_records_by_ids( - db, - dataset_id=dataset.id, - records_ids=records_search_engine_update, - include=RecordIncludeParam(keys=[RecordInclude.vectors], vectors=None), - ) - await dataset.awaitable_attrs.vectors_settings - await _preload_records_relationships_before_index(db, records) - await search_engine.index_records(dataset, records) - - await db.commit() - - async def delete_records( db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, records_ids: List[UUID] ) -> None: diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_create_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py similarity index 98% rename from argilla-server/tests/unit/api/handlers/v1/datasets/records/test_create_dataset_records.py rename to argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py index c13bc9b6cb..7110e9ce62 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_create_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_create_dataset_records_in_bulk.py @@ -34,9 +34,9 @@ @pytest.mark.asyncio -class TestCreateDatasetRecords: +class TestCreateDatasetRecordsInBulk: def url(self, dataset_id: UUID) -> str: - return f"/api/v1/datasets/{dataset_id}/records" + return f"/api/v1/datasets/{dataset_id}/records/bulk" async def test_create_dataset_records( self, async_client: AsyncClient, db: AsyncSession, owner: User, owner_auth_header: dict @@ -209,7 +209,7 @@ async def test_create_dataset_records( }, ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 1 assert (await db.execute(select(func.count(Response.id)))).scalar_one() == 1 diff --git a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_update_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_update_dataset_records_in_bulk.py similarity index 97% rename from argilla-server/tests/unit/api/handlers/v1/datasets/records/test_update_dataset_records.py rename to argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_update_dataset_records_in_bulk.py index cf9fa909e9..ffb7a24dc7 100644 --- a/argilla-server/tests/unit/api/handlers/v1/datasets/records/test_update_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/datasets/records/records_bulk/test_update_dataset_records_in_bulk.py @@ -35,9 +35,9 @@ @pytest.mark.asyncio -class TestUpdateDatasetRecords: +class TestUpdateDatasetRecordsInBulk: def url(self, dataset_id: UUID) -> str: - return f"/api/v1/datasets/{dataset_id}/records" + return f"/api/v1/datasets/{dataset_id}/records/bulk" async def test_update_dataset_records( self, async_client: AsyncClient, db: AsyncSession, owner: User, owner_auth_header: dict @@ -121,7 +121,7 @@ async def test_update_dataset_records( dataset=dataset, ) - response = await async_client.patch( + response = await async_client.put( self.url(dataset.id), headers=owner_auth_header, json={ @@ -180,7 +180,7 @@ async def test_update_dataset_records( }, ) - assert response.status_code == 204 + assert response.status_code == 200 assert (await db.execute(select(func.count(Record.id)))).scalar_one() == 1 assert (await db.execute(select(func.count(Suggestion.id)))).scalar_one() == 6 diff --git a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py index 9404b3850e..557cb4de70 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_datasets.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_datasets.py @@ -1807,10 +1807,10 @@ async def test_create_dataset_records( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204, response.json() + assert response.status_code == 201, response.json() assert (await db.execute(select(func.count(Record.id)))).scalar() == 5 assert (await db.execute(select(func.count(Response.id)))).scalar() == 4 assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 3 @@ -1872,13 +1872,13 @@ async def test_create_dataset_records_with_response_for_multiple_users( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) await db.refresh(annotator) await db.refresh(owner) - assert response.status_code == 204, response.json() + assert response.status_code == 201, response.json() assert (await db.execute(select(func.count(Record.id)))).scalar() == 2 assert (await db.execute(select(func.count(Response.id)).where(Response.user_id == annotator.id))).scalar() == 2 assert (await db.execute(select(func.count(Response.id)).where(Response.user_id == owner.id))).scalar() == 1 @@ -1912,7 +1912,7 @@ async def test_create_dataset_records_with_response_for_unknown_user( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422, response.json() @@ -1950,7 +1950,7 @@ async def test_create_dataset_records_with_duplicated_response_for_an_user( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422, response.json() @@ -1986,7 +1986,7 @@ async def test_create_dataset_records_with_not_valid_suggestion( question = await TextFieldFactory.create(name="input", dataset=dataset) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={"question_id": str(question.id), **payload}, ) @@ -2020,13 +2020,10 @@ async def test_create_dataset_records_with_missing_required_fields( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 0 is not valid because missing required value for field: 'output'" - } assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 async def test_create_dataset_records_with_wrong_value_field( @@ -2054,7 +2051,7 @@ async def test_create_dataset_records_with_wrong_value_field( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2092,13 +2089,10 @@ async def test_create_dataset_records_with_extra_fields( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 0 is not valid because found fields values for non configured fields: ['output']" - } assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 @pytest.mark.parametrize( @@ -2120,10 +2114,10 @@ async def test_create_dataset_records_with_optional_fields( records_json = {"items": [record_json]} response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204, response.json() + assert response.status_code == 201, response.json() await db.refresh(dataset, attribute_names=["records"]) assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 @@ -2145,7 +2139,7 @@ async def test_create_dataset_records_with_wrong_optional_fields( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 assert response.json() == { @@ -2203,10 +2197,10 @@ async def test_create_dataset_records_with_metadata_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 record = (await db.execute(select(Record))).scalar() assert record.metadata_ == {"metadata-property": value} @@ -2242,7 +2236,7 @@ async def test_create_dataset_records_with_metadata_nan_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2280,14 +2274,10 @@ async def test_create_dataset_records_with_not_valid_metadata_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert ( - "Record at position 0 is not valid because metadata is not valid: 'metadata-property' metadata property validation failed" - in response.json()["detail"] - ) async def test_create_dataset_records_with_extra_metadata_allowed( self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict @@ -2307,10 +2297,10 @@ async def test_create_dataset_records_with_extra_metadata_allowed( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 record = (await db.execute(select(Record))).scalar() assert record.metadata_ == {"terms-metadata": "a", "extra": {"this": {"is": "extra metadata"}}} @@ -2332,15 +2322,10 @@ async def test_create_dataset_records_with_extra_metadata_not_allowed( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert ( - "Record at position 0 is not valid because metadata is not valid: 'not-defined-metadata-property' metadata" - f" property does not exists for dataset '{dataset.id}' and extra metadata is not allowed for this dataset" - == response.json()["detail"] - ) @pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin]) async def test_create_dataset_records_with_vectors( @@ -2356,7 +2341,7 @@ async def test_create_dataset_records_with_vectors( vector_settings_b = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: user.api_key}, json={ "items": [ @@ -2376,7 +2361,7 @@ async def test_create_dataset_records_with_vectors( }, ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Vector.id)))).scalar() == 3 vector_a, vector_b, vector_c = (await db.execute(select(Vector))).scalars().all() @@ -2407,7 +2392,7 @@ async def test_create_dataset_records_with_invalid_vector( vector_settings = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -2420,10 +2405,6 @@ async def test_create_dataset_records_with_invalid_vector( ) assert response.status_code == 422 - assert response.json()["detail"] == ( - f"Record at position 0 is not valid because vector with name={vector_settings.name} is not valid: " - f"vector must have {vector_settings.dimensions} elements, got 1 elements" - ) async def test_create_dataset_records_with_non_existent_vector_settings( self, async_client: "AsyncClient", owner_auth_header: dict @@ -2433,7 +2414,7 @@ async def test_create_dataset_records_with_non_existent_vector_settings( await TextQuestionFactory.create(name="text_ok", dataset=dataset) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -2446,10 +2427,6 @@ async def test_create_dataset_records_with_non_existent_vector_settings( ) assert response.status_code == 422 - assert response.json()["detail"] == ( - "Record at position 0 is not valid because vector with name=missing_vector is not valid: " - f"vector with name=missing_vector does not exist for dataset_id={str(dataset.id)}" - ) async def test_create_dataset_records_with_vector_settings_id_from_another_dataset( self, async_client: "AsyncClient", owner_auth_header: dict @@ -2462,7 +2439,7 @@ async def test_create_dataset_records_with_vector_settings_id_from_another_datas vector_settings = await VectorSettingsFactory.create(dimensions=5) response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -2475,10 +2452,6 @@ async def test_create_dataset_records_with_vector_settings_id_from_another_datas ) assert response.status_code == 422 - assert response.json()["detail"] == ( - f"Record at position 0 is not valid because vector with name={vector_settings.name} is not valid: " - f"vector with name={vector_settings.name} does not exist for dataset_id={dataset.id}" - ) async def test_create_dataset_records_with_index_error( self, async_client: "AsyncClient", mock_search_engine: SearchEngine, db: "AsyncSession", owner_auth_header: dict @@ -2494,7 +2467,7 @@ async def test_create_dataset_records_with_index_error( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2517,7 +2490,7 @@ async def test_create_dataset_records_without_authentication(self, async_client: ], } - response = await async_client.post(f"/api/v1/datasets/{dataset.id}/records", json=records_json) + response = await async_client.post(f"/api/v1/datasets/{dataset.id}/records/bulk", json=records_json) assert response.status_code == 401 assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 @@ -2589,10 +2562,12 @@ async def test_create_dataset_records_as_admin( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: admin.api_key}, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", + headers={API_KEY_HEADER_NAME: admin.api_key}, + json=records_json, ) - assert response.status_code == 204, response.json() + assert response.status_code == 201, response.json() assert (await db.execute(select(func.count(Record.id)))).scalar() == 5 assert (await db.execute(select(func.count(Response.id)))).scalar() == 4 @@ -2623,7 +2598,7 @@ async def test_create_dataset_records_as_annotator(self, async_client: "AsyncCli } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: annotator.api_key}, json=records_json, ) @@ -2652,7 +2627,9 @@ async def test_create_dataset_records_as_admin_from_another_workspace(self, asyn } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers={API_KEY_HEADER_NAME: admin.api_key}, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", + headers={API_KEY_HEADER_NAME: admin.api_key}, + json=records_json, ) assert response.status_code == 403 @@ -2683,10 +2660,10 @@ async def test_create_dataset_records_with_submitted_response( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 assert (await db.execute(select(func.count(Response.id)))).scalar() == 1 @@ -2714,7 +2691,7 @@ async def test_create_dataset_records_with_submitted_response_without_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2751,10 +2728,10 @@ async def test_create_dataset_records_with_discarded_response( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 assert ( await db.execute(select(func.count(Response.id)).filter(Response.status == ResponseStatus.discarded)) @@ -2790,10 +2767,10 @@ async def test_create_dataset_records_with_draft_response( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 assert ( await db.execute(select(func.count(Response.id)).filter(Response.status == ResponseStatus.draft)) @@ -2823,7 +2800,7 @@ async def test_create_dataset_records_with_invalid_response_status( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2859,10 +2836,10 @@ async def test_create_dataset_records_with_discarded_response_without_values( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) - assert response.status_code == 204 + assert response.status_code == 201 assert (await db.execute(select(func.count(Response.id)))).scalar() == 1 assert (await db.execute(select(func.count(Record.id)))).scalar() == 1 @@ -2877,11 +2854,10 @@ async def test_create_dataset_records_with_non_published_dataset( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 - assert response.json() == {"detail": "Records cannot be created for a non published dataset"} assert (await db.execute(select(func.count(Record.id)))).scalar() == 0 assert (await db.execute(select(func.count(Response.id)))).scalar() == 0 @@ -2900,7 +2876,7 @@ async def test_create_dataset_records_with_less_items_than_allowed( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2922,7 +2898,7 @@ async def test_create_dataset_records_with_more_items_than_allowed( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2942,7 +2918,7 @@ async def test_create_dataset_records_with_invalid_records( } response = await async_client.post( - f"/api/v1/datasets/{dataset.id}/records", headers=owner_auth_header, json=records_json + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json=records_json ) assert response.status_code == 422 @@ -2957,7 +2933,7 @@ async def test_create_dataset_records_with_nonexistent_dataset_id( await DatasetFactory.create() response = await async_client.post( - f"/api/v1/datasets/{dataset_id}/records", + f"/api/v1/datasets/{dataset_id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -2977,7 +2953,7 @@ async def test_create_dataset_records_with_nonexistent_dataset_id( async def test_update_dataset_records( self, async_client: "AsyncClient", mock_search_engine: "SearchEngine", role: UserRole ): - dataset = await DatasetFactory.create() + dataset = await DatasetFactory.create(status=DatasetStatus.ready) user = await UserFactory.create(workspaces=[dataset.workspace], role=role) await TermsMetadataPropertyFactory.create(name="terms-metadata-property", dataset=dataset) await IntegerMetadataPropertyFactory.create(name="integer-metadata-property", dataset=dataset) @@ -2988,8 +2964,8 @@ async def test_update_dataset_records( metadata_={"terms-metadata-property": "z", "integer-metadata-property": 1, "float-metadata-property": 1.0}, ) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: user.api_key}, json={ "items": [ @@ -3027,7 +3003,7 @@ async def test_update_dataset_records( }, ) - assert response.status_code == 204 + assert response.status_code == 200, response.json() # Record 0 assert records[0].metadata_ == { @@ -3060,13 +3036,12 @@ async def test_update_dataset_records( "float-metadata-property": 1.0, } - # it should be called only with the first three records (metadata was updated for them) - mock_search_engine.index_records.assert_called_once_with(dataset, records[:3]) + mock_search_engine.index_records.assert_called_once_with(dataset, records[:4]) async def test_update_dataset_records_with_suggestions( self, async_client: "AsyncClient", mock_search_engine: "SearchEngine", owner_auth_header: dict ): - dataset = await DatasetFactory.create() + dataset = await DatasetFactory.create(status=DatasetStatus.ready) question_0 = await TextQuestionFactory.create(dataset=dataset) question_1 = await TextQuestionFactory.create(dataset=dataset) question_2 = await TextQuestionFactory.create(dataset=dataset) @@ -3093,8 +3068,8 @@ async def test_update_dataset_records_with_suggestions( await SuggestionFactory.create(question=question_2, record=records[2], value="suggestion 2 3"), ] - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3142,21 +3117,17 @@ async def test_update_dataset_records_with_suggestions( }, ) - assert response.status_code == 204 + assert response.status_code == 200 # Record 0 await records[0].awaitable_attrs.suggestions assert records[0].suggestions[0].value == "suggestion updated 0 1" assert records[0].suggestions[1].value == "suggestion updated 0 2" assert records[0].suggestions[2].value == "suggestion updated 0 3" - for suggestion in suggestions_records_0: - assert inspect(suggestion).deleted # Record 1 await records[1].awaitable_attrs.suggestions assert records[1].suggestions[0].value == "suggestion updated 1 1" - for suggestion in suggestions_records_1: - assert inspect(suggestion).deleted # Record 2 for suggestion in suggestions_records_2: @@ -3168,39 +3139,12 @@ async def test_update_dataset_records_with_suggestions( assert records[3].suggestions[1].value == "suggestion updated 3 2" assert records[3].suggestions[2].value == "suggestion updated 3 3" - mock_search_engine.index_records.assert_not_called() - - async def test_update_dataset_records_with_empty_list_of_suggestions( - self, async_client: "AsyncClient", owner_auth_header: dict - ): - dataset = await DatasetFactory.create() - question_0 = await TextQuestionFactory.create(dataset=dataset) - question_1 = await TextQuestionFactory.create(dataset=dataset) - question_2 = await TextQuestionFactory.create(dataset=dataset) - record = await RecordFactory.create(dataset=dataset) - - suggestions_records_0 = [ - await SuggestionFactory.create(question=question_0, record=record, value="suggestion 0 1"), - await SuggestionFactory.create(question=question_1, record=record, value="suggestion 0 2"), - await SuggestionFactory.create(question=question_2, record=record, value="suggestion 0 3"), - ] - - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", - headers=owner_auth_header, - json={"items": [{"id": str(record.id), "suggestions": []}]}, - ) - - assert response.status_code == 204 - - assert await record.awaitable_attrs.suggestions == [] - for suggestion in suggestions_records_0: - assert inspect(suggestion).deleted + mock_search_engine.index_records.assert_called_once() async def test_update_dataset_records_with_vectors( self, async_client: "AsyncClient", mock_search_engine: "SearchEngine", owner_auth_header: dict ): - dataset = await DatasetFactory.create() + dataset = await DatasetFactory.create(status=DatasetStatus.ready) vector_settings_0 = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) vector_settings_1 = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) vector_settings_2 = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) @@ -3216,8 +3160,8 @@ async def test_update_dataset_records_with_vectors( await VectorFactory.create(vector_settings=vector_settings_1, record=records[1], value=[4, 4, 4, 4, 4]) await VectorFactory.create(vector_settings=vector_settings_2, record=records[1], value=[5, 5, 5, 5, 5]) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3247,7 +3191,7 @@ async def test_update_dataset_records_with_vectors( }, ) - assert response.status_code == 204 + assert response.status_code == 200 # Record 0 await records[0].awaitable_attrs.vectors @@ -3276,8 +3220,8 @@ async def test_update_dataset_records_with_invalid_metadata( await TermsMetadataPropertyFactory.create(dataset=dataset, name="terms") records = await RecordFactory.create_batch(5, dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3298,10 +3242,6 @@ async def test_update_dataset_records_with_invalid_metadata( ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 1 is not valid because metadata is not valid: 'terms' metadata property " - "validation failed because 'i was not declared' is not an allowed term." - } async def test_update_dataset_records_with_metadata_nan_value( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3311,8 +3251,8 @@ async def test_update_dataset_records_with_metadata_nan_value( await FloatMetadataPropertyFactory.create(dataset=dataset, name="float") records = await RecordFactory.create_batch(3, dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3341,8 +3281,8 @@ async def test_update_dataset_records_with_invalid_suggestions( question = await LabelSelectionQuestionFactory.create(dataset=dataset) records = await RecordFactory.create_batch(5, dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3356,9 +3296,6 @@ async def test_update_dataset_records_with_invalid_suggestions( ) assert response.status_code == 422 - assert response.json() == { - "detail": f"Record at position 0 is not valid because suggestion for question_id={question.id} is not valid: 'option-a' is not a valid label for label selection question.\nValid labels are: ['option1', 'option2', 'option3']" - } async def test_update_dataset_records_with_invalid_vectors( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3367,8 +3304,8 @@ async def test_update_dataset_records_with_invalid_vectors( vector_settings = await VectorSettingsFactory.create(dataset=dataset, dimensions=5) records = await RecordFactory.create_batch(5, dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3378,18 +3315,14 @@ async def test_update_dataset_records_with_invalid_vectors( ) assert response.status_code == 422 - assert response.json() == { - "detail": f"Record at position 0 is not valid because vector with name={vector_settings.name} is not " - "valid: vector must have 5 elements, got 6 elements" - } async def test_update_dataset_records_with_nonexistent_dataset_id( self, async_client: "AsyncClient", owner_auth_header: dict ): dataset_id = uuid4() - response = await async_client.patch( - f"/api/v1/datasets/{dataset_id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset_id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3413,16 +3346,13 @@ async def test_update_dataset_records_with_nonexistent_records( records.append({"id": str(record.id), "metadata": {"i exists": True}}) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={"items": records}, ) assert response.status_code == 422 - assert response.json() == { - "detail": f"Found records that do not exist: {records[0]['id']}, {records[1]['id']}, {records[2]['id']}" - } async def test_update_dataset_records_with_nonexistent_question_id( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3432,8 +3362,8 @@ async def test_update_dataset_records_with_nonexistent_question_id( question_id = str(uuid4()) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3443,10 +3373,6 @@ async def test_update_dataset_records_with_nonexistent_question_id( ) assert response.status_code == 422 - assert response.json() == { - "detail": f"Record at position 0 is not valid because suggestion for question_id={question_id} is not " - f"valid: question_id={question_id} does not exist" - } async def test_update_dataset_records_with_nonexistent_vector_settings_name( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3454,17 +3380,13 @@ async def test_update_dataset_records_with_nonexistent_vector_settings_name( dataset = await DatasetFactory.create() record = await RecordFactory.create(dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={"items": [{"id": str(record.id), "vectors": {"i-do-not-exist": [1, 2, 3, 4]}}]}, ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 0 is not valid because vector with name=i-do-not-exist is not valid: vector " - f"with name=i-do-not-exist does not exist for dataset_id={dataset.id}" - } async def test_update_dataset_records_with_duplicate_records_ids( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3472,8 +3394,8 @@ async def test_update_dataset_records_with_duplicate_records_ids( dataset = await DatasetFactory.create() record = await RecordFactory.create(dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3484,7 +3406,6 @@ async def test_update_dataset_records_with_duplicate_records_ids( ) assert response.status_code == 422 - assert response.json() == {"detail": "Found duplicate records IDs"} async def test_update_dataset_records_with_duplicate_suggestions_question_ids( self, async_client: "AsyncClient", owner_auth_header: dict @@ -3493,8 +3414,8 @@ async def test_update_dataset_records_with_duplicate_suggestions_question_ids( question = await TextQuestionFactory.create(dataset=dataset) record = await RecordFactory.create(dataset=dataset) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers=owner_auth_header, json={ "items": [ @@ -3510,16 +3431,13 @@ async def test_update_dataset_records_with_duplicate_suggestions_question_ids( ) assert response.status_code == 422 - assert response.json() == { - "detail": "Record at position 0 is not valid because found duplicate suggestions question IDs" - } async def test_update_dataset_records_as_admin_from_another_workspace(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() user = await UserFactory.create(role=UserRole.admin) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: user.api_key}, json={ "items": [ @@ -3536,8 +3454,8 @@ async def test_update_dataset_records_as_annotator(self, async_client: "AsyncCli dataset = await DatasetFactory.create() user = await UserFactory.create(role=UserRole.annotator, workspaces=[dataset.workspace]) - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", headers={API_KEY_HEADER_NAME: user.api_key}, json={ "items": [ @@ -3553,8 +3471,8 @@ async def test_update_dataset_records_as_annotator(self, async_client: "AsyncCli async def test_update_dataset_records_without_authentication(self, async_client: "AsyncClient"): dataset = await DatasetFactory.create() - response = await async_client.patch( - f"/api/v1/datasets/{dataset.id}/records", json={"items": [{"id": str(uuid4())}]} + response = await async_client.put( + f"/api/v1/datasets/{dataset.id}/records/bulk", json={"items": [{"id": str(uuid4())}]} ) assert response.status_code == 401 From c219764e41600c1b9a14f80a1a27aa0e43e129cc Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Fri, 12 Jul 2024 16:56:26 +0200 Subject: [PATCH 20/36] [ENHANCEMENT] `argilla`: add record `status` property (#5184) # Description This PR adds the record status as a read-only property in the `Record` resource class. Closes https://github.com/argilla-io/argilla/issues/5141 **Type of change** - Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- .../src/argilla/_models/_record/_record.py | 4 ++-- argilla/src/argilla/records/_resource.py | 22 ++++++++++++++----- .../tests/integration/test_list_records.py | 13 +++++++++++ argilla/tests/unit/test_io/test_generic.py | 1 + .../tests/unit/test_io/test_hf_datasets.py | 1 + .../tests/unit/test_resources/test_records.py | 10 +++++++++ 6 files changed, 44 insertions(+), 7 deletions(-) diff --git a/argilla/src/argilla/_models/_record/_record.py b/argilla/src/argilla/_models/_record/_record.py index 38a4996c96..09f2e42272 100644 --- a/argilla/src/argilla/_models/_record/_record.py +++ b/argilla/src/argilla/_models/_record/_record.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Literal from pydantic import Field, field_serializer, field_validator @@ -30,12 +30,12 @@ class RecordModel(ResourceModel): """Schema for the records of a `Dataset`""" + status: Literal["pending", "completed"] = "pending" fields: Optional[Dict[str, FieldValue]] = None metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict) vectors: Optional[List[VectorModel]] = Field(default_factory=list) responses: Optional[List[UserResponseModel]] = Field(default_factory=list) suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] = Field(default_factory=tuple) - external_id: Optional[Any] = None @field_serializer("external_id", when_used="unless-none") diff --git a/argilla/src/argilla/records/_resource.py b/argilla/src/argilla/records/_resource.py index 53c1321b4b..27ec8be113 100644 --- a/argilla/src/argilla/records/_resource.py +++ b/argilla/src/argilla/records/_resource.py @@ -103,7 +103,7 @@ def __init__( def __repr__(self) -> str: return ( - f"Record(id={self.id},fields={self.fields},metadata={self.metadata}," + f"Record(id={self.id},status={self.status},fields={self.fields},metadata={self.metadata}," f"suggestions={self.suggestions},responses={self.responses})" ) @@ -147,6 +147,10 @@ def metadata(self) -> "RecordMetadata": def vectors(self) -> "RecordVectors": return self.__vectors + @property + def status(self) -> str: + return self._model.status + @property def _server_id(self) -> Optional[UUID]: return self._model.id @@ -164,6 +168,7 @@ def api_model(self) -> RecordModel: vectors=self.vectors.api_models(), responses=self.responses.api_models(), suggestions=self.suggestions.api_models(), + status=self.status, ) def serialize(self) -> Dict[str, Any]: @@ -185,6 +190,7 @@ def to_dict(self) -> Dict[str, Dict]: """ id = str(self.id) if self.id else None server_id = str(self._model.id) if self._model.id else None + status = self.status fields = self.fields.to_dict() metadata = self.metadata.to_dict() suggestions = self.suggestions.to_dict() @@ -198,6 +204,7 @@ def to_dict(self) -> Dict[str, Dict]: "suggestions": suggestions, "responses": responses, "vectors": vectors, + "status": status, "_server_id": server_id, } @@ -245,7 +252,7 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record": Returns: A Record object. """ - return cls( + instance = cls( id=model.external_id, fields=model.fields, metadata={meta.name: meta.value for meta in model.metadata}, @@ -257,10 +264,15 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record": for response in UserResponse.from_model(response_model, dataset=dataset) ], suggestions=[Suggestion.from_model(model=suggestion, dataset=dataset) for suggestion in model.suggestions], - _dataset=dataset, - _server_id=model.id, ) + # set private attributes + instance._dataset = dataset + instance._model.id = model.id + instance._model.status = model.status + + return instance + class RecordFields(dict): """This is a container class for the fields of a Record. @@ -335,7 +347,7 @@ def to_dict(self) -> Dict[str, List[Dict]]: response_dict = defaultdict(list) for response in self.__responses: response_dict[response.question_name].append({"value": response.value, "user_id": str(response.user_id)}) - return response_dict + return dict(response_dict) def api_models(self) -> List[UserResponseModel]: """Returns a list of ResponseModel objects.""" diff --git a/argilla/tests/integration/test_list_records.py b/argilla/tests/integration/test_list_records.py index ec51124be5..58407bc273 100644 --- a/argilla/tests/integration/test_list_records.py +++ b/argilla/tests/integration/test_list_records.py @@ -65,6 +65,19 @@ def test_list_records_with_start_offset(client: Argilla, dataset: Dataset): records = list(dataset.records(start_offset=1)) assert len(records) == 1 + assert [record.to_dict() for record in records] == [ + { + "_server_id": str(records[0]._server_id), + "fields": {"text": "The record text field"}, + "id": "2", + "status": "pending", + "metadata": {}, + "responses": {}, + "suggestions": {}, + "vectors": {}, + } + ] + def test_list_records_with_responses(client: Argilla, dataset: Dataset): dataset.records.log( diff --git a/argilla/tests/unit/test_io/test_generic.py b/argilla/tests/unit/test_io/test_generic.py index 446693f5b5..374ee20eed 100644 --- a/argilla/tests/unit/test_io/test_generic.py +++ b/argilla/tests/unit/test_io/test_generic.py @@ -41,6 +41,7 @@ def test_to_list_flatten(self): assert records_list == [ { "id": str(record.id), + "status": "pending", "_server_id": None, "field": "The field", "key": "value", diff --git a/argilla/tests/unit/test_io/test_hf_datasets.py b/argilla/tests/unit/test_io/test_hf_datasets.py index f13ab04ef4..99e43d8caf 100644 --- a/argilla/tests/unit/test_io/test_hf_datasets.py +++ b/argilla/tests/unit/test_io/test_hf_datasets.py @@ -46,6 +46,7 @@ def test_to_datasets_with_partial_values_in_records(self): ds = HFDatasetsIO.to_datasets(records) assert ds.features == { + "status": Value(dtype="string", id=None), "_server_id": Value(dtype="null", id=None), "a": Value(dtype="string", id=None), "b": Value(dtype="string", id=None), diff --git a/argilla/tests/unit/test_resources/test_records.py b/argilla/tests/unit/test_resources/test_records.py index 09759430c7..b04adb0203 100644 --- a/argilla/tests/unit/test_resources/test_records.py +++ b/argilla/tests/unit/test_resources/test_records.py @@ -14,6 +14,8 @@ import uuid +import pytest + from argilla import Record, Suggestion, Response from argilla._models import MetadataModel @@ -31,6 +33,7 @@ def test_record_repr(self): ) assert ( record.__repr__() == f"Record(id={record_id}," + "status=pending," "fields={'name': 'John', 'age': '30'}," "metadata={'key': 'value'}," "suggestions={'question': {'value': 'answer', 'score': None, 'agent': None}}," @@ -62,3 +65,10 @@ def test_update_record_vectors(self): record.vectors["new-vector"] = [1.0, 2.0, 3.0] assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]} + + def test_prevent_update_record(self): + record = Record(fields={"name": "John"}) + assert record.status == "pending" + + with pytest.raises(AttributeError): + record.status = "completed" From 9c9aa26005b109ee439dc67e07760ae42d16d32c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jul 2024 19:06:21 +0000 Subject: [PATCH 21/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- argilla-server/src/argilla_server/search_engine/commons.py | 2 -- .../tests/unit/api/handlers/v1/test_list_dataset_records.py | 1 - 2 files changed, 3 deletions(-) diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 04106a8f52..6cce9be64b 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -200,7 +200,6 @@ def es_path_for_vector_settings(vector_settings: VectorSettings) -> str: return str(vector_settings.id) - def es_path_for_question_response(question_name: str) -> str: return f"{question_name}" @@ -542,7 +541,6 @@ def _map_record_metadata_to_es( return search_engine_metadata - def _map_record_responses_to_es(self, responses: List[Response]) -> List[dict]: return [self._map_record_response_to_es(response) for response in responses] diff --git a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py index 3ed1910664..3dc3546d29 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_list_dataset_records.py @@ -607,4 +607,3 @@ async def create_dataset_with_user_responses( ] return dataset, questions, records, responses, suggestions - From 11ef1681734527aaa8f4362efc1b5a38a1a1e597 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Fri, 19 Jul 2024 08:55:55 +0200 Subject: [PATCH 22/36] Update argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts --- .../datasets/dataset-progress/useDatasetProgressViewModel.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts b/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts index fc823fccb1..9dbb5f1ef1 100644 --- a/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts +++ b/argilla-frontend/components/features/datasets/dataset-progress/useDatasetProgressViewModel.ts @@ -36,7 +36,6 @@ export const useDatasetProgressViewModel = ({ }, ]; - isLoaded.value = true; }); From 7356451c082c4143e892eca1443e3bdc801e8391 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 31 Jul 2024 14:57:35 +0200 Subject: [PATCH 23/36] chore: Remove repositories --- .../argilla_server/repositories/__init__.py | 18 ------ .../argilla_server/repositories/datasets.py | 29 --------- .../argilla_server/repositories/records.py | 63 ------------------- 3 files changed, 110 deletions(-) delete mode 100644 argilla-server/src/argilla_server/repositories/__init__.py delete mode 100644 argilla-server/src/argilla_server/repositories/datasets.py delete mode 100644 argilla-server/src/argilla_server/repositories/records.py diff --git a/argilla-server/src/argilla_server/repositories/__init__.py b/argilla-server/src/argilla_server/repositories/__init__.py deleted file mode 100644 index 98424d94a6..0000000000 --- a/argilla-server/src/argilla_server/repositories/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argilla_server.repositories.datasets import DatasetsRepository -from argilla_server.repositories.records import RecordsRepository - -__all__ = ["DatasetsRepository", "RecordsRepository"] diff --git a/argilla-server/src/argilla_server/repositories/datasets.py b/argilla-server/src/argilla_server/repositories/datasets.py deleted file mode 100644 index 46ac90e2fe..0000000000 --- a/argilla-server/src/argilla_server/repositories/datasets.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from uuid import UUID - -from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession - -from argilla_server.database import get_async_db -from argilla_server.models import Dataset - - -class DatasetsRepository: - def __init__(self, db: AsyncSession = Depends(get_async_db)): - self.db = db - - async def get(self, dataset_id: UUID) -> Dataset: - return await Dataset.get_or_raise(db=self.db, id=dataset_id) diff --git a/argilla-server/src/argilla_server/repositories/records.py b/argilla-server/src/argilla_server/repositories/records.py deleted file mode 100644 index ea93d94028..0000000000 --- a/argilla-server/src/argilla_server/repositories/records.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union, List, Tuple, Sequence -from uuid import UUID - -from fastapi import Depends -from sqlalchemy import select, and_, func -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload, contains_eager - -from argilla_server.database import get_async_db -from argilla_server.models import Record, VectorSettings, Vector - - -class RecordsRepository: - def __init__( - self, - db: AsyncSession = Depends(get_async_db), - ): - self.db = db - - async def list_by_dataset_id( - self, - dataset_id: UUID, - offset: int, - limit: int, - with_responses: bool = False, - with_suggestions: bool = False, - with_vectors: Union[bool, List[str]] = False, - ) -> Tuple[Sequence[Record], int]: - query = select(Record).filter_by(dataset_id=dataset_id) - - if with_responses: - query = query.options(selectinload(Record.responses)) - if with_suggestions: - query = query.options(selectinload(Record.suggestions)) - if with_vectors is True: - query = query.options(selectinload(Record.vectors)) - elif isinstance(with_vectors, list): - subquery = select(VectorSettings.id).filter( - and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) - ) - query = query.outerjoin( - Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) - ).options(contains_eager(Record.vectors)) - - records = (await self.db.scalars(query.offset(offset).limit(limit).order_by(Record.inserted_at))).unique().all() - - total = await self.db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id)) - - return records, total From 39e6bd7290af2a1c23029cf665d130d3a90228de Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 31 Jul 2024 14:58:41 +0200 Subject: [PATCH 24/36] refactor: Moving logic to contexts --- .../src/argilla_server/contexts/datasets.py | 7 ++++ .../src/argilla_server/contexts/records.py | 40 +++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 2ca212855b..77dd28b5b9 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -59,6 +59,7 @@ ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema from argilla_server.contexts import accounts, distribution +from argilla_server.database import get_async_db from argilla_server.enums import DatasetStatus, UserRole, RecordStatus from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError from argilla_server.models import ( @@ -114,6 +115,12 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> return result.scalars().all() +async def get_or_raise(dataset_id: UUID) -> Dataset: + """Get a dataset by ID or raise a NotFoundError""" + async for db in get_async_db(): + return await Dataset.get_or_raise(db, id=dataset_id) + + async def create_dataset(db: AsyncSession, dataset_attrs: dict): dataset = Dataset( name=dataset_attrs["name"], diff --git a/argilla-server/src/argilla_server/contexts/records.py b/argilla-server/src/argilla_server/contexts/records.py index c2b0f20bb9..10748df4c7 100644 --- a/argilla-server/src/argilla_server/contexts/records.py +++ b/argilla-server/src/argilla_server/contexts/records.py @@ -12,14 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Sequence +from typing import Dict, Sequence, Union, List, Tuple from uuid import UUID -from sqlalchemy import select +from sqlalchemy import select, and_, func from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import selectinload, contains_eager -from argilla_server.models import Dataset, Record +from argilla_server.database import get_async_db +from argilla_server.models import Dataset, Record, VectorSettings, Vector + + +async def list_records_by_dataset_id( + dataset_id: UUID, + offset: int, + limit: int, + with_responses: bool = False, + with_suggestions: bool = False, + with_vectors: Union[bool, List[str]] = False, +) -> Tuple[Sequence[Record], int]: + query = select(Record).filter_by(dataset_id=dataset_id) + + if with_responses: + query = query.options(selectinload(Record.responses)) + if with_suggestions: + query = query.options(selectinload(Record.suggestions)) + if with_vectors is True: + query = query.options(selectinload(Record.vectors)) + elif isinstance(with_vectors, list): + subquery = select(VectorSettings.id).filter( + and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) + ) + query = query.outerjoin( + Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) + ).options(contains_eager(Record.vectors)) + + async for db in get_async_db(): + records = (await db.scalars(query.offset(offset).limit(limit).order_by(Record.inserted_at))).unique().all() + total = await db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id)) + + return records, total async def list_dataset_records_by_ids( From e4eb17f6df222b2ac9f08dff0b4ef30abadad4bc Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 31 Jul 2024 14:59:01 +0200 Subject: [PATCH 25/36] refactor: using contexts --- .../api/handlers/v1/datasets/records.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index dd519f994c..72310120b2 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -15,7 +15,7 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID -from fastapi import APIRouter, Depends, Query, Security, status +from fastapi import APIRouter, Depends, Query, Security, status, Path from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -43,19 +43,16 @@ SearchSuggestionsOptions, SuggestionFilterScope, ) -from argilla_server.contexts import datasets, search +from argilla_server.contexts import datasets, search, records from argilla_server.database import get_async_db from argilla_server.enums import RecordSortField from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE from argilla_server.models import Dataset, Field, Record, User, VectorSettings -from argilla_server.repositories import DatasetsRepository, RecordsRepository -from argilla_server.models import Dataset, Field, Record, User, VectorSettings from argilla_server.search_engine import ( AndFilter, SearchEngine, SearchResponses, - UserResponseStatusFilter, get_search_engine, ) from argilla_server.security import auth @@ -80,7 +77,7 @@ def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> elif isinstance(scope, MetadataFilterScope): return search_engine.MetadataFilterScope(metadata_property=scope.metadata_property) elif isinstance(scope, SuggestionFilterScope): - return search_engine.SuggestionFilterScope(question=scope.question, property=scope.property) + return search_engine.SuggestionFilterScope(question=scope.question, property=str(scope.property)) elif isinstance(scope, ResponseFilterScope): return search_engine.ResponseFilterScope(question=scope.question, property=scope.property, user=user) else: @@ -203,18 +200,19 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord raise UnprocessableEntityError(str(e)) +async def get_dataset_or_raise(dataset_id: UUID = Path) -> Dataset: + return await datasets.get_or_raise(dataset_id) + + @router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) async def list_dataset_records( *, - datasets_repository: DatasetsRepository = Depends(), - records_repository: RecordsRepository = Depends(), - dataset_id: UUID, + dataset: Dataset = Depends(get_dataset_or_raise), include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), ): - dataset = await datasets_repository.get(dataset_id) await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset)) include_args = ( @@ -227,14 +225,14 @@ async def list_dataset_records( else {} ) - records, total = await records_repository.list_by_dataset_id( + dataset_records, total = await records.list_records_by_dataset_id( dataset_id=dataset.id, offset=offset, limit=limit, **include_args, ) - return Records(items=records, total=total) + return Records(items=dataset_records, total=total) @router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT) From 6326a541fcbe0c2b52bc5111c70df994a42a700c Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 31 Jul 2024 14:59:23 +0200 Subject: [PATCH 26/36] tests: Mock db for contexts --- argilla-server/tests/unit/conftest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/argilla-server/tests/unit/conftest.py b/argilla-server/tests/unit/conftest.py index f6e4d7cc2f..558b192463 100644 --- a/argilla-server/tests/unit/conftest.py +++ b/argilla-server/tests/unit/conftest.py @@ -22,7 +22,7 @@ from opensearchpy import OpenSearch from argilla_server import telemetry -from argilla_server.contexts import distribution +from argilla_server.contexts import distribution, datasets, records from argilla_server.api.routes import api_v1 from argilla_server.constants import API_KEY_HEADER_NAME, DEFAULT_API_KEY from argilla_server.database import get_async_db @@ -92,6 +92,8 @@ async def override_get_search_engine(): yield mock_search_engine mocker.patch.object(distribution, "_get_async_db", override_get_async_db) + mocker.patch.object(datasets, "get_async_db", override_get_async_db) + mocker.patch.object(records, "get_async_db", override_get_async_db) api_v1.dependency_overrides.update( { From 3ce1f84fa085e0b4704bd922555e99d0a810a78e Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 31 Jul 2024 15:05:29 +0200 Subject: [PATCH 27/36] refactor: Reusing depends get_dataset --- .../argilla_server/api/handlers/v1/datasets/records.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 72310120b2..5cf5bd8240 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -238,14 +238,12 @@ async def list_dataset_records( @router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT) async def delete_dataset_records( *, + dataset: Dataset = Depends(get_dataset_or_raise), db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), - dataset_id: UUID, current_user: User = Security(auth.get_current_user), ids: str = Query(..., description="A comma separated list with the IDs of the records to be removed"), ): - dataset = await Dataset.get_or_raise(db, dataset_id) - await authorize(current_user, DatasetPolicy.delete_records(dataset)) record_ids = parse_uuids(ids) @@ -392,12 +390,10 @@ async def search_dataset_records( ) async def list_dataset_records_search_suggestions_options( *, + dataset: Dataset = Depends(get_dataset_or_raise), db: AsyncSession = Depends(get_async_db), - dataset_id: UUID, current_user: User = Security(auth.get_current_user), ): - dataset = await Dataset.get_or_raise(db, dataset_id) - await authorize(current_user, DatasetPolicy.search_records(dataset)) suggestion_agents_by_question = await search.get_dataset_suggestion_agents_by_question(db, dataset.id) From 82e306e7b865076453fe8430c03a19e7a175fbd4 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 31 Jul 2024 15:31:12 +0200 Subject: [PATCH 28/36] refactor: Moving query builder to models --- .../src/argilla_server/contexts/records.py | 38 +++++------- .../src/argilla_server/models/database.py | 58 ++++++++++++++++++- 2 files changed, 72 insertions(+), 24 deletions(-) diff --git a/argilla-server/src/argilla_server/contexts/records.py b/argilla-server/src/argilla_server/contexts/records.py index 10748df4c7..233b0d9442 100644 --- a/argilla-server/src/argilla_server/contexts/records.py +++ b/argilla-server/src/argilla_server/contexts/records.py @@ -31,25 +31,18 @@ async def list_records_by_dataset_id( with_suggestions: bool = False, with_vectors: Union[bool, List[str]] = False, ) -> Tuple[Sequence[Record], int]: - query = select(Record).filter_by(dataset_id=dataset_id) - - if with_responses: - query = query.options(selectinload(Record.responses)) - if with_suggestions: - query = query.options(selectinload(Record.suggestions)) - if with_vectors is True: - query = query.options(selectinload(Record.vectors)) - elif isinstance(with_vectors, list): - subquery = select(VectorSettings.id).filter( - and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) - ) - query = query.outerjoin( - Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) - ).options(contains_eager(Record.vectors)) + query = Record.Select.by_dataset_id( + dataset_id=dataset_id, + offset=offset, + limit=limit, + with_responses=with_responses, + with_suggestions=with_suggestions, + with_vectors=with_vectors, + ) async for db in get_async_db(): - records = (await db.scalars(query.offset(offset).limit(limit).order_by(Record.inserted_at))).unique().all() - total = await db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id)) + records = (await db.scalars(query)).unique().all() + total = await db.scalar(Record.Select.count(dataset_id=dataset_id)) return records, total @@ -57,19 +50,20 @@ async def list_records_by_dataset_id( async def list_dataset_records_by_ids( db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID] ) -> Sequence[Record]: - query = select(Record).filter(Record.id.in_(record_ids), Record.dataset_id == dataset_id) - return (await db.execute(query)).unique().scalars().all() + query = Record.Select.by_dataset_id(dataset_id=dataset_id).where(Record.id.in_(record_ids)) + return (await db.scalars(query)).unique().all() async def list_dataset_records_by_external_ids( db: AsyncSession, dataset_id: UUID, external_ids: Sequence[str] ) -> Sequence[Record]: query = ( - select(Record) - .filter(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id) + Record.Select.by_dataset_id(dataset_id=dataset_id) + .where(Record.external_id.in_(external_ids)) .options(selectinload(Record.dataset)) ) - return (await db.execute(query)).unique().scalars().all() + + return (await db.scalars(query)).unique().all() async def fetch_records_by_ids_as_dict( diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 6b9580dbb5..14e799daa2 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -17,12 +17,25 @@ from typing import Any, List, Optional, Union from uuid import UUID -from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql, select, func, text +from sqlalchemy import ( + JSON, + ForeignKey, + String, + Text, + UniqueConstraint, + and_, + sql, + select, + func, + text, + Select, + ColumnExpressionArgument, +) from sqlalchemy import Enum as SAEnum from sqlalchemy.engine.default import DefaultExecutionContext from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.mutable import MutableDict, MutableList -from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property +from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property, selectinload, contains_eager from argilla_server.api.schemas.v1.questions import QuestionSettings from argilla_server.enums import ( @@ -236,6 +249,47 @@ def __repr__(self): f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" ) + class Select: + @classmethod + def count(cls, **filters) -> Select: + return select(func.count(Record.id)).filter_by(**filters) + + @classmethod + def by_dataset_id( + cls, + dataset_id: UUID, + offset: Optional[int] = None, + limit: Optional[int] = None, + with_responses: bool = False, + with_suggestions: bool = False, + with_vectors: Union[bool, List[str]] = False, + ) -> Select: + query = select(Record).filter_by(dataset_id=dataset_id) + + if with_responses: + query = query.options(selectinload(Record.responses)) + + if with_suggestions: + query = query.options(selectinload(Record.suggestions)) + + if with_vectors is True: + query = query.options(selectinload(Record.vectors)) + elif isinstance(with_vectors, list): + subquery = select(VectorSettings.id).filter( + and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) + ) + query = query.outerjoin( + Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) + ).options(contains_eager(Record.vectors)) + + if offset is not None: + query = query.offset(offset) + + if limit is not None: + query = query.limit(limit) + + return query.order_by(Record.inserted_at) + class Question(DatabaseModel): __tablename__ = "questions" From 7205842a26b217be00f43ca014b4e05a98633152 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 1 Aug 2024 18:09:27 +0200 Subject: [PATCH 29/36] chore: Update CHANGELOG --- argilla-server/CHANGELOG.md | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index b34c054563..ed657f711a 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -40,12 +40,6 @@ These are the section headers that we use: - [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) - Change search index mapping for responses (reindex is required). ([#5228](https://github.com/argilla-io/argilla/pull/5228)) -### Changed - -- Change `responses` table to delete rows on cascade when a user is deleted. ([#5126](https://github.com/argilla-io/argilla/pull/5126)) -- [breaking] Change `GET /api/v1/datasets/:dataset_id/progress` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) -- [breaking] Change `GET /api/v1/me/datasets/:dataset_id/metrics` endpoint to support new dataset distribution task. ([#5140](https://github.com/argilla-io/argilla/pull/5140)) - ### Fixed - Fixed SQLite connection settings not working correctly due to an outdated conditional. ([#5149](https://github.com/argilla-io/argilla/pull/5149)) @@ -54,15 +48,6 @@ These are the section headers that we use: ### Removed -- [breaking] Remove deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) -- [breaking] Remove deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) -- [breaking] Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) -- [breaking] Removed support for `response_status` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5163](https://github.com/argilla-io/argilla/pull/5163)) -- [breaking] Removed support for `metadata` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5156](https://github.com/argilla-io/argilla/pull/5156)) -- [breaking] Removed support for `sort_by` query param for endpoints `POST /api/v1/me/datasets/:dataset_id/records/search` and `POST /api/v1/datasets/:dataset_id/records/search`. ([#5166](https://github.com/argilla-io/argilla/pull/5166)) - -## [2.0.0rc1](https://github.com/argilla-io/argilla/compare/v1.29.0...v2.0.0rc1) - - [breaking] Removed deprecated endpoint `POST /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) - [breaking] Removed deprecated endpoint `PATCH /api/v1/datasets/:dataset_id/records`. ([#5206](https://github.com/argilla-io/argilla/pull/5206)) - [breaking] Removed `GET /api/v1/me/datasets/:dataset_id/records` endpoint. ([#5153](https://github.com/argilla-io/argilla/pull/5153)) From 59d05c5a82ca8f5b5a6fd38ae9099d4be886044d Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 1 Aug 2024 18:15:02 +0200 Subject: [PATCH 30/36] chore: Change order --- .../src/argilla_server/api/handlers/v1/datasets/records.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 5cf5bd8240..0d379ed1b1 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -207,11 +207,11 @@ async def get_dataset_or_raise(dataset_id: UUID = Path) -> Dataset: @router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) async def list_dataset_records( *, - dataset: Dataset = Depends(get_dataset_or_raise), include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), + dataset: Dataset = Depends(get_dataset_or_raise), ): await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset)) @@ -238,10 +238,10 @@ async def list_dataset_records( @router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT) async def delete_dataset_records( *, - dataset: Dataset = Depends(get_dataset_or_raise), db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), current_user: User = Security(auth.get_current_user), + dataset: Dataset = Depends(get_dataset_or_raise), ids: str = Query(..., description="A comma separated list with the IDs of the records to be removed"), ): await authorize(current_user, DatasetPolicy.delete_records(dataset)) @@ -390,9 +390,9 @@ async def search_dataset_records( ) async def list_dataset_records_search_suggestions_options( *, - dataset: Dataset = Depends(get_dataset_or_raise), db: AsyncSession = Depends(get_async_db), current_user: User = Security(auth.get_current_user), + dataset: Dataset = Depends(get_dataset_or_raise), ): await authorize(current_user, DatasetPolicy.search_records(dataset)) From d8aa03e5c3907f06bf92de0b9cca6471a0446fec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:21:51 +0000 Subject: [PATCH 31/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../src/argilla_server/api/handlers/v1/datasets/records.py | 1 - 1 file changed, 1 deletion(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 5ec2c6533e..c99c9ae64f 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -193,7 +193,6 @@ async def _get_search_responses( return await search_engine.search(**search_params) - async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): try: await search.validate_search_records_query(db, query, dataset) From 4b257532c2be0eb0b8218b21356756e6450c00bb Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 9 Sep 2024 14:48:02 +0200 Subject: [PATCH 32/36] chore: Apply PR comments --- .../api/handlers/v1/datasets/records.py | 24 +++++---- .../src/argilla_server/contexts/datasets.py | 6 --- .../src/argilla_server/contexts/records.py | 53 +++++++++++++++---- .../src/argilla_server/models/database.py | 41 -------------- 4 files changed, 57 insertions(+), 67 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 4684ecdd0c..1b2d9f8d85 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union from uuid import UUID -from fastapi import APIRouter, Depends, Query, Security, status, Path +from fastapi import APIRouter, Depends, Query, Security, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -45,7 +45,7 @@ ) from argilla_server.contexts import datasets, search, records from argilla_server.database import get_async_db -from argilla_server.enums import RecordSortField, ResponseStatusFilter +from argilla_server.enums import RecordSortField from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE from argilla_server.models import Dataset, Field, Record, User, VectorSettings @@ -193,26 +193,24 @@ async def _get_search_responses( return await search_engine.search(**search_params) -async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset_id: UUID): +async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset: Dataset): try: await search.validate_search_records_query(db, query, dataset) except (ValueError, NotFoundError) as e: raise UnprocessableEntityError(str(e)) -async def get_dataset_or_raise(dataset_id: UUID = Path) -> Dataset: - return await datasets.get_or_raise(dataset_id) - - @router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True) async def list_dataset_records( *, + db: AsyncSession = Depends(get_async_db), + dataset_id: UUID, include: Optional[RecordIncludeParam] = Depends(parse_record_include_param), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE), current_user: User = Security(auth.get_current_user), - dataset: Dataset = Depends(get_dataset_or_raise), ): + dataset = await Dataset.get_or_raise(db, dataset_id) await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset)) include_args = ( @@ -226,6 +224,7 @@ async def list_dataset_records( ) dataset_records, total = await records.list_records_by_dataset_id( + db=db, dataset_id=dataset.id, offset=offset, limit=limit, @@ -240,10 +239,11 @@ async def delete_dataset_records( *, db: AsyncSession = Depends(get_async_db), search_engine: SearchEngine = Depends(get_search_engine), + dataset_id: UUID, current_user: User = Security(auth.get_current_user), - dataset: Dataset = Depends(get_dataset_or_raise), ids: str = Query(..., description="A comma separated list with the IDs of the records to be removed"), ): + dataset = await Dataset.get_or_raise(db, dataset_id) await authorize(current_user, DatasetPolicy.delete_records(dataset)) record_ids = parse_uuids(ids) @@ -391,9 +391,11 @@ async def search_dataset_records( async def list_dataset_records_search_suggestions_options( *, db: AsyncSession = Depends(get_async_db), + dataset_id: UUID, current_user: User = Security(auth.get_current_user), - dataset: Dataset = Depends(get_dataset_or_raise), ): + dataset = await Dataset.get_or_raise(db, dataset_id) + await authorize(current_user, DatasetPolicy.search_records(dataset)) suggestion_agents_by_question = await search.get_dataset_suggestion_agents_by_question(db, dataset.id) diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 7fd2e22b2f..e9f653251e 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -117,12 +117,6 @@ async def list_datasets_by_workspace_id(db: AsyncSession, workspace_id: UUID) -> return result.scalars().all() -async def get_or_raise(dataset_id: UUID) -> Dataset: - """Get a dataset by ID or raise a NotFoundError""" - async for db in get_async_db(): - return await Dataset.get_or_raise(db, id=dataset_id) - - async def create_dataset(db: AsyncSession, dataset_attrs: dict): dataset = Dataset( name=dataset_attrs["name"], diff --git a/argilla-server/src/argilla_server/contexts/records.py b/argilla-server/src/argilla_server/contexts/records.py index 233b0d9442..01a4c763a5 100644 --- a/argilla-server/src/argilla_server/contexts/records.py +++ b/argilla-server/src/argilla_server/contexts/records.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Sequence, Union, List, Tuple +from typing import Dict, Sequence, Union, List, Tuple, Optional from uuid import UUID -from sqlalchemy import select, and_, func +from sqlalchemy import select, and_, func, Select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload, contains_eager @@ -24,6 +24,7 @@ async def list_records_by_dataset_id( + db: AsyncSession, dataset_id: UUID, offset: int, limit: int, @@ -31,7 +32,7 @@ async def list_records_by_dataset_id( with_suggestions: bool = False, with_vectors: Union[bool, List[str]] = False, ) -> Tuple[Sequence[Record], int]: - query = Record.Select.by_dataset_id( + query = _record_by_dataset_id_query( dataset_id=dataset_id, offset=offset, limit=limit, @@ -40,17 +41,16 @@ async def list_records_by_dataset_id( with_vectors=with_vectors, ) - async for db in get_async_db(): - records = (await db.scalars(query)).unique().all() - total = await db.scalar(Record.Select.count(dataset_id=dataset_id)) + records = (await db.scalars(query)).unique().all() + total = await db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id)) - return records, total + return records, total async def list_dataset_records_by_ids( db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID] ) -> Sequence[Record]: - query = Record.Select.by_dataset_id(dataset_id=dataset_id).where(Record.id.in_(record_ids)) + query = _record_by_dataset_id_query(dataset_id).where(Record.id.in_(record_ids)) return (await db.scalars(query)).unique().all() @@ -58,7 +58,7 @@ async def list_dataset_records_by_external_ids( db: AsyncSession, dataset_id: UUID, external_ids: Sequence[str] ) -> Sequence[Record]: query = ( - Record.Select.by_dataset_id(dataset_id=dataset_id) + _record_by_dataset_id_query(dataset_id) .where(Record.external_id.in_(external_ids)) .options(selectinload(Record.dataset)) ) @@ -78,3 +78,38 @@ async def fetch_records_by_external_ids_as_dict( ) -> Dict[str, Record]: records_by_external_ids = await list_dataset_records_by_external_ids(db, dataset.id, external_ids) return {record.external_id: record for record in records_by_external_ids} + + +def _record_by_dataset_id_query( + dataset_id, + offset: Optional[int] = None, + limit: Optional[int] = None, + with_responses: bool = False, + with_suggestions: bool = False, + with_vectors: Union[bool, List[str]] = False, +) -> Select: + query = select(Record).filter_by(dataset_id=dataset_id) + + if with_responses: + query = query.options(selectinload(Record.responses)) + + if with_suggestions: + query = query.options(selectinload(Record.suggestions)) + + if with_vectors is True: + query = query.options(selectinload(Record.vectors)) + elif isinstance(with_vectors, list): + subquery = select(VectorSettings.id).filter( + and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) + ) + query = query.outerjoin( + Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) + ).options(contains_eager(Record.vectors)) + + if offset is not None: + query = query.offset(offset) + + if limit is not None: + query = query.limit(limit) + + return query.order_by(Record.inserted_at) diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 15d52d1297..adc3d21e01 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -258,47 +258,6 @@ def __repr__(self): f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" ) - class Select: - @classmethod - def count(cls, **filters) -> Select: - return select(func.count(Record.id)).filter_by(**filters) - - @classmethod - def by_dataset_id( - cls, - dataset_id: UUID, - offset: Optional[int] = None, - limit: Optional[int] = None, - with_responses: bool = False, - with_suggestions: bool = False, - with_vectors: Union[bool, List[str]] = False, - ) -> Select: - query = select(Record).filter_by(dataset_id=dataset_id) - - if with_responses: - query = query.options(selectinload(Record.responses)) - - if with_suggestions: - query = query.options(selectinload(Record.suggestions)) - - if with_vectors is True: - query = query.options(selectinload(Record.vectors)) - elif isinstance(with_vectors, list): - subquery = select(VectorSettings.id).filter( - and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors)) - ) - query = query.outerjoin( - Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery)) - ).options(contains_eager(Record.vectors)) - - if offset is not None: - query = query.offset(offset) - - if limit is not None: - query = query.limit(limit) - - return query.order_by(Record.inserted_at) - class Question(DatabaseModel): __tablename__ = "questions" From 31390ab021cae7bc2a7afc108eb8f725c3a8ae8c Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 9 Sep 2024 14:50:50 +0200 Subject: [PATCH 33/36] chore: Revert newline --- .../src/argilla_server/api/handlers/v1/datasets/records.py | 1 + 1 file changed, 1 insertion(+) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index 1b2d9f8d85..b26979bce5 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -244,6 +244,7 @@ async def delete_dataset_records( ids: str = Query(..., description="A comma separated list with the IDs of the records to be removed"), ): dataset = await Dataset.get_or_raise(db, dataset_id) + await authorize(current_user, DatasetPolicy.delete_records(dataset)) record_ids = parse_uuids(ids) From b00f404e222334b1c8d8937a4c05388e7d14dadd Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 17 Sep 2024 10:17:16 +0200 Subject: [PATCH 34/36] chore: Apply suggestions --- .../argilla_server/api/handlers/v1/datasets/records.py | 2 +- argilla-server/src/argilla_server/contexts/records.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py index b26979bce5..5c394a77a8 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/datasets/records.py @@ -223,7 +223,7 @@ async def list_dataset_records( else {} ) - dataset_records, total = await records.list_records_by_dataset_id( + dataset_records, total = await records.list_dataset_records( db=db, dataset_id=dataset.id, offset=offset, diff --git a/argilla-server/src/argilla_server/contexts/records.py b/argilla-server/src/argilla_server/contexts/records.py index 01a4c763a5..0764b3152f 100644 --- a/argilla-server/src/argilla_server/contexts/records.py +++ b/argilla-server/src/argilla_server/contexts/records.py @@ -23,7 +23,7 @@ from argilla_server.models import Dataset, Record, VectorSettings, Vector -async def list_records_by_dataset_id( +async def list_dataset_records( db: AsyncSession, dataset_id: UUID, offset: int, @@ -50,7 +50,7 @@ async def list_records_by_dataset_id( async def list_dataset_records_by_ids( db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID] ) -> Sequence[Record]: - query = _record_by_dataset_id_query(dataset_id).where(Record.id.in_(record_ids)) + query = select(Record).where(and_(Record.id.in_(record_ids), Record.dataset_id == dataset_id)) return (await db.scalars(query)).unique().all() @@ -58,8 +58,8 @@ async def list_dataset_records_by_external_ids( db: AsyncSession, dataset_id: UUID, external_ids: Sequence[str] ) -> Sequence[Record]: query = ( - _record_by_dataset_id_query(dataset_id) - .where(Record.external_id.in_(external_ids)) + select(Record) + .where(and_(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id)) .options(selectinload(Record.dataset)) ) From a9e679527ce196b64e564c4537bf7432a65fab19 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 17 Sep 2024 10:20:02 +0200 Subject: [PATCH 35/36] revert code changes --- argilla-server/src/argilla_server/models/database.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/argilla-server/src/argilla_server/models/database.py b/argilla-server/src/argilla_server/models/database.py index 99db6ee872..00b960ad1b 100644 --- a/argilla-server/src/argilla_server/models/database.py +++ b/argilla-server/src/argilla_server/models/database.py @@ -17,6 +17,7 @@ from typing import Any, List, Optional, Union from uuid import UUID +from sqlalchemy import Enum as SAEnum from sqlalchemy import ( JSON, ForeignKey, @@ -25,17 +26,11 @@ UniqueConstraint, and_, sql, - select, - func, - text, - Select, - ColumnExpressionArgument, ) -from sqlalchemy import Enum as SAEnum from sqlalchemy.engine.default import DefaultExecutionContext from sqlalchemy.ext.asyncio import async_object_session from sqlalchemy.ext.mutable import MutableDict, MutableList -from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property, selectinload, contains_eager +from sqlalchemy.orm import Mapped, mapped_column, relationship from argilla_server.api.schemas.v1.questions import QuestionSettings from argilla_server.enums import ( @@ -43,7 +38,6 @@ FieldType, MetadataPropertyType, QuestionType, - RecordStatus, ResponseStatus, SuggestionType, UserRole, From 031a40730d28b15e662ad9d15ff1d5eb22f7ccf0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 07:29:51 +0000 Subject: [PATCH 36/36] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- argilla-server/CHANGELOG.md | 2 +- argilla-server/src/argilla_server/contexts/accounts.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 4ef90db35a..42e3c8e15f 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -20,7 +20,7 @@ These are the section headers that we use: - Added filtering by `name`, and `status` support to endpoint `GET /api/v1/me/datasets`. ([#5374](https://github.com/argilla-io/argilla/pull/5374)) -## [2.2.0]() +## [2.2.0]() ### Added diff --git a/argilla-server/src/argilla_server/contexts/accounts.py b/argilla-server/src/argilla_server/contexts/accounts.py index 723b68a5c3..712a0dea24 100644 --- a/argilla-server/src/argilla_server/contexts/accounts.py +++ b/argilla-server/src/argilla_server/contexts/accounts.py @@ -73,9 +73,7 @@ async def create_workspace(db: AsyncSession, workspace_attrs: dict) -> Workspace async def delete_workspace(db: AsyncSession, workspace: Workspace): if await datasets.list_datasets(db, workspace_id=workspace.id): - raise NotUniqueError( - f"Cannot delete the workspace {workspace.id}. This workspace has some datasets linked" - ) + raise NotUniqueError(f"Cannot delete the workspace {workspace.id}. This workspace has some datasets linked") return await workspace.delete(db)