Skip to content

Commit

Permalink
rework nats init
Browse files Browse the repository at this point in the history
  • Loading branch information
cl0ete committed Nov 5, 2024
1 parent 8f82b99 commit dfdb689
Showing 1 changed file with 42 additions and 12 deletions.
54 changes: 42 additions & 12 deletions shared/services/nats_jetstream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, AsyncGenerator

import nats
Expand All @@ -13,29 +14,58 @@

async def init_nats_client() -> AsyncGenerator[JetStreamContext, Any]:
"""
Initialize a connection to the NATS server.
Initialize a connection to the NATS server with automatic reconnection handling.
"""
logger.debug("Initialise NATS server ...")

connect_kwargs = {"servers": [NATS_SERVER]}
connect_kwargs = {
"servers": [NATS_SERVER],
"reconnect_time_wait": 0.5, # Shorter wait time for faster reconnection
"max_reconnect_attempts": -1, # Infinite reconnection attempts
"error_cb": error_callback,
"disconnected_cb": disconnected_callback,
"reconnected_cb": reconnected_callback,
"closed_cb": closed_callback,
}

if NATS_CREDS_FILE:
connect_kwargs["user_credentials"] = NATS_CREDS_FILE
else:
logger.warning("No NATS credentials file found, assuming local development")

logger.info("Connecting to NATS server with kwargs {} ...", connect_kwargs)
try:
nats_client: NATS = await nats.connect(**connect_kwargs)

except (ErrConnectionClosed, ErrTimeout, ErrNoServers) as e:
logger.error("Error connecting to NATS server: {}", e)
raise e
logger.debug("Connected to NATS server")
while True:
try:
nats_client: NATS = await nats.connect(**connect_kwargs)
break
except (ErrConnectionClosed, ErrTimeout, ErrNoServers) as e:
logger.error("Error connecting to NATS server: {}", e)
await asyncio.sleep(1.0) # Wait before retrying
continue
except Exception as e:
logger.error("Unexpected error connecting to NATS server: {}", e)
raise e

logger.debug("Connected to NATS server")
jetstream: JetStreamContext = nats_client.jetstream()
logger.debug("Yielding JetStream context ...")
yield jetstream

logger.debug("Closing NATS connection ...")
await nats_client.close()
logger.debug("NATS connection closed")
try:
yield jetstream
finally:
logger.debug("Closing NATS connection ...")
await nats_client.close()
logger.debug("NATS connection closed")

async def error_callback(e):
logger.error("NATS error: {}", str(e))

async def disconnected_callback():
logger.warning("Disconnected from NATS server")

async def reconnected_callback():
logger.info("Reconnected to NATS server")

async def closed_callback():
logger.warning("NATS connection closed")

0 comments on commit dfdb689

Please sign in to comment.