Skip to content

Commit 3771e9c

Browse files
author
Sundar Raghavan
committed
feat(identity): Add AWS JWT support to @requires_access_token decorator
- Extend auth_flow parameter to accept AWS_JWT in addition to M2M/USER_FEDERATION - Add get_aws_jwt_token() and get_aws_jwt_token_sync() methods to IdentityClient
1 parent 18a78b9 commit 3771e9c

File tree

5 files changed

+892
-20
lines changed

5 files changed

+892
-20
lines changed

src/bedrock_agentcore/identity/auth.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,94 @@
2020

2121
def requires_access_token(
2222
*,
23-
provider_name: str,
24-
into: str = "access_token",
25-
scopes: List[str],
23+
# OAuth parameters (required for M2M and USER_FEDERATION)
24+
provider_name: Optional[str] = None,
25+
scopes: Optional[List[str]] = None,
2626
on_auth_url: Optional[Callable[[str], Any]] = None,
27-
auth_flow: Literal["M2M", "USER_FEDERATION"],
2827
callback_url: Optional[str] = None,
2928
force_authentication: bool = False,
3029
token_poller: Optional[TokenPoller] = None,
3130
custom_state: Optional[str] = None,
3231
custom_parameters: Optional[Dict[str, str]] = None,
32+
# AWS JWT parameters (required for AWS_JWT)
33+
audience: Optional[List[str]] = None,
34+
signing_algorithm: str = "ES384",
35+
duration_seconds: int = 300,
36+
tags: Optional[List[Dict[str, str]]] = None,
37+
# Common parameters
38+
into: str = "access_token",
39+
auth_flow: Literal["M2M", "USER_FEDERATION", "AWS_JWT"] = "USER_FEDERATION",
3340
) -> Callable:
34-
"""Decorator that fetches an OAuth2 access token before calling the decorated function.
41+
"""Decorator that fetches an access token before calling the decorated function.
3542
36-
Args:
37-
provider_name: The credential provider name
38-
into: Parameter name to inject the token into
39-
scopes: OAuth2 scopes to request
40-
on_auth_url: Callback for handling authorization URLs
41-
auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION")
43+
Supports three authentication flows:
44+
45+
1. USER_FEDERATION (OAuth 3LO): User consent required, uses credential provider
46+
2. M2M (OAuth client credentials): Machine-to-machine, uses credential provider
47+
3. AWS_JWT: Direct AWS STS JWT, no secrets required
48+
49+
OAuth Parameters (for M2M and USER_FEDERATION):
50+
provider_name: The credential provider name (required for OAuth flows)
51+
scopes: OAuth2 scopes to request (required for OAuth flows)
52+
on_auth_url: Callback for handling authorization URLs (USER_FEDERATION only)
4253
callback_url: OAuth2 callback URL
4354
force_authentication: Force re-authentication
4455
token_poller: Custom token poller implementation
45-
custom_state: A state that allows applications to verify the validity of callbacks to callback_url
46-
custom_parameters: A map of custom parameters to include in authorization request to the credential provider
47-
Note: these parameters are in addition to standard OAuth 2.0 flow parameters
56+
custom_state: State for callback verification
57+
custom_parameters: Additional OAuth parameters
58+
59+
AWS JWT Parameters (for AWS_JWT):
60+
audience: List of intended token recipients (required for AWS_JWT)
61+
signing_algorithm: 'ES384' (default) or 'RS256'
62+
duration_seconds: Token lifetime 60-3600 (default 300)
63+
tags: Custom claims as [{'Key': str, 'Value': str}, ...]
64+
65+
Common Parameters:
66+
into: Parameter name to inject the token into (default: 'access_token')
67+
auth_flow: Authentication flow type
4868
4969
Returns:
5070
Decorator function
71+
72+
Examples:
73+
# OAuth USER_FEDERATION flow
74+
@requires_access_token(
75+
provider_name="CognitoProvider",
76+
scopes=["openid"],
77+
auth_flow="USER_FEDERATION",
78+
on_auth_url=lambda url: print(f"Please authorize: {url}")
79+
)
80+
async def call_oauth_api(*, access_token: str):
81+
...
82+
83+
# AWS JWT flow (no secrets!)
84+
@requires_access_token(
85+
auth_flow="AWS_JWT",
86+
audience=["https://api.example.com"],
87+
signing_algorithm="ES384",
88+
)
89+
async def call_external_api(*, access_token: str):
90+
...
5191
"""
92+
# Validate parameters based on flow
93+
if auth_flow in ["M2M", "USER_FEDERATION"]:
94+
if not provider_name:
95+
raise ValueError(f"provider_name is required for auth_flow='{auth_flow}'")
96+
if not scopes:
97+
raise ValueError(f"scopes is required for auth_flow='{auth_flow}'")
98+
elif auth_flow == "AWS_JWT":
99+
if not audience:
100+
raise ValueError("audience is required for auth_flow='AWS_JWT'")
101+
if signing_algorithm not in ["ES384", "RS256"]:
102+
raise ValueError("signing_algorithm must be 'ES384' or 'RS256'")
103+
if not (60 <= duration_seconds <= 3600):
104+
raise ValueError("duration_seconds must be between 60 and 3600")
52105

53106
def decorator(func: Callable) -> Callable:
54107
client = IdentityClient(_get_region())
55108

56-
async def _get_token() -> str:
57-
"""Common token fetching logic."""
109+
async def _get_oauth_token() -> str:
110+
"""Get token via OAuth flow (existing logic)."""
58111
return await client.get_token(
59112
provider_name=provider_name,
60113
agent_identity_token=await _get_workload_access_token(client),
@@ -68,6 +121,23 @@ async def _get_token() -> str:
68121
custom_parameters=custom_parameters,
69122
)
70123

124+
async def _get_aws_jwt_token() -> str:
125+
"""Get token via AWS STS (new logic)."""
126+
result = client.get_aws_jwt_token_sync(
127+
audience=audience,
128+
signing_algorithm=signing_algorithm,
129+
duration_seconds=duration_seconds,
130+
tags=tags,
131+
)
132+
return result["token"]
133+
134+
async def _get_token() -> str:
135+
"""Route to appropriate token retrieval method."""
136+
if auth_flow == "AWS_JWT":
137+
return await _get_aws_jwt_token()
138+
else:
139+
return await _get_oauth_token()
140+
71141
@wraps(func)
72142
async def async_wrapper(*args: Any, **kwargs_func: Any) -> Any:
73143
token = await _get_token()

src/bedrock_agentcore/services/identity.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,97 @@ async def get_api_key(self, *, provider_name: str, agent_identity_token: str) ->
249249
req = {"resourceCredentialProviderName": provider_name, "workloadIdentityToken": agent_identity_token}
250250

251251
return self.dp_client.get_resource_api_key(**req)["apiKey"]
252+
253+
async def get_aws_jwt_token(
254+
self,
255+
*,
256+
audience: List[str],
257+
signing_algorithm: str = "ES384",
258+
duration_seconds: int = 300,
259+
tags: Optional[List[Dict[str, str]]] = None,
260+
) -> Dict[str, Any]:
261+
"""Get a signed JWT from AWS STS for external service authentication.
262+
263+
This method calls STS:GetWebIdentityToken directly - no AgentCore Identity
264+
service involvement, no secrets required.
265+
266+
Args:
267+
audience: List of intended recipients (populates 'aud' claim)
268+
signing_algorithm: Algorithm for signing ('ES384' or 'RS256')
269+
duration_seconds: Token lifetime (60-3600, default 300)
270+
tags: Optional list of {'Key': str, 'Value': str} for custom claims
271+
272+
Returns:
273+
Dict with 'token' (the JWT string) and 'expiration' (datetime)
274+
275+
Raises:
276+
Exception: If token generation fails
277+
"""
278+
self.logger.info("Getting AWS JWT token...")
279+
280+
# Create regional STS client (GetWebIdentityToken not available on global endpoint)
281+
sts_client = boto3.client("sts", region_name=self.region)
282+
283+
# Build request parameters
284+
params = {
285+
"Audience": audience,
286+
"SigningAlgorithm": signing_algorithm,
287+
"DurationSeconds": duration_seconds,
288+
}
289+
290+
# Add tags if provided
291+
if tags:
292+
params["Tags"] = tags
293+
294+
try:
295+
response = sts_client.get_web_identity_token(**params)
296+
297+
self.logger.info("Successfully obtained AWS JWT token")
298+
299+
return {
300+
"token": response["WebIdentityToken"],
301+
"expiration": response["Expiration"],
302+
}
303+
304+
except Exception as e:
305+
self.logger.error("Failed to get AWS JWT token: %s", str(e))
306+
raise
307+
308+
def get_aws_jwt_token_sync(
309+
self,
310+
*,
311+
audience: List[str],
312+
signing_algorithm: str = "ES384",
313+
duration_seconds: int = 300,
314+
tags: Optional[List[Dict[str, str]]] = None,
315+
) -> Dict[str, Any]:
316+
"""Synchronous version of get_aws_jwt_token.
317+
318+
See get_aws_jwt_token for full documentation.
319+
"""
320+
self.logger.info("Getting AWS JWT token (sync)...")
321+
322+
sts_client = boto3.client("sts", region_name=self.region)
323+
324+
params = {
325+
"Audience": audience,
326+
"SigningAlgorithm": signing_algorithm,
327+
"DurationSeconds": duration_seconds,
328+
}
329+
330+
if tags:
331+
params["Tags"] = tags
332+
333+
try:
334+
response = sts_client.get_web_identity_token(**params)
335+
336+
self.logger.info("Successfully obtained AWS JWT token")
337+
338+
return {
339+
"token": response["WebIdentityToken"],
340+
"expiration": response["Expiration"],
341+
}
342+
343+
except Exception as e:
344+
self.logger.error("Failed to get AWS JWT token: %s", str(e))
345+
raise

0 commit comments

Comments
 (0)