diff --git a/waypoint/poetry.lock b/waypoint/poetry.lock index b2901377d..311befa4a 100644 --- a/waypoint/poetry.lock +++ b/waypoint/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -166,13 +166,13 @@ files = [ [[package]] name = "astroid" -version = "3.3.5" +version = "3.3.6" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.9.0" files = [ - {file = "astroid-3.3.5-py3-none-any.whl", hash = "sha256:a9d1c946ada25098d790e079ba2a1b112157278f3fb7e718ae6a9252f5835dc8"}, - {file = "astroid-3.3.5.tar.gz", hash = "sha256:5cfc40ae9f68311075d27ef68a4841bdc5cc7f6cf86671b49f00607d30188e2d"}, + {file = "astroid-3.3.6-py3-none-any.whl", hash = "sha256:db676dc4f3ae6bfe31cda227dc60e03438378d7a896aec57422c95634e8d722f"}, + {file = "astroid-3.3.6.tar.gz", hash = "sha256:6aaea045f938c735ead292204afdb977a36e989522b7833ef6fea94de743f442"}, ] [[package]] @@ -1847,6 +1847,21 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tomlkit" version = "0.13.2" @@ -1909,13 +1924,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "win32-setctime" -version = "1.1.0" +version = "1.2.0" description = "A small Python utility to set file creation time on Windows" optional = false python-versions = ">=3.5" files = [ - {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, - {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, + {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"}, + {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, ] [package.extras] @@ -2124,4 +2139,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "a5173ffafecb2339d8a05f8b46ae0c72aad33e1bccceb21b5ae6f264d7993a51" +content-hash = "26c68bc77024a0f0363ea6ef1652d0eb95d7b9e1201402b747b8a08277159af6" diff --git a/waypoint/pyproject.toml b/waypoint/pyproject.toml index bd570bbe6..4967543ba 100644 --- a/waypoint/pyproject.toml +++ b/waypoint/pyproject.toml @@ -21,6 +21,7 @@ uvicorn = "~0.32.1" sse-starlette = "~=2.1.3" ddtrace = "^2.17.0" scalar-fastapi = "^1.0.3" +tenacity = "^9.0.0" [tool.poetry.dev-dependencies] anyio = "~4.7.0" diff --git a/waypoint/services/nats_service.py b/waypoint/services/nats_service.py index 2f72ec13c..53e041486 100644 --- a/waypoint/services/nats_service.py +++ b/waypoint/services/nats_service.py @@ -7,6 +7,14 @@ from nats.errors import BadSubscriptionError, Error, TimeoutError from nats.js.api import ConsumerConfig, DeliverPolicy from nats.js.client import JetStreamContext +from nats.js.errors import FetchTimeoutError +from tenacity import ( + RetryCallState, + retry, + retry_if_exception_type, + stop_never, + wait_exponential, +) from shared.constants import NATS_STATE_STREAM, NATS_STATE_SUBJECT from shared.log_config import get_logger @@ -25,48 +33,80 @@ def __init__(self, jetstream: JetStreamContext): self.js_context: JetStreamContext = jetstream async def _subscribe( - self, *, group_id: str, wallet_id: str, topic: str, state: str, look_back: int + self, + *, + group_id: str, + wallet_id: str, + topic: str, + state: str, + start_time: str = None, ) -> JetStreamContext.PullSubscription: - try: - logger.trace( - "Subscribing to JetStream for wallet_id: {}, group_id: {}", - wallet_id, - group_id, - ) - group_id = group_id or "*" - subscribe_kwargs = { - "subject": f"{NATS_STATE_SUBJECT}.{group_id}.{wallet_id}.{topic}.{state}", - "stream": NATS_STATE_STREAM, - } - # Get the current time in UTC - current_time = datetime.now(timezone.utc) + logger.debug( + "Subscribing to JetStream for wallet_id: {}, group_id: {}", + wallet_id, + group_id, + ) - # Subtract look_back time from the current time - look_back_time = current_time - timedelta(seconds=look_back) + group_id = group_id or "*" + subscribe_kwargs = { + "subject": f"{NATS_STATE_SUBJECT}.{group_id}.{wallet_id}.{topic}.{state}", + "stream": NATS_STATE_STREAM, + } - # Format the time in the required format - start_time = look_back_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" - config = ConsumerConfig( - deliver_policy=DeliverPolicy.BY_START_TIME, - opt_start_time=start_time, - ) - subscription = await self.js_context.pull_subscribe( - config=config, **subscribe_kwargs - ) + config = ConsumerConfig( + deliver_policy=DeliverPolicy.BY_START_TIME, + opt_start_time=start_time, + ) - return subscription + # This is a custom retry decorator that will retry on TimeoutError + # and wait exponentially up to a max of 16 seconds between retries indefinitely + @retry( + retry=retry_if_exception_type(TimeoutError), + wait=wait_exponential(multiplier=1, max=16), + after=self._retry_log, + stop=stop_never, + ) + async def pull_subscribe(config, **kwargs): + try: + logger.trace( + "Attempting to subscribe to JetStream for wallet_id: {}, group_id: {}", + wallet_id, + group_id, + ) + subscription = await self.js_context.pull_subscribe( + config=config, **kwargs + ) + logger.debug( + "Successfully subscribed to JetStream for wallet_id: {}, group_id: {}", + wallet_id, + group_id, + ) + return subscription + except BadSubscriptionError as e: + logger.error("BadSubscriptionError subscribing to NATS: {}", e) + raise + except Error as e: + logger.error("Error subscribing to NATS: {}", e) + raise - except BadSubscriptionError as e: - logger.error("BadSubscriptionError subscribing to NATS: {}", e) - raise - except Error as e: - logger.error("Error subscribing to NATS: {}", e) - raise + try: + return await pull_subscribe(config, **subscribe_kwargs) except Exception: - logger.exception("Unknown error subscribing to NATS") + logger.exception("An exception occurred subscribing to NATS") raise + def _retry_log(self, retry_state: RetryCallState): + """Custom logging for retry attempts.""" + if retry_state.outcome.failed: + exception = retry_state.outcome.exception() + logger.warning( + "Retry attempt {} failed due to {}: {}", + retry_state.attempt_number, + type(exception).__name__, + exception, + ) + @asynccontextmanager async def process_events( self, @@ -86,44 +126,110 @@ async def process_events( topic, state, ) - - subscription = await self._subscribe( - group_id=group_id, - wallet_id=wallet_id, - topic=topic, - state=state, - look_back=look_back, - ) - - async def event_generator(): - end_time = time.time() + duration - while not stop_event.is_set(): - remaining_time = end_time - time.time() - logger.trace("remaining_time: {}", remaining_time) - if remaining_time <= 0: - logger.debug("Timeout reached") - stop_event.set() - break - - try: - messages = await subscription.fetch(batch=5, timeout=0.2) - for message in messages: - event = orjson.loads(message.data) - yield CloudApiWebhookEventGeneric(**event) - await message.ack() - except TimeoutError: - logger.trace("Timeout fetching messages continuing...") - await asyncio.sleep(0.1) + # Get the current time in UTC + current_time = datetime.now(timezone.utc) + + # Subtract look_back time from the current time + look_back_time = current_time - timedelta(seconds=look_back) + + # Format the time in the required format + start_time = look_back_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + + async def event_generator( + *, + subscription: JetStreamContext.PullSubscription, + group_id: str, + wallet_id: str, + topic: str, + state: str, + stop_event: asyncio.Event, + ): + try: + end_time = time.time() + duration + while not stop_event.is_set(): + remaining_time = end_time - time.time() + logger.trace("remaining_time: {}", remaining_time) + if remaining_time <= 0: + logger.debug("Timeout reached") + stop_event.set() + break + + try: + messages = await subscription.fetch( + batch=5, timeout=0.5, heartbeat=0.2 + ) + for message in messages: + event = orjson.loads(message.data) + logger.trace("Received event: {}", event) + yield CloudApiWebhookEventGeneric(**event) + await message.ack() + + except FetchTimeoutError: + # Fetch timeout, continue + logger.trace("Timeout fetching messages continuing...") + await asyncio.sleep(0.1) + + except TimeoutError: + # Timeout error, resubscribe + logger.warning( + "Subscription lost connection, attempting to resubscribe..." + ) + try: + await subscription.unsubscribe() + except BadSubscriptionError as e: + # If we can't unsubscribe, log the error and continue + logger.warning( + "BadSubscriptionError unsubscribing from NATS: {}", e + ) + + subscription = await self._subscribe( + group_id=group_id, + wallet_id=wallet_id, + topic=topic, + state=state, + start_time=start_time, + ) + logger.debug("Successfully resubscribed to NATS.") + + except Exception: # pylint: disable=W0718 + logger.exception("Unexpected error in event generator") + stop_event.set() + raise + + except asyncio.CancelledError: + logger.debug("Event generator cancelled") + stop_event.set() try: - yield event_generator() - except asyncio.CancelledError: - logger.debug("Event generator cancelled") - stop_event.set() + subscription = await self._subscribe( + group_id=group_id, + wallet_id=wallet_id, + topic=topic, + state=state, + start_time=start_time, + ) + yield event_generator( + subscription=subscription, + stop_event=stop_event, + group_id=group_id, + wallet_id=wallet_id, + topic=topic, + state=state, + ) + except Exception as e: # pylint: disable=W0718 + logger.exception("Unexpected error processing events: {}") + raise e + finally: - logger.trace("Closing subscription...") - await subscription.unsubscribe() - logger.debug("Subscription closed") + if subscription: + try: + logger.trace("Closing subscription...") + await subscription.unsubscribe() + logger.debug("Subscription closed") + except BadSubscriptionError as e: + logger.warning( + "BadSubscriptionError unsubscribing from NATS: {}", e + ) async def check_jetstream(self): try: diff --git a/waypoint/tests/services/test_nats_service.py b/waypoint/tests/services/test_nats_service.py index dc9a61805..e7872e6fe 100644 --- a/waypoint/tests/services/test_nats_service.py +++ b/waypoint/tests/services/test_nats_service.py @@ -1,6 +1,6 @@ import asyncio import json -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from nats.aio.client import Client as NATS @@ -8,6 +8,8 @@ from nats.errors import BadSubscriptionError, Error, TimeoutError from nats.js.api import ConsumerConfig, DeliverPolicy from nats.js.client import JetStreamContext +from nats.js.errors import FetchTimeoutError +from tenacity import RetryCallState from shared.constants import NATS_STATE_STREAM, NATS_STATE_SUBJECT from shared.models.webhook_events import CloudApiWebhookEventGeneric @@ -64,7 +66,7 @@ async def test_nats_events_processor_subscribe( wallet_id="wallet_id", topic="proofs", state="done", - look_back=300, + start_time="2024-10-24T09:17:17.998149541Z", ) mock_nats_client.pull_subscribe.assert_called_once_with( subject=f"{NATS_STATE_SUBJECT}.group_id.wallet_id.proofs.done", @@ -88,7 +90,7 @@ async def test_nats_events_processor_subscribe_error( wallet_id="wallet_id", topic="proofs", state="done", - look_back=300, + start_time="2024-10-24T09:17:17.998149541Z", ) @@ -164,14 +166,14 @@ async def test_process_events_cancelled_error( @pytest.mark.anyio -async def test_process_events_timeout_error( +async def test_process_events_fetch_timeout_error( mock_nats_client, # pylint: disable=redefined-outer-name ): processor = NatsEventsProcessor(mock_nats_client) mock_subscription = AsyncMock() mock_nats_client.pull_subscribe.return_value = mock_subscription - mock_subscription.fetch.side_effect = TimeoutError + mock_subscription.fetch.side_effect = FetchTimeoutError stop_event = asyncio.Event() async with processor.process_events( @@ -190,6 +192,127 @@ async def test_process_events_timeout_error( assert stop_event.is_set() +@pytest.mark.anyio +async def test_process_events_timeout_error( + mock_nats_client, +): # pylint: disable=redefined-outer-name + processor = NatsEventsProcessor(mock_nats_client) + mock_subscription = AsyncMock() + mock_nats_client.pull_subscribe.return_value = mock_subscription + + # Mock fetch to raise TimeoutError + mock_subscription.fetch.side_effect = TimeoutError + + # Mock the _subscribe method to simulate resubscribe + mock_resubscribe = AsyncMock(return_value=mock_subscription) + processor._subscribe = mock_resubscribe # pylint: disable=protected-access + + stop_event = asyncio.Event() + + async with processor.process_events( + group_id="group_id", + wallet_id="wallet_id", + topic="test_topic", + state="state", + stop_event=stop_event, + duration=2, + ) as event_generator: + events = [] + async for event in event_generator: + events.append(event) + + # Assert no events are yielded + assert len(events) == 0 + + # Assert fetch was called + assert mock_subscription.fetch.called + + # Assert _subscribe was called again after TimeoutError + assert mock_resubscribe.called + + +@pytest.mark.anyio +async def test_process_events_bad_subscription_error_on_unsubscribe( + mock_nats_client, # pylint: disable=redefined-outer-name +): + processor = NatsEventsProcessor(mock_nats_client) + mock_subscription = AsyncMock() + mock_nats_client.pull_subscribe.return_value = mock_subscription + + # Mock fetch to raise TimeoutError to trigger unsubscribe logic + mock_subscription.fetch.side_effect = TimeoutError + + # Mock unsubscribe to raise BadSubscriptionError + mock_subscription.unsubscribe.side_effect = BadSubscriptionError("Test error") + + # Mock the _subscribe method to simulate resubscribe + mock_resubscribe = AsyncMock(return_value=mock_subscription) + processor._subscribe = mock_resubscribe # pylint: disable=protected-access + + stop_event = asyncio.Event() + + async with processor.process_events( + group_id="group_id", + wallet_id="wallet_id", + topic="test_topic", + state="state", + stop_event=stop_event, + duration=2, + ) as event_generator: + events = [] + async for event in event_generator: + events.append(event) + + # Assert no events are yielded + assert len(events) == 0 + + # Assert fetch was called + assert mock_subscription.fetch.called + + # Assert unsubscribe was called and raised BadSubscriptionError + assert mock_subscription.unsubscribe.called + + # Assert _subscribe was called again after the unsubscribe error + assert mock_resubscribe.called + + +@pytest.mark.anyio +async def test_process_events_base_exception( + mock_nats_client, # pylint: disable=redefined-outer-name +): + processor = NatsEventsProcessor(mock_nats_client) + mock_subscription = AsyncMock() + mock_nats_client.pull_subscribe.return_value = mock_subscription + + # Mock fetch to raise a generic exception + mock_subscription.fetch.side_effect = Exception("Test base exception") + + stop_event = asyncio.Event() + + # Process events + with pytest.raises(Exception): + async with processor.process_events( + group_id="group_id", + wallet_id="wallet_id", + topic="test_topic", + state="state", + stop_event=stop_event, + duration=2, + ) as event_generator: + events = [] + async for event in event_generator: + events.append(event) + + # Assert no events are yielded due to the base exception + assert len(events) == 0 + + # Verify unsubscribe was attempted + mock_subscription.unsubscribe.assert_called_once() + + # Verify fetch was called once before raising the exception + assert mock_subscription.fetch.call_count == 1 + + @pytest.mark.anyio async def test_check_jetstream_working( mock_nats_client, # pylint: disable=redefined-outer-name @@ -235,3 +358,41 @@ async def test_check_jetstream_exception( assert result == {"is_working": False} mock_nats_client.account_info.assert_called_once() + + +class MockFuture: + """A mock class to simulate the behavior of a Future object.""" + + def __init__(self, exception=None): + self._exception = exception + + @property + def failed(self): + return self._exception is not None + + def exception(self): + return self._exception + + +def test_retry_log(mock_nats_client): # pylint: disable=redefined-outer-name + processor = NatsEventsProcessor(mock_nats_client) + # Mock a retry state + mock_retry_state = MagicMock(spec=RetryCallState) + + # Mock the outcome attribute with a Future-like object + mock_retry_state.outcome = MockFuture(exception=ValueError("Test retry exception")) + mock_retry_state.attempt_number = 3 # Retry attempt number + + # Patch the logger to capture log calls + with patch("waypoint.services.nats_service.logger") as mock_logger: + processor._retry_log( # pylint: disable=protected-access + retry_state=mock_retry_state + ) + + # Assert that logger.warning was called with the expected message + mock_logger.warning.assert_called_once_with( + "Retry attempt {} failed due to {}: {}", + 3, + "ValueError", + mock_retry_state.outcome.exception(), + )