Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENHANCEMENT] argilla-server: List records endpoint using db #5170

Merged
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
f62d58a
feat: add dataset support to be created using distribution settings (…
jfcalvo Jul 1, 2024
017001f
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 1, 2024
f084ab7
✨ Remove unused method
damianpumar Jul 4, 2024
c8ef4c6
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 4, 2024
6df5256
feat: improve Records `responses_submitted` relationship to be view o…
jfcalvo Jul 4, 2024
dbae135
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 4, 2024
cf3408c
feat: change metrics to support new distribution task logic (#5140)
jfcalvo Jul 4, 2024
267811c
[REFACTOR] `argilla-server`: Remove list current user records endpoin…
frascuchon Jul 4, 2024
89f9bde
[BREAKING- REFACTOR] `argilla-server`: remove metadata filter query p…
frascuchon Jul 4, 2024
0404465
[BREAKING - REFACTOR] `argilla-server`: remove user response status s…
frascuchon Jul 4, 2024
20d4ab8
refactor: Remove sort_by argument
frascuchon Jul 4, 2024
5f4e5b0
[breaking] refactor: Remove sort_by query param
frascuchon Jul 4, 2024
c885392
tests: Adapt tests
frascuchon Jul 4, 2024
28b2998
chore: Update changelog
frascuchon Jul 5, 2024
209d64d
feat: Define new repositories
frascuchon Jul 5, 2024
a350b0c
chore: Rewrite list endpoint using repositories
frascuchon Jul 5, 2024
3537941
tests: Enable skip tests for list dataset records
frascuchon Jul 5, 2024
8e8b116
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
frascuchon Jul 5, 2024
808c837
[ENHANCEMENT]: `argilla-server`: allow update distribution for non an…
frascuchon Jul 8, 2024
ba417dc
[BREAKING - REFACTOR] `argilla-server`: remove `sort_by` query param …
frascuchon Jul 8, 2024
f241e41
fix: wrong filter naming after merge from develop
frascuchon Jul 9, 2024
67d4ee3
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 9, 2024
3e06890
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 9, 2024
b15de8f
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
frascuchon Jul 11, 2024
f497140
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 11, 2024
bec0b0d
feat: add session helper with serializable isolation level (#5165)
jfcalvo Jul 12, 2024
8bf8abb
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 12, 2024
85e847f
[REFACTOR] `argilla-server`: remove deprecated records endpoint (#5206)
frascuchon Jul 12, 2024
22263d8
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 12, 2024
c219764
[ENHANCEMENT] `argilla`: add record `status` property (#5184)
frascuchon Jul 12, 2024
ced0220
Merge branch 'develop' into feat/add-dataset-automatic-task-distribution
jfcalvo Jul 12, 2024
aa9bf1f
Merge branch 'feat/add-dataset-automatic-task-distribution' into refa…
frascuchon Jul 12, 2024
0b73b3f
Merge branch 'refactor/cleaning-list-records-endpoints' into refactor…
frascuchon Jul 12, 2024
dcfbfaf
Merge branch 'feat/add-dataset-automatic-task-distribution' into refa…
frascuchon Jul 12, 2024
4d3f668
Merge branch 'refactor/cleaning-list-records-endpoints' into refactor…
frascuchon Jul 12, 2024
2941072
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Jul 18, 2024
9c9aa26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
11ef168
Update argilla-frontend/components/features/datasets/dataset-progress…
frascuchon Jul 19, 2024
0e525b4
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Jul 25, 2024
1526e33
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Jul 29, 2024
bca45ff
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Jul 31, 2024
7356451
chore: Remove repositories
frascuchon Jul 31, 2024
39e6bd7
refactor: Moving logic to contexts
frascuchon Jul 31, 2024
e4eb17f
refactor: using contexts
frascuchon Jul 31, 2024
6326a54
tests: Mock db for contexts
frascuchon Jul 31, 2024
3ce1f84
refactor: Reusing depends get_dataset
frascuchon Jul 31, 2024
82e306e
refactor: Moving query builder to models
frascuchon Jul 31, 2024
423466a
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
jfcalvo Jul 31, 2024
7205842
chore: Update CHANGELOG
frascuchon Aug 1, 2024
59d05c5
chore: Change order
frascuchon Aug 1, 2024
61bc08f
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Aug 1, 2024
d656bbc
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Aug 2, 2024
d8aa03e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
ee3fa63
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Aug 26, 2024
6d9ecfb
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
jfcalvo Aug 30, 2024
d4f70de
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 9, 2024
4b25753
chore: Apply PR comments
frascuchon Sep 9, 2024
c8d3be8
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 9, 2024
31390ab
chore: Revert newline
frascuchon Sep 9, 2024
a7b6713
Merge branch 'refactor/argilla-server/list-records-endpoint-using-db'…
frascuchon Sep 9, 2024
a7dc205
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 17, 2024
b00f404
chore: Apply suggestions
frascuchon Sep 17, 2024
a9e6795
revert code changes
frascuchon Sep 17, 2024
acfe981
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 18, 2024
031a407
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
22539d6
Merge branch 'develop' into refactor/argilla-server/list-records-endp…
frascuchon Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from fastapi import APIRouter, Depends, Query, Security, status
Expand Down Expand Up @@ -43,17 +43,16 @@
SearchSuggestionsOptions,
SuggestionFilterScope,
)
from argilla_server.contexts import datasets, search
from argilla_server.contexts import datasets, search, records
from argilla_server.database import get_async_db
from argilla_server.enums import RecordSortField, ResponseStatusFilter
from argilla_server.enums import RecordSortField
from argilla_server.errors.future import MissingVectorError, NotFoundError, UnprocessableEntityError
from argilla_server.errors.future.base_errors import MISSING_VECTOR_ERROR_CODE
from argilla_server.models import Dataset, Field, Record, User, VectorSettings
from argilla_server.search_engine import (
AndFilter,
SearchEngine,
SearchResponses,
UserResponseStatusFilter,
get_search_engine,
)
from argilla_server.security import auth
Expand All @@ -72,42 +71,13 @@
router = APIRouter()


async def _filter_records_using_search_engine(
db: "AsyncSession",
search_engine: "SearchEngine",
dataset: Dataset,
limit: int,
offset: int,
user: Optional[User] = None,
include: Optional[RecordIncludeParam] = None,
) -> Tuple[List[Record], int]:
search_responses = await _get_search_responses(
db=db,
search_engine=search_engine,
dataset=dataset,
limit=limit,
offset=offset,
user=user,
)

record_ids = [response.record_id for response in search_responses.items]
user_id = user.id if user else None

return (
await datasets.get_records_by_ids(
db=db, dataset_id=dataset.id, user_id=user_id, records_ids=record_ids, include=include
),
search_responses.total,
)


def _to_search_engine_filter_scope(scope: FilterScope, user: Optional[User]) -> search_engine.FilterScope:
if isinstance(scope, RecordFilterScope):
return search_engine.RecordFilterScope(property=scope.property)
elif isinstance(scope, MetadataFilterScope):
return search_engine.MetadataFilterScope(metadata_property=scope.metadata_property)
elif isinstance(scope, SuggestionFilterScope):
return search_engine.SuggestionFilterScope(question=scope.question, property=scope.property)
return search_engine.SuggestionFilterScope(question=scope.question, property=str(scope.property))
elif isinstance(scope, ResponseFilterScope):
return search_engine.ResponseFilterScope(question=scope.question, property=scope.property, user=user)
else:
Expand Down Expand Up @@ -223,18 +193,6 @@ async def _get_search_responses(
return await search_engine.search(**search_params)


async def _build_response_status_filter_for_search(
response_statuses: Optional[List[ResponseStatusFilter]] = None, user: Optional[User] = None
) -> Optional[UserResponseStatusFilter]:
user_response_status_filter = None

if response_statuses:
# TODO(@frascuchon): user response and status responses should be split into different filter types
user_response_status_filter = UserResponseStatusFilter(user=user, statuses=response_statuses)

return user_response_status_filter


async def _validate_search_records_query(db: "AsyncSession", query: SearchRecordsQuery, dataset: Dataset):
try:
await search.validate_search_records_query(db, query, dataset)
Expand All @@ -246,27 +204,34 @@ async def _validate_search_records_query(db: "AsyncSession", query: SearchRecord
async def list_dataset_records(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
dataset_id: UUID,
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ge=1, le=LIST_DATASET_RECORDS_LIMIT_LE),
current_user: User = Security(auth.get_current_user),
):
dataset = await Dataset.get_or_raise(db, dataset_id)

await authorize(current_user, DatasetPolicy.list_records_with_all_responses(dataset))

records, total = await _filter_records_using_search_engine(
db,
search_engine,
dataset=dataset,
limit=limit,
include_args = (
dict(
with_responses=include.with_responses,
with_suggestions=include.with_suggestions,
with_vectors=include.with_all_vectors or include.vectors,
)
if include
else {}
)

dataset_records, total = await records.list_records_by_dataset_id(
db=db,
dataset_id=dataset.id,
offset=offset,
include=include,
limit=limit,
**include_args,
)

return Records(items=records, total=total)
return Records(items=dataset_records, total=total)


@router.delete("/datasets/{dataset_id}/records", status_code=status.HTTP_204_NO_CONTENT)
Expand Down
79 changes: 70 additions & 9 deletions argilla-server/src/argilla_server/contexts/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Sequence
from typing import Dict, Sequence, Union, List, Tuple, Optional
from uuid import UUID

from sqlalchemy import select
from sqlalchemy import select, and_, func, Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import selectinload, contains_eager

from argilla_server.models import Dataset, Record
from argilla_server.database import get_async_db
from argilla_server.models import Dataset, Record, VectorSettings, Vector


async def list_records_by_dataset_id(
db: AsyncSession,
dataset_id: UUID,
offset: int,
limit: int,
with_responses: bool = False,
with_suggestions: bool = False,
with_vectors: Union[bool, List[str]] = False,
) -> Tuple[Sequence[Record], int]:
query = _record_by_dataset_id_query(
dataset_id=dataset_id,
offset=offset,
limit=limit,
with_responses=with_responses,
with_suggestions=with_suggestions,
with_vectors=with_vectors,
)

records = (await db.scalars(query)).unique().all()
total = await db.scalar(select(func.count(Record.id)).filter_by(dataset_id=dataset_id))

return records, total


async def list_dataset_records_by_ids(
db: AsyncSession, dataset_id: UUID, record_ids: Sequence[UUID]
) -> Sequence[Record]:
query = select(Record).filter(Record.id.in_(record_ids), Record.dataset_id == dataset_id)
return (await db.execute(query)).unique().scalars().all()
query = _record_by_dataset_id_query(dataset_id).where(Record.id.in_(record_ids))
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
return (await db.scalars(query)).unique().all()


async def list_dataset_records_by_external_ids(
db: AsyncSession, dataset_id: UUID, external_ids: Sequence[str]
) -> Sequence[Record]:
query = (
select(Record)
.filter(Record.external_id.in_(external_ids), Record.dataset_id == dataset_id)
_record_by_dataset_id_query(dataset_id)
.where(Record.external_id.in_(external_ids))
.options(selectinload(Record.dataset))
)
return (await db.execute(query)).unique().scalars().all()

return (await db.scalars(query)).unique().all()


async def fetch_records_by_ids_as_dict(
Expand All @@ -52,3 +78,38 @@
) -> Dict[str, Record]:
records_by_external_ids = await list_dataset_records_by_external_ids(db, dataset.id, external_ids)
return {record.external_id: record for record in records_by_external_ids}


def _record_by_dataset_id_query(
dataset_id,
offset: Optional[int] = None,
limit: Optional[int] = None,
with_responses: bool = False,
with_suggestions: bool = False,
with_vectors: Union[bool, List[str]] = False,
) -> Select:
query = select(Record).filter_by(dataset_id=dataset_id)

if with_responses:
query = query.options(selectinload(Record.responses))

if with_suggestions:
query = query.options(selectinload(Record.suggestions))

Check warning on line 97 in argilla-server/src/argilla_server/contexts/records.py

View check run for this annotation

Codecov / codecov/patch

argilla-server/src/argilla_server/contexts/records.py#L97

Added line #L97 was not covered by tests

if with_vectors is True:
query = query.options(selectinload(Record.vectors))
elif isinstance(with_vectors, list):
subquery = select(VectorSettings.id).filter(
and_(VectorSettings.dataset_id == dataset_id, VectorSettings.name.in_(with_vectors))
)
query = query.outerjoin(
Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(subquery))
).options(contains_eager(Record.vectors))

if offset is not None:
query = query.offset(offset)

if limit is not None:
query = query.limit(limit)

return query.order_by(Record.inserted_at)
17 changes: 15 additions & 2 deletions argilla-server/src/argilla_server/models/database.py
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,25 @@
from typing import Any, List, Optional, Union
from uuid import UUID

from sqlalchemy import JSON, ForeignKey, String, Text, UniqueConstraint, and_, sql, select, func, text
from sqlalchemy import (
JSON,
ForeignKey,
String,
Text,
UniqueConstraint,
and_,
sql,
select,
func,
text,
Select,
ColumnExpressionArgument,
)
from sqlalchemy import Enum as SAEnum
from sqlalchemy.engine.default import DefaultExecutionContext
from sqlalchemy.ext.asyncio import async_object_session
from sqlalchemy.ext.mutable import MutableDict, MutableList
from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property
from sqlalchemy.orm import Mapped, mapped_column, relationship, column_property, selectinload, contains_eager

from argilla_server.api.schemas.v1.questions import QuestionSettings
from argilla_server.enums import (
Expand Down
Loading