Skip to content

Commit 9317cfb

Browse files
Abishek10Abishek Kumar
authored andcommitted
feat: add WebSocket support with @app.websocket decorator (#41)
* feat: add WebSocket support with @app.websocket decorator Add bidirectional streaming support for agent invocations: - Add @app.websocket decorator for registering WebSocket handlers - Register /ws endpoint in BedrockAgentCoreApp constructor - Implement _handle_websocket method with proper error handling - Follow existing decorator patterns (entrypoint, ping) for consistency - Support WebSocketDisconnect for graceful connection handling - Integrate with existing RequestContext for session management test: add unit and integration tests for WebSocket support Add comprehensive test coverage for WebSocket decorator: Unit tests (10 tests): - WebSocket route initialization and registration - Decorator functionality and handler storage - Basic send/receive communication - Context integration with session IDs - Exception handling and error cases - Multiple message handling - Graceful disconnect handling - Custom request headers via context - Streaming data functionality Integration test: - End-to-end WebSocket echo server - Streaming multiple messages - Session ID propagation through headers - Real WebSocket client connection testing All tests pass successfully. * fix: remove exception details from WebSocket close reason Don't send exception messages to clients when closing WebSocket connections with code 1011. This prevents leaking internal error details to clients. Changes: - Remove reason parameter from websocket.close() calls - Exception details are still logged server-side - Tests continue to pass as they only check for disconnect exceptions --------- Co-authored-by: Abishek Kumar <[email protected]>
1 parent 95bbfa4 commit 9317cfb

File tree

6 files changed

+452
-3
lines changed

6 files changed

+452
-3
lines changed

.github/workflows/integration-testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ jobs:
130130
- name: Install dependencies
131131
run: |
132132
pip install -e .
133-
pip install --no-cache-dir pytest requests strands-agents uvicorn httpx starlette
133+
pip install --no-cache-dir pytest requests strands-agents uvicorn httpx starlette websockets
134134
135135
- name: Run integration tests
136136
env:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ dev = [
143143
"pytest-asyncio>=0.24.0",
144144
"pytest-cov>=6.0.0",
145145
"ruff>=0.12.0",
146+
"websockets>=14.1",
146147
"wheel>=0.45.1",
147148
"strands-agents>=1.18.0",
148149
]

src/bedrock_agentcore/runtime/app.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from starlette.applications import Starlette
1818
from starlette.middleware import Middleware
1919
from starlette.responses import JSONResponse, Response, StreamingResponse
20-
from starlette.routing import Route
20+
from starlette.routing import Route, WebSocketRoute
2121
from starlette.types import Lifespan
22+
from starlette.websockets import WebSocket, WebSocketDisconnect
2223

2324
from .context import BedrockAgentCoreContext, RequestContext
2425
from .models import (
@@ -90,6 +91,7 @@ def __init__(
9091
"""
9192
self.handlers: Dict[str, Callable] = {}
9293
self._ping_handler: Optional[Callable] = None
94+
self._websocket_handler: Optional[Callable] = None
9395
self._active_tasks: Dict[int, Dict[str, Any]] = {}
9496
self._task_counter_lock: threading.Lock = threading.Lock()
9597
self._forced_ping_status: Optional[PingStatus] = None
@@ -98,6 +100,7 @@ def __init__(
98100
routes = [
99101
Route("/invocations", self._handle_invocation, methods=["POST"]),
100102
Route("/ping", self._handle_ping, methods=["GET"]),
103+
WebSocketRoute("/ws", self._handle_websocket),
101104
]
102105
super().__init__(routes=routes, lifespan=lifespan, middleware=middleware)
103106
self.debug = debug # Set after super().__init__ to avoid override
@@ -135,6 +138,24 @@ def ping(self, func: Callable) -> Callable:
135138
self._ping_handler = func
136139
return func
137140

141+
def websocket(self, func: Callable) -> Callable:
142+
"""Decorator to register a WebSocket handler at /ws endpoint.
143+
144+
Args:
145+
func: The function to register as WebSocket handler
146+
147+
Returns:
148+
The decorated function
149+
150+
Example:
151+
@app.websocket
152+
async def handler(websocket, context):
153+
await websocket.accept()
154+
# ... handle messages ...
155+
"""
156+
self._websocket_handler = func
157+
return func
158+
138159
def async_task(self, func: Callable) -> Callable:
139160
"""Decorator to track async tasks for ping status.
140161
@@ -390,6 +411,29 @@ def _handle_ping(self, request):
390411
self.logger.exception("Ping endpoint failed")
391412
return JSONResponse({"status": PingStatus.HEALTHY.value, "time_of_last_update": int(time.time())})
392413

414+
async def _handle_websocket(self, websocket: WebSocket):
415+
"""Handle WebSocket connections."""
416+
request_context = self._build_request_context(websocket)
417+
418+
try:
419+
handler = self._websocket_handler
420+
if not handler:
421+
self.logger.error("No WebSocket handler defined")
422+
await websocket.close(code=1011)
423+
return
424+
425+
self.logger.debug("WebSocket connection established")
426+
await handler(websocket, request_context)
427+
428+
except WebSocketDisconnect:
429+
self.logger.debug("WebSocket disconnected")
430+
except Exception:
431+
self.logger.exception("WebSocket handler failed")
432+
try:
433+
await websocket.close(code=1011)
434+
except Exception:
435+
pass
436+
393437
def run(self, port: int = 8080, host: Optional[str] = None, **kwargs):
394438
"""Start the Bedrock AgentCore server.
395439

tests/bedrock_agentcore/runtime/test_app.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,3 +1949,202 @@ def __init__(self):
19491949
assert result["X-Amzn-Bedrock-AgentCore-Runtime-Custom-Empty"] == ""
19501950
assert result["X-Amzn-Bedrock-AgentCore-Runtime-Custom-Valid"] == "valid-value"
19511951
assert len(result) == 3
1952+
1953+
1954+
class TestWebSocketSupport:
1955+
"""Test WebSocket decorator and handler functionality."""
1956+
1957+
def test_websocket_initialization(self):
1958+
"""Test that WebSocket route is registered during initialization."""
1959+
app = BedrockAgentCoreApp()
1960+
routes = app.routes
1961+
route_paths = [route.path for route in routes] # type: ignore
1962+
1963+
assert "/ws" in route_paths
1964+
1965+
def test_websocket_decorator(self):
1966+
"""Test @app.websocket decorator registers handler."""
1967+
app = BedrockAgentCoreApp()
1968+
1969+
@app.websocket
1970+
async def test_handler(websocket, context):
1971+
await websocket.accept()
1972+
1973+
assert app._websocket_handler is not None
1974+
assert app._websocket_handler == test_handler
1975+
1976+
def test_websocket_no_handler_defined(self):
1977+
"""Test WebSocket endpoint when no handler is defined."""
1978+
from starlette.websockets import WebSocketDisconnect
1979+
1980+
app = BedrockAgentCoreApp()
1981+
client = TestClient(app)
1982+
1983+
with pytest.raises((WebSocketDisconnect, RuntimeError)):
1984+
with client.websocket_connect("/ws"):
1985+
pass
1986+
1987+
def test_websocket_basic_communication(self):
1988+
"""Test basic WebSocket send/receive."""
1989+
app = BedrockAgentCoreApp()
1990+
1991+
@app.websocket
1992+
async def handler(websocket, context):
1993+
await websocket.accept()
1994+
data = await websocket.receive_json()
1995+
await websocket.send_json({"echo": data})
1996+
await websocket.close()
1997+
1998+
client = TestClient(app)
1999+
2000+
with client.websocket_connect("/ws") as websocket:
2001+
websocket.send_json({"message": "Hello"})
2002+
response = websocket.receive_json()
2003+
assert response == {"echo": {"message": "Hello"}}
2004+
2005+
def test_websocket_with_context(self):
2006+
"""Test WebSocket handler receives context with session ID."""
2007+
app = BedrockAgentCoreApp()
2008+
2009+
received_context = None
2010+
2011+
@app.websocket
2012+
async def handler(websocket, context):
2013+
nonlocal received_context
2014+
received_context = context
2015+
await websocket.accept()
2016+
await websocket.send_json({"session_id": context.session_id})
2017+
await websocket.close()
2018+
2019+
client = TestClient(app)
2020+
2021+
with client.websocket_connect(
2022+
"/ws", headers={"X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": "ws-session-123"}
2023+
) as websocket:
2024+
response = websocket.receive_json()
2025+
assert response["session_id"] == "ws-session-123"
2026+
assert received_context is not None
2027+
assert received_context.session_id == "ws-session-123"
2028+
2029+
def test_websocket_handler_exception(self):
2030+
"""Test WebSocket handler exceptions are caught and logged."""
2031+
from starlette.websockets import WebSocketDisconnect
2032+
2033+
app = BedrockAgentCoreApp()
2034+
2035+
@app.websocket
2036+
async def handler(websocket, context):
2037+
await websocket.accept()
2038+
raise ValueError("Test WebSocket error")
2039+
2040+
client = TestClient(app)
2041+
2042+
with pytest.raises((WebSocketDisconnect, ValueError, RuntimeError)):
2043+
with client.websocket_connect("/ws") as websocket:
2044+
websocket.receive_json()
2045+
2046+
def test_websocket_multiple_messages(self):
2047+
"""Test WebSocket can handle multiple messages."""
2048+
app = BedrockAgentCoreApp()
2049+
2050+
@app.websocket
2051+
async def handler(websocket, context):
2052+
await websocket.accept()
2053+
for _ in range(3):
2054+
data = await websocket.receive_json()
2055+
await websocket.send_json({"received": data})
2056+
await websocket.close()
2057+
2058+
client = TestClient(app)
2059+
2060+
with client.websocket_connect("/ws") as websocket:
2061+
for i in range(3):
2062+
websocket.send_json({"count": i})
2063+
response = websocket.receive_json()
2064+
assert response == {"received": {"count": i}}
2065+
2066+
def test_websocket_disconnect_handling(self):
2067+
"""Test WebSocket gracefully handles client disconnect."""
2068+
from starlette.websockets import WebSocketDisconnect
2069+
2070+
app = BedrockAgentCoreApp()
2071+
2072+
disconnect_handled = False
2073+
2074+
@app.websocket
2075+
async def handler(websocket, context):
2076+
nonlocal disconnect_handled
2077+
await websocket.accept()
2078+
try:
2079+
while True:
2080+
await websocket.receive_json()
2081+
except WebSocketDisconnect:
2082+
disconnect_handled = True
2083+
raise
2084+
2085+
client = TestClient(app)
2086+
2087+
with client.websocket_connect("/ws") as websocket:
2088+
websocket.send_json({"message": "test"})
2089+
2090+
# Disconnect should be handled gracefully
2091+
assert disconnect_handled
2092+
2093+
def test_websocket_with_request_headers(self):
2094+
"""Test WebSocket handler receives custom request headers via context."""
2095+
app = BedrockAgentCoreApp()
2096+
2097+
received_headers = None
2098+
2099+
@app.websocket
2100+
async def handler(websocket, context):
2101+
nonlocal received_headers
2102+
received_headers = context.request_headers
2103+
await websocket.accept()
2104+
await websocket.send_json({"has_headers": context.request_headers is not None})
2105+
await websocket.close()
2106+
2107+
client = TestClient(app)
2108+
2109+
headers = {
2110+
"Authorization": "Bearer ws-token",
2111+
"X-Amzn-Bedrock-AgentCore-Runtime-Custom-ClientId": "ws-client-123",
2112+
}
2113+
2114+
with client.websocket_connect("/ws", headers=headers) as websocket:
2115+
response = websocket.receive_json()
2116+
assert response["has_headers"] is True
2117+
2118+
assert received_headers is not None
2119+
# Find authorization header (case-insensitive)
2120+
auth_key = next((k for k in received_headers.keys() if k.lower() == "authorization"), None)
2121+
assert auth_key is not None
2122+
assert received_headers[auth_key] == "Bearer ws-token"
2123+
2124+
def test_websocket_streaming_data(self):
2125+
"""Test WebSocket can stream multiple data chunks."""
2126+
app = BedrockAgentCoreApp()
2127+
2128+
@app.websocket
2129+
async def handler(websocket, context):
2130+
await websocket.accept()
2131+
# Stream data
2132+
for i in range(5):
2133+
await websocket.send_json({"chunk": i, "data": f"chunk_{i}"})
2134+
await websocket.send_json({"done": True})
2135+
await websocket.close()
2136+
2137+
client = TestClient(app)
2138+
2139+
with client.websocket_connect("/ws") as websocket:
2140+
chunks = []
2141+
for _ in range(5):
2142+
chunk = websocket.receive_json()
2143+
chunks.append(chunk)
2144+
2145+
final = websocket.receive_json()
2146+
2147+
assert len(chunks) == 5
2148+
assert chunks[0] == {"chunk": 0, "data": "chunk_0"}
2149+
assert chunks[4] == {"chunk": 4, "data": "chunk_4"}
2150+
assert final == {"done": True}

0 commit comments

Comments
 (0)