Skip to content

Commit

Permalink
Add HTTP retry handling into task SDK api.client
Browse files Browse the repository at this point in the history
  • Loading branch information
jscheffl committed Dec 20, 2024
1 parent b130758 commit bc58ae6
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
1 change: 1 addition & 0 deletions task_sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"msgspec>=0.18.6",
"psutil>=6.1.0",
"structlog>=24.4.0",
"retryhttp>=1.2.0",
]
classifiers = [
"Framework :: Apache Airflow",
Expand Down
25 changes: 25 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import os
import sys
import uuid
from http import HTTPStatus
Expand All @@ -26,6 +27,8 @@
import msgspec
import structlog
from pydantic import BaseModel
from retryhttp import retry, wait_retry_after
from tenacity import wait_random_exponential
from uuid6 import uuid7

from airflow.sdk import __version__
Expand Down Expand Up @@ -263,6 +266,14 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json={"text": "Hello, world!"})


# Config options for SDK how retries on HTTP requests should be handled
# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07, 3:37 and fails after 5:07min
# As long as there is no other config facility in SDK we use ENV for the moment
API_RETRIES = int(os.getenv("AIRFLOW__WORKERS__API_RETRIES", 10))
API_RETRY_WAIT_MIN = int(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN", 1))
API_RETRY_WAIT_MAX = int(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX", 90))


class Client(httpx.Client):
def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any):
if (not base_url) ^ dry_run:
Expand All @@ -284,6 +295,20 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
**kwargs,
)

_default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX)

@retry(
reraise=True,
max_attempt_number=API_RETRIES,
wait_server_errors=_default_wait,
wait_network_errors=_default_wait,
wait_timeouts=_default_wait,
wait_rate_limited=wait_retry_after(fallback=_default_wait), # No infinite timeout on HTTP 429
)
def request(self, *args, **kwargs):
"""Implement a convenience for httpx.Client.request with a retry layer."""
return super().request(*args, **kwargs)

# We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all
# methods on one object prefixed with the object type (`.task_instances.update` rather than
# `task_instance_update` etc.)
Expand Down
117 changes: 117 additions & 0 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
from unittest import mock

import httpx
import pytest
Expand Down Expand Up @@ -82,6 +83,122 @@ def handle_request(request: httpx.Request) -> httpx.Response:
assert err.value.args == ("Not found",)
assert err.value.detail is None

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_unrecoverable_error(self, mock_sleep):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 11,
httpx.Response(200, json={"detail": "Recovered from error - but will fail before"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]

def mock_handle_request(request: httpx.Request) -> httpx.Response:
return responses.pop(0)

client = Client(
base_url=None,
dry_run=True,
token="",
mounts={"'http://": httpx.MockTransport(mock_handle_request)},
)

with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)
assert len(responses) == 3
assert mock_sleep.call_count == 9

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_recovered(self, mock_sleep):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 3,
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]

def mock_handle_request(request: httpx.Request) -> httpx.Response:
return responses.pop(0)

client = Client(
base_url=None,
dry_run=True,
token="",
mounts={"'http://": httpx.MockTransport(mock_handle_request)},
)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 3

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_overload(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(429, text="I am really busy atm, please back-off", headers={"Retry-After": "37"}),
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]

def mock_handle_request(request: httpx.Request) -> httpx.Response:
return responses.pop(0)

client = Client(
base_url=None,
dry_run=True,
token="",
mounts={"'http://": httpx.MockTransport(mock_handle_request)},
)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 1
assert mock_sleep.call_args[0][0] == 37

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_non_retry_error(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(422, json={"detail": "Somehow this is a bad request"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]

def mock_handle_request(request: httpx.Request) -> httpx.Response:
return responses.pop(0)

client = Client(
base_url=None,
dry_run=True,
token="",
mounts={"'http://": httpx.MockTransport(mock_handle_request)},
)

with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert len(responses) == 1
assert mock_sleep.call_count == 0
assert err.value.args == ("Somehow this is a bad request",)

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_ok(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]

def mock_handle_request(request: httpx.Request) -> httpx.Response:
return responses.pop(0)

client = Client(
base_url=None,
dry_run=True,
token="",
mounts={"'http://": httpx.MockTransport(mock_handle_request)},
)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 0


def make_client(transport: httpx.MockTransport) -> Client:
"""Get a client with a custom transport"""
Expand Down

0 comments on commit bc58ae6

Please sign in to comment.