diff --git a/argilla-server/.env.dev b/argilla-server/.env.dev index 76c10523d0..18fcb79f01 100644 --- a/argilla-server/.env.dev +++ b/argilla-server/.env.dev @@ -2,3 +2,5 @@ OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES # Needed by RQ to work with forked proce ALEMBIC_CONFIG=src/argilla_server/alembic.ini ARGILLA_AUTH_SECRET_KEY=8VO7na5N/jQx+yP/N+HlE8q51vPdrxqlh6OzoebIyko= # With this we avoid using a different key every time the server is reloaded ARGILLA_DATABASE_URL=sqlite+aiosqlite:///${HOME}/.argilla/argilla.db?check_same_thread=False +# For mac users only https://github.com/rq/rq/issues/1418 +OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES diff --git a/argilla-server/src/argilla_server/contexts/distribution.py b/argilla-server/src/argilla_server/contexts/distribution.py index 410c375178..85062f8c4c 100644 --- a/argilla-server/src/argilla_server/contexts/distribution.py +++ b/argilla-server/src/argilla_server/contexts/distribution.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import backoff -import sqlalchemy - from typing import List from uuid import UUID -from sqlalchemy.orm import selectinload +import backoff +import sqlalchemy from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from argilla_server.api.webhooks.v1.enums import RecordEvent from argilla_server.api.webhooks.v1.records import notify_record_event as notify_record_event_v1 diff --git a/argilla/src/argilla/_models/_record/_record.py b/argilla/src/argilla/_models/_record/_record.py index 9a9bfed19e..b2ba679a41 100644 --- a/argilla/src/argilla/_models/_record/_record.py +++ b/argilla/src/argilla/_models/_record/_record.py @@ -39,6 +39,8 @@ class RecordModel(ResourceModel): suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] = Field(default_factory=tuple) external_id: Optional[Any] = Field(default=None) + dataset_id: Optional[uuid.UUID] = Field(default=None) + @field_serializer("external_id", when_used="unless-none") def serialize_external_id(self, value: str) -> str: return str(value) @@ -77,3 +79,29 @@ def validate_external_id(cls, external_id: Any) -> Union[str, int, uuid.UUID]: if external_id is None: external_id = uuid.uuid4() return external_id + + @field_validator("vectors", mode="before") + @classmethod + def empty_vectors_if_none(cls, vectors: Optional[List[VectorModel]]) -> Optional[List[VectorModel]]: + """Ensure vectors is None if not provided.""" + if vectors is None: + return [] + return vectors + + @field_validator("responses", mode="before") + @classmethod + def empty_responses_if_none(cls, responses: Optional[List[UserResponseModel]]) -> Optional[List[UserResponseModel]]: + """Ensure responses is None if not provided.""" + if responses is None: + return [] + return responses + + @field_validator("suggestions", mode="before") + @classmethod + def empty_suggestions_if_none( + cls, suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] + ) -> Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]]: + """Ensure suggestions is None if not provided.""" + if suggestions is None: + return [] + return suggestions diff --git a/examples/webhooks/distilabel_trigger/.gitignore b/examples/webhooks/distilabel_trigger/.gitignore new file mode 100644 index 0000000000..efa407c35f --- /dev/null +++ b/examples/webhooks/distilabel_trigger/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/examples/webhooks/distilabel_trigger/README.md b/examples/webhooks/distilabel_trigger/README.md new file mode 100644 index 0000000000..5b13c44bd6 --- /dev/null +++ b/examples/webhooks/distilabel_trigger/README.md @@ -0,0 +1,31 @@ + + + +## Running the app + +1. Start argilla server and argilla worker +```bash +pdm server start +pdm worker +``` + +2. Add the `localhost.org` alias in the `/etc/hosts` file to comply with the Top Level Domain URL requirement. +``` +## +# Host Database +# +# localhost is used to configure the loopback interface +# when the system is booting. Do not change this entry. +## +127.0.0.1 localhost localhost.org +``` + +2. Start the app +```bash +uvicorn webhook:server +``` + +## Testing the app +Annotate some record in the argilla UI and check the logs of the app to see the webhook being triggered. diff --git a/examples/webhooks/distilabel_trigger/configure_models.py b/examples/webhooks/distilabel_trigger/configure_models.py new file mode 100644 index 0000000000..f4a3cc46c8 --- /dev/null +++ b/examples/webhooks/distilabel_trigger/configure_models.py @@ -0,0 +1,49 @@ +import os +from typing import List + +from distilabel.llms import InferenceEndpointsLLM +from distilabel.steps.tasks import TextGeneration, Task, UltraFeedback + +LLAMA_MODEL_ID = os.environ.get( + "LLAMA_MODEL_ID", "meta-llama/Meta-Llama-3.1-8B-Instruct" +) +GEMMA_MODEL_ID = os.environ.get("GEMMA_MODEL_ID", "google/gemma-1.1-7b-it") +ULTRAFEEDBACK_MODEL_ID = os.environ.get( + "ULTRAFEEDBACK_MODEL_ID", "meta-llama/Meta-Llama-3.1-70B-Instruct" +) + + +def initialize_text_generation_models() -> List["Task"]: + llama31 = TextGeneration( + name="text-generation", + llm=InferenceEndpointsLLM( + model_id=LLAMA_MODEL_ID, + tokenizer_id=LLAMA_MODEL_ID, + ), + ) + llama31.load() + + gemma_tiny = TextGeneration( + name="text-generation", + llm=InferenceEndpointsLLM( + model_id=GEMMA_MODEL_ID, + tokenizer_id=GEMMA_MODEL_ID, + ), + ) + gemma_tiny.load() + + return [llama31, gemma_tiny] + + +def initialize_ultrafeedback(): + ultrafeedback = UltraFeedback( + aspect="overall-rating", + llm=InferenceEndpointsLLM( + model_id=ULTRAFEEDBACK_MODEL_ID, + tokenizer_id=ULTRAFEEDBACK_MODEL_ID, + ), + ) + + ultrafeedback.load() + + return ultrafeedback diff --git a/examples/webhooks/distilabel_trigger/configure_webhook.py b/examples/webhooks/distilabel_trigger/configure_webhook.py new file mode 100644 index 0000000000..b1ca5548cd --- /dev/null +++ b/examples/webhooks/distilabel_trigger/configure_webhook.py @@ -0,0 +1,24 @@ +import os + +import argilla as rg +from argilla._api._webhooks import WebhookModel +from standardwebhooks.webhooks import Webhook + +WEBHOOK_BASE_URL = os.getenv("WEBHOOK_BASE_URL", "http://localhost.org:8000") + + +def configure_webhook(client: rg.Argilla, path: str) -> Webhook: + # Configure the webhook + for wh_model in client.api.webhooks.list(): + client.api.webhooks.delete(wh_model.id) + + model = WebhookModel( + url=f"{WEBHOOK_BASE_URL}{path}", + events=["record.completed"], + description="Webhook for record completion", + ) + + webhook_model = client.api.webhooks.create(model) + webhook = Webhook(whsecret=webhook_model.secret) + + return webhook diff --git a/examples/webhooks/distilabel_trigger/dataset_setup.py b/examples/webhooks/distilabel_trigger/dataset_setup.py new file mode 100644 index 0000000000..9d9275c9ff --- /dev/null +++ b/examples/webhooks/distilabel_trigger/dataset_setup.py @@ -0,0 +1,90 @@ +import os +from datetime import datetime + +import argilla as rg +from datasets import load_dataset + +MAX_RECORDS = int(os.environ.get("MAX_RECORDS", 10)) + + +def prepare_dataset(client) -> rg.Dataset: + workspace = client.workspaces(name="argilla") + if workspace is None: + workspace = rg.Workspace(name="argilla", client=client).create() + + dataset = create_dataset(client, workspace) + load_and_upload_records(dataset) + + return dataset + + +def create_dataset(client: rg.Argilla, workspace: rg.Workspace) -> rg.Dataset: + return rg.Dataset( + client=client, + workspace=workspace, + name=f"triggers_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}", + settings=rg.Settings( + fields=[ + rg.TextField("persona"), + rg.TextField("instruction"), + rg.TextField("response1"), + rg.TextField("response2"), + ], + questions=[ + rg.LabelQuestion(name="respond", labels=["yes", "no"], required=True), + rg.TextQuestion(name="improved_instruction", required=False), + rg.TextQuestion(name="response1_rationale", required=False), + rg.TextQuestion(name="response2_rationale", required=False), + rg.RatingQuestion( + name="response1_rating", values=[1, 2, 3, 4, 5], required=False + ), + rg.RatingQuestion( + name="response2_rating", values=[1, 2, 3, 4, 5], required=False + ), + ], + ), + ).create() + + +def load_and_upload_records(dataset: rg.Dataset): + ds = load_dataset("proj-persona/PersonaHub", "instruction") + records_to_upload = [] + for sample in ds["train"].to_iterable_dataset(): + record = rg.Record( + fields={ + "persona": sample["input persona"], + "instruction": sample["synthesized text"], + "response1": "", + "response2": "", + }, + id=str(hash(sample["synthesized text"])), + ) + records_to_upload.append(record) + + if len(records_to_upload) == MAX_RECORDS: + break + + dataset.records.log(records=records_to_upload) + + +# def update_record_fields(record_id, updated_fields): +# url = f"{API_URL}/api/v1/records/{record_id}" +# headers = { +# "accept": "application/json", +# "X-Argilla-Api-Key": API_KEY, +# "Content-Type": "application/json", +# } +# data = {"fields": updated_fields} +# response = requests.patch(url, headers=headers, json=data) +# return response.json() + + +# def delete_response(response_id): +# url = f"{API_URL}/api/v1/responses/{response_id}" +# headers = { +# "accept": "application/json", +# "X-Argilla-Api-Key": API_KEY, +# "Content-Type": "application/json", +# } +# response = requests.delete(url, headers=headers) +# return response.json() diff --git a/examples/webhooks/distilabel_trigger/requirements.txt b/examples/webhooks/distilabel_trigger/requirements.txt new file mode 100644 index 0000000000..de784fe5fc --- /dev/null +++ b/examples/webhooks/distilabel_trigger/requirements.txt @@ -0,0 +1,6 @@ +argilla +distilabel +transformers +fastapi +uvicorn[standard] +standardwebhooks diff --git a/examples/webhooks/distilabel_trigger/webhook.py b/examples/webhooks/distilabel_trigger/webhook.py new file mode 100644 index 0000000000..0b004f4b5a --- /dev/null +++ b/examples/webhooks/distilabel_trigger/webhook.py @@ -0,0 +1,144 @@ +import http +import os +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import List, Literal + +import argilla as rg +from argilla._models import RecordModel +from distilabel.steps.tasks import UltraFeedback, Task +from fastapi import FastAPI, Request, HTTPException +from pydantic import BaseModel +from standardwebhooks.webhooks import WebhookVerificationError + +from configure_models import initialize_text_generation_models, initialize_ultrafeedback +from configure_webhook import configure_webhook +from dataset_setup import prepare_dataset + +# Environment variables with defaults +API_KEY = os.environ.get("ARGILLA_API_KEY", "argilla.apikey") +API_URL = os.environ.get("ARGILLA_API_URL", "http://localhost:6900") + +# Initialize Argilla client +client = rg.Argilla(api_key=API_KEY, api_url=API_URL) + +dataset = prepare_dataset(client) +text_generation_models = initialize_text_generation_models() +ultrafeedback = initialize_ultrafeedback() + +thread_pool = ThreadPoolExecutor(max_workers=2) +webhook = configure_webhook(client, "/webhook") + +server = FastAPI() + + +@server.middleware("http") +async def webhook_verify(request: Request, call_next): + try: + body = await request.body() + webhook.verify(body, dict(request.headers)) + except WebhookVerificationError as e: + raise HTTPException(status_code=http.HTTPStatus.UNAUTHORIZED, detail=str(e)) + else: + return await call_next(request) + + +class RecordCompletedEvent(BaseModel): + type: Literal["record.completed"] = "record.completed" + timestamp: datetime + data: RecordModel # Events work at API Model level, not resource model level + + +@server.post("/webhook") +async def webhook_handler(body: dict): + if body["type"] != "record.completed": + return + event = RecordCompletedEvent.model_validate(body) + future = thread_pool.submit(handle_event, event) + future.result() + + +def handle_event(event: RecordCompletedEvent) -> None: + print("Received webhook payload:", event) + + if event.data.dataset_id != dataset.id: + print("Ignoring webhook payload") + return + + record = rg.Record.from_model(event.data, dataset=dataset) + try: + respond_to_good_instructions(record, text_generation_models, ultrafeedback) + except Exception as e: + print("Error processing record", record.id, e) + + +def respond_to_record(record: rg.Record, models: List[Task]): + responses = [] + for task in models: + print(task.name) + output = list(task.process([{"instruction": record.fields["instruction"]}]))[0][ + 0 + ] + generation = output["generation"] + responses.append(generation) + return responses + + +def add_feedback_suggestions( + record: rg.Record, response_1, response_2, ultrafeedback: UltraFeedback +) -> None: + response = ultrafeedback.process( + [ + { + "instruction": "trivia questions", + "generations": [ + response_1, + response_2, + ], + } + ], + ) + response = list(response)[0][0] + ratings = response["ratings"] + rationales = response["rationales"] + + for n, (rating, rationale) in enumerate(zip(ratings, rationales), 1): + if rating is not None: + record.suggestions.add( + suggestion=rg.Suggestion( + question_name=f"response{n}_rating", + value=rating, + ) + ) + if rationale is not None: + record.suggestions.add( + suggestion=rg.Suggestion( + question_name=f"response{n}_rationale", + value=rationale, + ) + ) + + for response in record.responses["respond"]: + response.status = "draft" + + +def respond_to_good_instructions( + record: rg.Record, models: List[Task], ultrafeedback: UltraFeedback +) -> None: + if not record.responses["respond"] or record.responses["respond"][0].value != "yes": + return + + response_1, response_2 = respond_to_record(record=record, models=models) + + updated_fields = dict(record.fields) + updated_fields["response1"] = response_1 + updated_fields["response2"] = response_2 + + add_feedback_suggestions( + record=record, + response_1=response_1, + response_2=response_2, + ultrafeedback=ultrafeedback, + ) + + dataset.records.log([record])