diff --git a/argilla-server/src/argilla_server/contexts/datasets.py b/argilla-server/src/argilla_server/contexts/datasets.py index 177a6c352b..e8d4f703e9 100644 --- a/argilla-server/src/argilla_server/contexts/datasets.py +++ b/argilla-server/src/argilla_server/contexts/datasets.py @@ -61,8 +61,14 @@ ) from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema from argilla_server.webhooks.v1.enums import DatasetEvent, ResponseEvent, RecordEvent -from argilla_server.webhooks.v1.records import notify_record_event as notify_record_event_v1 -from argilla_server.webhooks.v1.responses import notify_response_event as notify_response_event_v1 +from argilla_server.webhooks.v1.records import ( + notify_record_event as notify_record_event_v1, + build_record_event as build_record_event_v1, +) +from argilla_server.webhooks.v1.responses import ( + notify_response_event as notify_response_event_v1, + build_response_event as build_response_event_v1, +) from argilla_server.webhooks.v1.datasets import notify_dataset_event as notify_dataset_event_v1 from argilla_server.contexts import accounts, distribution from argilla_server.database import get_async_db @@ -812,6 +818,7 @@ async def preload_records_relationships_before_validate(db: AsyncSession, record ) +# TODO: Use build_record_event_v1 instead async def delete_records( db: AsyncSession, search_engine: "SearchEngine", dataset: Dataset, records_ids: List[UUID] ) -> None: @@ -860,13 +867,14 @@ async def update_record( async def delete_record(db: AsyncSession, search_engine: "SearchEngine", record: Record) -> Record: + deleted_record_event_v1 = await build_record_event_v1(db, RecordEvent.deleted, record) + async with db.begin_nested(): record = await record.delete(db=db, autocommit=False) await search_engine.delete_records(dataset=record.dataset, records=[record]) await db.commit() - - await notify_record_event_v1(db, RecordEvent.deleted, record) + await deleted_record_event_v1.notify(db) return record @@ -962,6 +970,8 @@ async def upsert_response( async def delete_response(db: AsyncSession, search_engine: SearchEngine, response: Response) -> Response: + deleted_response_event_v1 = await build_response_event_v1(db, ResponseEvent.deleted, response) + async with db.begin_nested(): response = await response.delete(db, autocommit=False) @@ -971,7 +981,8 @@ async def delete_response(db: AsyncSession, search_engine: SearchEngine, respons await db.commit() await distribution.update_record_status(search_engine, response.record_id) - await notify_response_event_v1(db, ResponseEvent.deleted, response) + # TODO: think about the record status being updated after the event being build + await deleted_response_event_v1.notify(db) return response diff --git a/argilla-server/src/argilla_server/webhooks/v1/event.py b/argilla-server/src/argilla_server/webhooks/v1/event.py new file mode 100644 index 0000000000..5efa831af7 --- /dev/null +++ b/argilla-server/src/argilla_server/webhooks/v1/event.py @@ -0,0 +1,36 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +from datetime import datetime + +from rq.job import Job +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.jobs.webhook_jobs import enqueue_notify_events + + +class Event: + def __init__(self, type: str, timestamp: datetime, data: dict): + self.type = type + self.timestamp = timestamp + self.data = data + + async def notify(self, db: AsyncSession) -> List[Job]: + return await enqueue_notify_events( + db, + event=self.type, + timestamp=self.timestamp, + data=self.data, + ) diff --git a/argilla-server/src/argilla_server/webhooks/v1/records.py b/argilla-server/src/argilla_server/webhooks/v1/records.py index b60d95d884..f322b4c2bd 100644 --- a/argilla-server/src/argilla_server/webhooks/v1/records.py +++ b/argilla-server/src/argilla_server/webhooks/v1/records.py @@ -20,16 +20,20 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from argilla_server.webhooks.v1.event import Event from argilla_server.webhooks.v1.enums import RecordEvent from argilla_server.webhooks.v1.schemas import RecordEventSchema -from argilla_server.jobs.webhook_jobs import enqueue_notify_events from argilla_server.models import Record, Dataset async def notify_record_event(db: AsyncSession, record_event: RecordEvent, record: Record) -> List[Job]: - if record_event == RecordEvent.deleted: - return await _notify_record_deleted_event(db, record) + event = await build_record_event(db, record_event, record) + return await event.notify(db) + + +async def build_record_event(db: AsyncSession, record_event: RecordEvent, record: Record) -> Event: + # NOTE: Force loading required association resources required by the event schema ( await db.execute( select(Dataset) @@ -44,18 +48,8 @@ async def notify_record_event(db: AsyncSession, record_event: RecordEvent, recor ) ).scalar_one() - return await enqueue_notify_events( - db, - event=record_event, + return Event( + type=record_event, timestamp=datetime.utcnow(), data=RecordEventSchema.from_orm(record).dict(), ) - - -async def _notify_record_deleted_event(db: AsyncSession, record: Record) -> List[Job]: - return await enqueue_notify_events( - db, - event=RecordEvent.deleted, - timestamp=datetime.utcnow(), - data={"id": record.id}, - ) diff --git a/argilla-server/src/argilla_server/webhooks/v1/responses.py b/argilla-server/src/argilla_server/webhooks/v1/responses.py index a4048094c0..22f80fa2e2 100644 --- a/argilla-server/src/argilla_server/webhooks/v1/responses.py +++ b/argilla-server/src/argilla_server/webhooks/v1/responses.py @@ -22,15 +22,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from argilla_server.models import Response, Record, Dataset -from argilla_server.jobs.webhook_jobs import enqueue_notify_events +from argilla_server.webhooks.v1.event import Event from argilla_server.webhooks.v1.schemas import ResponseEventSchema -from argilla_server.webhooks.v1.enums import ResponseEvent +from argilla_server.webhooks.v1.enums import ResponseEvent, WebhookEvent -async def notify_response_event(db: AsyncSession, response_event: ResponseEvent, response: Response) -> List[Job]: - if response_event == ResponseEvent.deleted: - return await _notify_response_deleted_event(db, response) - +async def build_response_event(db: AsyncSession, response_event: ResponseEvent, response: Response) -> Event: # NOTE: Force loading required association resources required by the event schema ( await db.execute( @@ -51,18 +48,14 @@ async def notify_response_event(db: AsyncSession, response_event: ResponseEvent, ) ).scalar_one() - return await enqueue_notify_events( - db, - event=response_event, + return Event( + type=response_event, timestamp=datetime.utcnow(), data=ResponseEventSchema.from_orm(response).dict(), ) -async def _notify_response_deleted_event(db: AsyncSession, response: Response) -> List[Job]: - return await enqueue_notify_events( - db, - event=ResponseEvent.deleted, - timestamp=datetime.utcnow(), - data={"id": response.id}, - ) +async def notify_response_event(db: AsyncSession, response_event: ResponseEvent, response: Response) -> List[Job]: + event = await build_response_event(db, response_event, response) + + return await event.notify(db)