Skip to content

Commit

Permalink
refactor: webhook server url setup
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Sep 18, 2024
1 parent 3464d85 commit 8c95e3f
Showing 1 changed file with 52 additions and 21 deletions.
73 changes: 52 additions & 21 deletions argilla/src/argilla/webhooks/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,56 @@
# limitations under the License.

import os
import warnings
from typing import TYPE_CHECKING, Optional, Callable, Union, List

import argilla as rg
from argilla import Argilla
from argilla.webhooks._resource import Webhook
from argilla.webhooks._handler import WebhookHandler
from argilla.webhooks._resource import Webhook

if TYPE_CHECKING:
from fastapi import FastAPI

WEBHOOK_SERVER_URL = os.getenv("WEBHOOK_SERVER_URL", "http://127.0.0.1:8000")

def _compute_webhook_server_url() -> str:
"""
Compute the webhook server URL.
Returns:
str: The webhook server URL. If the environment variable `SPACE_HOST` is set, it will return `https://<SPACE_HOST>`.
Otherwise, it will return the value of the environment variable `WEBHOOK_SERVER_URL` or `http://127.0.0.1:8000`.
"""
if space_host := os.getenv("SPACE_HOST"):
return f"https://{space_host}"
return os.getenv("WEBHOOK_SERVER_URL", "http://127.0.0.1:8000")


def _webhook_url_for_func(func: Callable) -> str:
return f"{WEBHOOK_SERVER_URL}/{func.__name__}"
"""
Compute the full webhook URL for a given function.
Parameters:
func (Callable): The function to compute the webhook URL for.
Returns:
str: The full webhook URL.
"""
webhook_server_url = _compute_webhook_server_url()

return f"{webhook_server_url}/{func.__name__}"


def get_webhook_server() -> "FastAPI":
"""
Get the webhook server.
Returns:
FastAPI: The webhook server.
"""
from fastapi import FastAPI

global _server
Expand All @@ -40,6 +72,13 @@ def get_webhook_server() -> "FastAPI":


def set_webhook_server(app: "FastAPI"):
"""
Set the webhook server. It can only be set once.
Parameters:
app (FastAPI): The webhook server.
"""
global _server

if _server:
Expand All @@ -64,29 +103,21 @@ def webhook_listener(
if isinstance(events, str):
events = [events]

def decorator(func: Callable) -> Callable:
webhook_url = _webhook_url_for_func(func)
def wrapper(func: Callable) -> Callable:
webhook = Webhook(
url=_webhook_url_for_func(func),
events=events,
description=description or f"Webhook for {func.__name__}",
).create()

webhook = None
for argilla_webhook in client.webhooks:
if webhook_url == argilla_webhook.url:
webhook = argilla_webhook
break

if webhook:
webhook.description = description or webhook.description
webhook.events = events
webhook.update()
else:
webhook = Webhook(
url=webhook_url,
events=events,
description=description or f"Webhook for {func.__name__}",
).create()
if argilla_webhook.url == webhook.url and argilla_webhook.id != webhook.id:
warnings.warn(f"Deleting existing webhook with URL {argilla_webhook.url}: {argilla_webhook}")
argilla_webhook.delete()

request_handler = WebhookHandler(webhook).handle(func, raw_event)
server.post(f"/{func.__name__}", tags=["Argilla Webhooks"])(request_handler)

return request_handler

return decorator
return wrapper

0 comments on commit 8c95e3f

Please sign in to comment.