Skip to content

Commit

Permalink
⚡️Waypoint re-subscribe (#1217)
Browse files Browse the repository at this point in the history
* add retry logic to waypoint subscribe

* catch on fetchtimeout

* add heartbeat

* raise time out error

* add trace log when an event found

* 🎨 & update heartbeat value

* add retry

* use retry for backoff

* update test

* add tenacity

* use tenacity for retry

* ⚗️ testing some timeout config

* ⚗️ timeout config

* add resubscribe logic to nats_service if TimeoutError

* remove unused

* # pylint: disable=W0718

* 🎨

* fix lock file

* remove look_back

* remove look_back

* fix tests

* add debug logging

* add exception handling

* 🎨 add some debug handling

* simplify exception handling

* debug log, update exception log

* raise exceptions

* add unit test for nats_service

* import tenacity stuff

* make class function

* add comment use imports

* update logging

* remove start time

* use exception logging

* tweak times

* fix test

* fix test

* 🎨

* 🤡

* fix test

* fix test

* update log line
  • Loading branch information
cl0ete authored Dec 11, 2024
1 parent c62460c commit 458e7a4
Show file tree
Hide file tree
Showing 4 changed files with 364 additions and 81 deletions.
31 changes: 23 additions & 8 deletions waypoint/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions waypoint/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ uvicorn = "~0.32.1"
sse-starlette = "~=2.1.3"
ddtrace = "^2.17.0"
scalar-fastapi = "^1.0.3"
tenacity = "^9.0.0"

[tool.poetry.dev-dependencies]
anyio = "~4.7.0"
Expand Down
242 changes: 174 additions & 68 deletions waypoint/services/nats_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from nats.errors import BadSubscriptionError, Error, TimeoutError
from nats.js.api import ConsumerConfig, DeliverPolicy
from nats.js.client import JetStreamContext
from nats.js.errors import FetchTimeoutError
from tenacity import (
RetryCallState,
retry,
retry_if_exception_type,
stop_never,
wait_exponential,
)

from shared.constants import NATS_STATE_STREAM, NATS_STATE_SUBJECT
from shared.log_config import get_logger
Expand All @@ -25,48 +33,80 @@ def __init__(self, jetstream: JetStreamContext):
self.js_context: JetStreamContext = jetstream

async def _subscribe(
self, *, group_id: str, wallet_id: str, topic: str, state: str, look_back: int
self,
*,
group_id: str,
wallet_id: str,
topic: str,
state: str,
start_time: str = None,
) -> JetStreamContext.PullSubscription:
try:
logger.trace(
"Subscribing to JetStream for wallet_id: {}, group_id: {}",
wallet_id,
group_id,
)
group_id = group_id or "*"
subscribe_kwargs = {
"subject": f"{NATS_STATE_SUBJECT}.{group_id}.{wallet_id}.{topic}.{state}",
"stream": NATS_STATE_STREAM,
}

# Get the current time in UTC
current_time = datetime.now(timezone.utc)
logger.debug(
"Subscribing to JetStream for wallet_id: {}, group_id: {}",
wallet_id,
group_id,
)

# Subtract look_back time from the current time
look_back_time = current_time - timedelta(seconds=look_back)
group_id = group_id or "*"
subscribe_kwargs = {
"subject": f"{NATS_STATE_SUBJECT}.{group_id}.{wallet_id}.{topic}.{state}",
"stream": NATS_STATE_STREAM,
}

# Format the time in the required format
start_time = look_back_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
config = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_TIME,
opt_start_time=start_time,
)
subscription = await self.js_context.pull_subscribe(
config=config, **subscribe_kwargs
)
config = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_TIME,
opt_start_time=start_time,
)

return subscription
# This is a custom retry decorator that will retry on TimeoutError
# and wait exponentially up to a max of 16 seconds between retries indefinitely
@retry(
retry=retry_if_exception_type(TimeoutError),
wait=wait_exponential(multiplier=1, max=16),
after=self._retry_log,
stop=stop_never,
)
async def pull_subscribe(config, **kwargs):
try:
logger.trace(
"Attempting to subscribe to JetStream for wallet_id: {}, group_id: {}",
wallet_id,
group_id,
)
subscription = await self.js_context.pull_subscribe(
config=config, **kwargs
)
logger.debug(
"Successfully subscribed to JetStream for wallet_id: {}, group_id: {}",
wallet_id,
group_id,
)
return subscription
except BadSubscriptionError as e:
logger.error("BadSubscriptionError subscribing to NATS: {}", e)
raise
except Error as e:
logger.error("Error subscribing to NATS: {}", e)
raise

except BadSubscriptionError as e:
logger.error("BadSubscriptionError subscribing to NATS: {}", e)
raise
except Error as e:
logger.error("Error subscribing to NATS: {}", e)
raise
try:
return await pull_subscribe(config, **subscribe_kwargs)
except Exception:
logger.exception("Unknown error subscribing to NATS")
logger.exception("An exception occurred subscribing to NATS")
raise

def _retry_log(self, retry_state: RetryCallState):
"""Custom logging for retry attempts."""
if retry_state.outcome.failed:
exception = retry_state.outcome.exception()
logger.warning(
"Retry attempt {} failed due to {}: {}",
retry_state.attempt_number,
type(exception).__name__,
exception,
)

@asynccontextmanager
async def process_events(
self,
Expand All @@ -86,44 +126,110 @@ async def process_events(
topic,
state,
)

subscription = await self._subscribe(
group_id=group_id,
wallet_id=wallet_id,
topic=topic,
state=state,
look_back=look_back,
)

async def event_generator():
end_time = time.time() + duration
while not stop_event.is_set():
remaining_time = end_time - time.time()
logger.trace("remaining_time: {}", remaining_time)
if remaining_time <= 0:
logger.debug("Timeout reached")
stop_event.set()
break

try:
messages = await subscription.fetch(batch=5, timeout=0.2)
for message in messages:
event = orjson.loads(message.data)
yield CloudApiWebhookEventGeneric(**event)
await message.ack()
except TimeoutError:
logger.trace("Timeout fetching messages continuing...")
await asyncio.sleep(0.1)
# Get the current time in UTC
current_time = datetime.now(timezone.utc)

# Subtract look_back time from the current time
look_back_time = current_time - timedelta(seconds=look_back)

# Format the time in the required format
start_time = look_back_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"

async def event_generator(
*,
subscription: JetStreamContext.PullSubscription,
group_id: str,
wallet_id: str,
topic: str,
state: str,
stop_event: asyncio.Event,
):
try:
end_time = time.time() + duration
while not stop_event.is_set():
remaining_time = end_time - time.time()
logger.trace("remaining_time: {}", remaining_time)
if remaining_time <= 0:
logger.debug("Timeout reached")
stop_event.set()
break

try:
messages = await subscription.fetch(
batch=5, timeout=0.5, heartbeat=0.2
)
for message in messages:
event = orjson.loads(message.data)
logger.trace("Received event: {}", event)
yield CloudApiWebhookEventGeneric(**event)
await message.ack()

except FetchTimeoutError:
# Fetch timeout, continue
logger.trace("Timeout fetching messages continuing...")
await asyncio.sleep(0.1)

except TimeoutError:
# Timeout error, resubscribe
logger.warning(
"Subscription lost connection, attempting to resubscribe..."
)
try:
await subscription.unsubscribe()
except BadSubscriptionError as e:
# If we can't unsubscribe, log the error and continue
logger.warning(
"BadSubscriptionError unsubscribing from NATS: {}", e
)

subscription = await self._subscribe(
group_id=group_id,
wallet_id=wallet_id,
topic=topic,
state=state,
start_time=start_time,
)
logger.debug("Successfully resubscribed to NATS.")

except Exception: # pylint: disable=W0718
logger.exception("Unexpected error in event generator")
stop_event.set()
raise

except asyncio.CancelledError:
logger.debug("Event generator cancelled")
stop_event.set()

try:
yield event_generator()
except asyncio.CancelledError:
logger.debug("Event generator cancelled")
stop_event.set()
subscription = await self._subscribe(
group_id=group_id,
wallet_id=wallet_id,
topic=topic,
state=state,
start_time=start_time,
)
yield event_generator(
subscription=subscription,
stop_event=stop_event,
group_id=group_id,
wallet_id=wallet_id,
topic=topic,
state=state,
)
except Exception as e: # pylint: disable=W0718
logger.exception("Unexpected error processing events: {}")
raise e

finally:
logger.trace("Closing subscription...")
await subscription.unsubscribe()
logger.debug("Subscription closed")
if subscription:
try:
logger.trace("Closing subscription...")
await subscription.unsubscribe()
logger.debug("Subscription closed")
except BadSubscriptionError as e:
logger.warning(
"BadSubscriptionError unsubscribing from NATS: {}", e
)

async def check_jetstream(self):
try:
Expand Down
Loading

0 comments on commit 458e7a4

Please sign in to comment.