Skip to content

Commit

Permalink
⚡ Waypoint start time (#1137)
Browse files Browse the repository at this point in the history
* imports

* add look back param

* format date time string

* add consumer config with start time

* add config to subscription

* convert to str

* update test

* 🎨

* add look_back query param

* add look back query param

* remove env var

* refactor code

* update tests

* update default value of look_back

* add look_back to wait for state

* 🎨

* remove import
  • Loading branch information
cl0ete authored Oct 25, 2024
1 parent ca5433d commit 87557d6
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 29 deletions.
6 changes: 6 additions & 0 deletions app/routes/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ async def get_sse_subscribe_event_with_field_and_state(
field_id: str,
desired_state: str,
group_id: Optional[str] = group_id_query,
look_back: Optional[int] = Query(
default=300, description="Number of seconds to look back for events"
),
auth: AcaPyAuthVerified = Depends(acapy_auth_verified),
) -> StreamingResponse:
"""
Expand Down Expand Up @@ -63,6 +66,8 @@ async def get_sse_subscribe_event_with_field_and_state(
The ID of the field subscribing to the events.
desired_state:
The desired state to be reached.
look_back:
Number of seconds to look back for events before subscribing.
"""
logger.bind(
body={
Expand All @@ -87,6 +92,7 @@ async def get_sse_subscribe_event_with_field_and_state(
field=field,
field_id=field_id,
desired_state=desired_state,
look_back=look_back,
),
media_type="text/event-stream",
)
5 changes: 4 additions & 1 deletion app/services/event_handling/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ async def sse_subscribe_event_with_field_and_state(
field: str,
field_id: str,
desired_state: str,
look_back: int = 300,
) -> AsyncGenerator[str, None]:
"""
Subscribe to server-side events for a specific wallet ID and topic.
Expand All @@ -56,8 +57,10 @@ async def sse_subscribe_event_with_field_and_state(
)

params = {}
if group_id: # Optional param
if group_id: # Optional params
params["group_id"] = group_id
if look_back:
params["look_back"] = look_back

try:
async with RichAsyncClient(timeout=event_timeout) as client:
Expand Down
2 changes: 2 additions & 0 deletions app/tests/routes/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def test_get_sse_subscribe_event_with_field_and_state(
desired_state=state,
group_id=group_id,
auth=mock_auth,
look_back=300,
)

assert response.media_type == "text/event-stream"
Expand All @@ -67,4 +68,5 @@ async def test_get_sse_subscribe_event_with_field_and_state(
field=field,
field_id=field_id,
desired_state=state,
look_back=300,
)
2 changes: 1 addition & 1 deletion app/tests/services/event_handling/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ async def test_sse_subscribe_event_with_field_and_state_success(
patch_yield_lines_with_disconnect_check, # pylint: disable=redefined-outer-name
group_id: Optional[str],
):
expected_params = {}
expected_params = {"look_back": 300}
if group_id: # Optional param
expected_params["group_id"] = group_id

Expand Down
2 changes: 1 addition & 1 deletion app/tests/util/sse_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def wait_for_event(
"""
Start listening for SSE events. When an event is received that matches the specified parameters.
"""
url = f"{waypoint_base_url}/{self.wallet_id}/{self.topic}/{field}/{field_id}/{desired_state}"
url = f"{waypoint_base_url}/{self.wallet_id}/{self.topic}/{field}/{field_id}/{desired_state}?look_back=5"

timeout = Timeout(timeout)
async with RichAsyncClient(timeout=timeout) as client:
Expand Down
7 changes: 7 additions & 0 deletions waypoint/routers/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def nats_event_stream_generator(
field_id: str,
desired_state: str,
group_id: Optional[str],
look_back: Optional[int],
nats_processor: NatsEventsProcessor,
) -> AsyncGenerator[str, None]:
"""
Expand All @@ -53,6 +54,7 @@ async def nats_event_stream_generator(
topic=topic,
stop_event=stop_event,
duration=SSE_TIMEOUT,
look_back=look_back,
) as event_generator:
background_tasks.add_task(check_disconnect, request, stop_event)

Expand Down Expand Up @@ -94,6 +96,10 @@ async def sse_wait_for_event_with_field_and_state(
group_id: Optional[str] = Query(
default=None, description="Group ID to which the wallet belongs"
),
look_back: Optional[int] = Query(
default=300,
description="Number of seconds to look back for events before subscribing",
),
nats_processor: NatsEventsProcessor = Depends(
Provide[Container.nats_events_processor]
),
Expand Down Expand Up @@ -121,6 +127,7 @@ async def sse_wait_for_event_with_field_and_state(
field_id=field_id,
desired_state=desired_state,
group_id=group_id,
look_back=look_back,
nats_processor=nats_processor,
)

Expand Down
50 changes: 33 additions & 17 deletions waypoint/services/nats_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import time
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone

import orjson
from nats.errors import BadSubscriptionError, Error, TimeoutError
from nats.js.api import ConsumerConfig, DeliverPolicy
from nats.js.client import JetStreamContext

from shared.constants import NATS_STREAM, NATS_SUBJECT
Expand All @@ -23,24 +25,35 @@ def __init__(self, jetstream: JetStreamContext):
self.js_context: JetStreamContext = jetstream

async def _subscribe(
self, group_id: str, wallet_id: str
self, group_id: str, wallet_id: str, look_back: int
) -> JetStreamContext.PullSubscription:
try:
logger.debug("Subscribing to JetStream...")
if group_id:

logger.trace("Tenant-admin call got group_id: {}", group_id)
subscribe_kwargs = {
"subject": f"{NATS_SUBJECT}.{group_id}.{wallet_id}",
"stream": NATS_STREAM,
}
else:
logger.trace("Tenant call got no group_id")
subscribe_kwargs = {
"subject": f"{NATS_SUBJECT}.*.{wallet_id}",
"stream": NATS_STREAM,
}
subscription = await self.js_context.pull_subscribe(**subscribe_kwargs)
logger.trace(
"Subscribing to JetStream for wallet_id: {}, group_id: {}",
wallet_id,
group_id,
)
group_id = group_id or "*"
subscribe_kwargs = {
"subject": f"{NATS_SUBJECT}.{group_id}.{wallet_id}",
"stream": NATS_STREAM,
}

# Get the current time in UTC
current_time = datetime.now(timezone.utc)

# Subtract 30 seconds
look_back_time = current_time - timedelta(seconds=look_back)

# Format the time in the required format
start_time = look_back_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
config = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_TIME,
opt_start_time=start_time,
)
subscription = await self.js_context.pull_subscribe(
config=config, **subscribe_kwargs
)

return subscription

Expand All @@ -62,6 +75,7 @@ async def process_events(
topic: str,
stop_event: asyncio.Event,
duration: int = 10,
look_back: int = 300,
):
logger.debug(
"Processing events for group {} and wallet {} on topic {}",
Expand All @@ -70,7 +84,9 @@ async def process_events(
topic,
)

subscription = await self._subscribe(group_id=group_id, wallet_id=wallet_id)
subscription = await self._subscribe(
group_id=group_id, wallet_id=wallet_id, look_back=look_back
)

async def event_generator():
end_time = time.time() + duration
Expand Down
5 changes: 5 additions & 0 deletions waypoint/tests/routers/test_waypoint_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def mock_event_generator():
desired_state=desired_state,
group_id=group_id,
nats_processor=nats_processor_mock,
look_back=300,
):
events.append(event)

Expand Down Expand Up @@ -131,6 +132,7 @@ async def mock_event_generator():
field_id=field_id,
desired_state=desired_state,
group_id=group_id,
look_back=300,
nats_processor=nats_processor_mock,
):
pass
Expand Down Expand Up @@ -163,6 +165,7 @@ async def mock_event_generator():
field_id="some_field_id",
desired_state="some_state",
group_id="some_group",
look_back=300,
nats_processor=nats_processor_mock,
)

Expand Down Expand Up @@ -196,6 +199,7 @@ async def test_sse_event_stream(
field_id=field_id,
desired_state=desired_state,
group_id=group_id,
look_back=300,
nats_processor=nats_processor_mock,
)

Expand All @@ -211,5 +215,6 @@ async def test_sse_event_stream(
field_id=field_id,
desired_state=desired_state,
group_id=group_id,
look_back=300,
nats_processor=nats_processor_mock,
)
25 changes: 16 additions & 9 deletions waypoint/tests/services/test_nats_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from nats.aio.client import Client as NATS
from nats.aio.errors import ErrConnectionClosed, ErrNoServers, ErrTimeout
from nats.errors import BadSubscriptionError, Error, TimeoutError
from nats.js.api import ConsumerConfig, DeliverPolicy
from nats.js.client import JetStreamContext

from shared.constants import NATS_STREAM, NATS_SUBJECT
Expand Down Expand Up @@ -53,14 +54,20 @@ async def test_nats_events_processor_subscribe(
mock_nats_client.pull_subscribe.return_value = AsyncMock(
spec=JetStreamContext.PullSubscription
)

subscription = await processor._subscribe( # pylint: disable=protected-access
"group_id", "wallet_id"
)
mock_nats_client.pull_subscribe.assert_called_once_with(
subject=f"{NATS_SUBJECT}.group_id.wallet_id", stream=NATS_STREAM
)
assert isinstance(subscription, JetStreamContext.PullSubscription)
with patch("waypoint.services.nats_service.ConsumerConfig") as mock_config:
mock_config.return_value = ConsumerConfig(
deliver_policy=DeliverPolicy.BY_START_TIME,
opt_start_time="2024-10-24T09:17:17.998149541Z",
)
subscription = await processor._subscribe( # pylint: disable=protected-access
"group_id", "wallet_id", 300
)
mock_nats_client.pull_subscribe.assert_called_once_with(
subject=f"{NATS_SUBJECT}.group_id.wallet_id",
stream=NATS_STREAM,
config=mock_config.return_value,
)
assert isinstance(subscription, JetStreamContext.PullSubscription)


@pytest.mark.anyio
Expand All @@ -72,7 +79,7 @@ async def test_nats_events_processor_subscribe_error(
mock_nats_client.pull_subscribe.side_effect = exception

with pytest.raises(exception):
await processor._subscribe("group_id", "wallet_id")
await processor._subscribe("group_id", "wallet_id", 300)


@pytest.mark.anyio
Expand Down

0 comments on commit 87557d6

Please sign in to comment.