Skip to content

Commit 44b55b6

Browse files
author
Sundar Raghavan
committed
fix: address SDE feedback. Add @requires_iam_access_token decorator for AWS STS JWT tokens
- Add new @requires_iam_access_token decorator for AWS IAM JWT federation - Decorator calls STS:GetWebIdentityToken directly (not via AgentCore Identity) - Separate from @requires_access_token to avoid confusion between OAuth and STS flows - Bump boto3 minimum version to support new STS API
1 parent 3771e9c commit 44b55b6

File tree

5 files changed

+305
-684
lines changed

5 files changed

+305
-684
lines changed

src/bedrock_agentcore/identity/auth.py

Lines changed: 133 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Callable, Dict, List, Literal, Optional
99

1010
import boto3
11+
from botocore.exceptions import ClientError
1112

1213
from bedrock_agentcore.runtime import BedrockAgentCoreContext
1314
from bedrock_agentcore.services.identity import IdentityClient, TokenPoller
@@ -20,94 +21,41 @@
2021

2122
def requires_access_token(
2223
*,
23-
# OAuth parameters (required for M2M and USER_FEDERATION)
24-
provider_name: Optional[str] = None,
25-
scopes: Optional[List[str]] = None,
24+
provider_name: str,
25+
into: str = "access_token",
26+
scopes: List[str],
2627
on_auth_url: Optional[Callable[[str], Any]] = None,
28+
auth_flow: Literal["M2M", "USER_FEDERATION"],
2729
callback_url: Optional[str] = None,
2830
force_authentication: bool = False,
2931
token_poller: Optional[TokenPoller] = None,
3032
custom_state: Optional[str] = None,
3133
custom_parameters: Optional[Dict[str, str]] = None,
32-
# AWS JWT parameters (required for AWS_JWT)
33-
audience: Optional[List[str]] = None,
34-
signing_algorithm: str = "ES384",
35-
duration_seconds: int = 300,
36-
tags: Optional[List[Dict[str, str]]] = None,
37-
# Common parameters
38-
into: str = "access_token",
39-
auth_flow: Literal["M2M", "USER_FEDERATION", "AWS_JWT"] = "USER_FEDERATION",
4034
) -> Callable:
41-
"""Decorator that fetches an access token before calling the decorated function.
42-
43-
Supports three authentication flows:
35+
"""Decorator that fetches an OAuth2 access token before calling the decorated function.
4436
45-
1. USER_FEDERATION (OAuth 3LO): User consent required, uses credential provider
46-
2. M2M (OAuth client credentials): Machine-to-machine, uses credential provider
47-
3. AWS_JWT: Direct AWS STS JWT, no secrets required
48-
49-
OAuth Parameters (for M2M and USER_FEDERATION):
50-
provider_name: The credential provider name (required for OAuth flows)
51-
scopes: OAuth2 scopes to request (required for OAuth flows)
52-
on_auth_url: Callback for handling authorization URLs (USER_FEDERATION only)
37+
Args:
38+
provider_name: The credential provider name
39+
into: Parameter name to inject the token into
40+
scopes: OAuth2 scopes to request
41+
on_auth_url: Callback for handling authorization URLs
42+
auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION")
5343
callback_url: OAuth2 callback URL
5444
force_authentication: Force re-authentication
5545
token_poller: Custom token poller implementation
56-
custom_state: State for callback verification
57-
custom_parameters: Additional OAuth parameters
58-
59-
AWS JWT Parameters (for AWS_JWT):
60-
audience: List of intended token recipients (required for AWS_JWT)
61-
signing_algorithm: 'ES384' (default) or 'RS256'
62-
duration_seconds: Token lifetime 60-3600 (default 300)
63-
tags: Custom claims as [{'Key': str, 'Value': str}, ...]
64-
65-
Common Parameters:
66-
into: Parameter name to inject the token into (default: 'access_token')
67-
auth_flow: Authentication flow type
46+
custom_state: A state that allows applications to verify the validity of callbacks to callback_url
47+
custom_parameters: A map of custom parameters to include in authorization request to the credential provider
48+
Note: these parameters are in addition to standard OAuth 2.0 flow parameters
6849
6950
Returns:
7051
Decorator function
71-
72-
Examples:
73-
# OAuth USER_FEDERATION flow
74-
@requires_access_token(
75-
provider_name="CognitoProvider",
76-
scopes=["openid"],
77-
auth_flow="USER_FEDERATION",
78-
on_auth_url=lambda url: print(f"Please authorize: {url}")
79-
)
80-
async def call_oauth_api(*, access_token: str):
81-
...
82-
83-
# AWS JWT flow (no secrets!)
84-
@requires_access_token(
85-
auth_flow="AWS_JWT",
86-
audience=["https://api.example.com"],
87-
signing_algorithm="ES384",
88-
)
89-
async def call_external_api(*, access_token: str):
90-
...
9152
"""
92-
# Validate parameters based on flow
93-
if auth_flow in ["M2M", "USER_FEDERATION"]:
94-
if not provider_name:
95-
raise ValueError(f"provider_name is required for auth_flow='{auth_flow}'")
96-
if not scopes:
97-
raise ValueError(f"scopes is required for auth_flow='{auth_flow}'")
98-
elif auth_flow == "AWS_JWT":
99-
if not audience:
100-
raise ValueError("audience is required for auth_flow='AWS_JWT'")
101-
if signing_algorithm not in ["ES384", "RS256"]:
102-
raise ValueError("signing_algorithm must be 'ES384' or 'RS256'")
103-
if not (60 <= duration_seconds <= 3600):
104-
raise ValueError("duration_seconds must be between 60 and 3600")
10553

10654
def decorator(func: Callable) -> Callable:
10755
client = IdentityClient(_get_region())
10856

109-
async def _get_oauth_token() -> str:
110-
"""Get token via OAuth flow (existing logic)."""
57+
async def _get_token() -> str:
58+
"""Common token fetching logic."""
11159
return await client.get_token(
11260
provider_name=provider_name,
11361
agent_identity_token=await _get_workload_access_token(client),
@@ -121,23 +69,6 @@ async def _get_oauth_token() -> str:
12169
custom_parameters=custom_parameters,
12270
)
12371

124-
async def _get_aws_jwt_token() -> str:
125-
"""Get token via AWS STS (new logic)."""
126-
result = client.get_aws_jwt_token_sync(
127-
audience=audience,
128-
signing_algorithm=signing_algorithm,
129-
duration_seconds=duration_seconds,
130-
tags=tags,
131-
)
132-
return result["token"]
133-
134-
async def _get_token() -> str:
135-
"""Route to appropriate token retrieval method."""
136-
if auth_flow == "AWS_JWT":
137-
return await _get_aws_jwt_token()
138-
else:
139-
return await _get_oauth_token()
140-
14172
@wraps(func)
14273
async def async_wrapper(*args: Any, **kwargs_func: Any) -> Any:
14374
token = await _get_token()
@@ -170,6 +101,122 @@ def sync_wrapper(*args: Any, **kwargs_func: Any) -> Any:
170101
return decorator
171102

172103

104+
def requires_iam_access_token(
105+
*,
106+
audience: List[str],
107+
signing_algorithm: str = "ES384",
108+
duration_seconds: int = 300,
109+
tags: Optional[List[Dict[str, str]]] = None,
110+
into: str = "access_token",
111+
) -> Callable:
112+
"""Decorator that fetches an AWS IAM JWT token before calling the decorated function.
113+
114+
This decorator obtains a signed JWT from AWS STS using the GetWebIdentityToken API.
115+
The JWT can be used to authenticate with external services that support OIDC token
116+
validation. No client secrets are required - the token is signed by AWS.
117+
118+
This is separate from @requires_access_token which uses AgentCore Identity for
119+
OAuth 2.0 flows. Use this decorator for M2M authentication with services that
120+
accept AWS-signed JWTs.
121+
122+
Args:
123+
audience: List of intended token recipients (populates 'aud' claim in JWT).
124+
Must match what the external service expects.
125+
signing_algorithm: Algorithm for signing the JWT.
126+
'ES384' (default) or 'RS256'.
127+
duration_seconds: Token lifetime in seconds (60-3600, default 300).
128+
tags: Optional custom claims as [{'Key': str, 'Value': str}, ...].
129+
These are added to the JWT as additional claims.
130+
into: Parameter name to inject the token into (default: 'access_token').
131+
132+
Returns:
133+
Decorator function that wraps the target function.
134+
135+
Raises:
136+
ValueError: If parameters are invalid.
137+
RuntimeError: If AWS JWT federation is not enabled for the account.
138+
ClientError: If the STS API call fails.
139+
140+
Example:
141+
@tool
142+
@requires_iam_access_token(
143+
audience=["https://api.example.com"],
144+
signing_algorithm="ES384",
145+
duration_seconds=300,
146+
)
147+
def call_external_api(query: str, *, access_token: str) -> str:
148+
'''Call external API with AWS JWT authentication.'''
149+
import requests
150+
response = requests.get(
151+
"https://api.example.com/data",
152+
headers={"Authorization": f"Bearer {access_token}"},
153+
params={"q": query},
154+
)
155+
return response.text
156+
157+
Note:
158+
Before using this decorator, you must:
159+
1. Enable AWS IAM Outbound Web Identity Federation for your account
160+
(via `agentcore identity setup-aws-jwt` or IAM API)
161+
2. Ensure the execution role has `sts:GetWebIdentityToken` permission
162+
3. Configure the external service to trust your AWS account's issuer URL
163+
"""
164+
# Validate parameters
165+
if not audience:
166+
raise ValueError("audience is required")
167+
if signing_algorithm not in ["ES384", "RS256"]:
168+
raise ValueError("signing_algorithm must be 'ES384' or 'RS256'")
169+
if not (60 <= duration_seconds <= 3600):
170+
raise ValueError("duration_seconds must be between 60 and 3600")
171+
172+
logger = logging.getLogger(__name__)
173+
174+
def _get_iam_jwt_token(region: str) -> str:
175+
"""Get JWT from AWS STS - NO IdentityClient involved."""
176+
logger.info("Getting AWS IAM JWT token from STS...")
177+
sts_client = boto3.client("sts", region_name=region)
178+
179+
params = {
180+
"Audience": audience,
181+
"SigningAlgorithm": signing_algorithm,
182+
"DurationSeconds": duration_seconds,
183+
}
184+
if tags:
185+
params["Tags"] = tags
186+
187+
try:
188+
response = sts_client.get_web_identity_token(**params)
189+
logger.info("Successfully obtained AWS IAM JWT token")
190+
return response["WebIdentityToken"]
191+
except ClientError as e:
192+
error_code = e.response.get("Error", {}).get("Code", "")
193+
if error_code in ["FeatureDisabledException", "FeatureDisabled"]:
194+
raise RuntimeError("AWS IAM Outbound Web Identity Federation is not enabled.") from e
195+
logger.error("Failed to get AWS IAM JWT token: %s", str(e))
196+
raise
197+
198+
def decorator(func: Callable) -> Callable:
199+
@wraps(func)
200+
async def async_wrapper(*args: Any, **kwargs_func: Any) -> Any:
201+
region = _get_region()
202+
token = _get_iam_jwt_token(region)
203+
kwargs_func[into] = token
204+
return await func(*args, **kwargs_func)
205+
206+
@wraps(func)
207+
def sync_wrapper(*args: Any, **kwargs_func: Any) -> Any:
208+
region = _get_region()
209+
token = _get_iam_jwt_token(region)
210+
kwargs_func[into] = token
211+
return func(*args, **kwargs_func)
212+
213+
if asyncio.iscoroutinefunction(func):
214+
return async_wrapper
215+
return sync_wrapper
216+
217+
return decorator
218+
219+
173220
def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable:
174221
"""Decorator that fetches an API key before calling the decorated function.
175222

src/bedrock_agentcore/services/identity.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -249,97 +249,3 @@ async def get_api_key(self, *, provider_name: str, agent_identity_token: str) ->
249249
req = {"resourceCredentialProviderName": provider_name, "workloadIdentityToken": agent_identity_token}
250250

251251
return self.dp_client.get_resource_api_key(**req)["apiKey"]
252-
253-
async def get_aws_jwt_token(
254-
self,
255-
*,
256-
audience: List[str],
257-
signing_algorithm: str = "ES384",
258-
duration_seconds: int = 300,
259-
tags: Optional[List[Dict[str, str]]] = None,
260-
) -> Dict[str, Any]:
261-
"""Get a signed JWT from AWS STS for external service authentication.
262-
263-
This method calls STS:GetWebIdentityToken directly - no AgentCore Identity
264-
service involvement, no secrets required.
265-
266-
Args:
267-
audience: List of intended recipients (populates 'aud' claim)
268-
signing_algorithm: Algorithm for signing ('ES384' or 'RS256')
269-
duration_seconds: Token lifetime (60-3600, default 300)
270-
tags: Optional list of {'Key': str, 'Value': str} for custom claims
271-
272-
Returns:
273-
Dict with 'token' (the JWT string) and 'expiration' (datetime)
274-
275-
Raises:
276-
Exception: If token generation fails
277-
"""
278-
self.logger.info("Getting AWS JWT token...")
279-
280-
# Create regional STS client (GetWebIdentityToken not available on global endpoint)
281-
sts_client = boto3.client("sts", region_name=self.region)
282-
283-
# Build request parameters
284-
params = {
285-
"Audience": audience,
286-
"SigningAlgorithm": signing_algorithm,
287-
"DurationSeconds": duration_seconds,
288-
}
289-
290-
# Add tags if provided
291-
if tags:
292-
params["Tags"] = tags
293-
294-
try:
295-
response = sts_client.get_web_identity_token(**params)
296-
297-
self.logger.info("Successfully obtained AWS JWT token")
298-
299-
return {
300-
"token": response["WebIdentityToken"],
301-
"expiration": response["Expiration"],
302-
}
303-
304-
except Exception as e:
305-
self.logger.error("Failed to get AWS JWT token: %s", str(e))
306-
raise
307-
308-
def get_aws_jwt_token_sync(
309-
self,
310-
*,
311-
audience: List[str],
312-
signing_algorithm: str = "ES384",
313-
duration_seconds: int = 300,
314-
tags: Optional[List[Dict[str, str]]] = None,
315-
) -> Dict[str, Any]:
316-
"""Synchronous version of get_aws_jwt_token.
317-
318-
See get_aws_jwt_token for full documentation.
319-
"""
320-
self.logger.info("Getting AWS JWT token (sync)...")
321-
322-
sts_client = boto3.client("sts", region_name=self.region)
323-
324-
params = {
325-
"Audience": audience,
326-
"SigningAlgorithm": signing_algorithm,
327-
"DurationSeconds": duration_seconds,
328-
}
329-
330-
if tags:
331-
params["Tags"] = tags
332-
333-
try:
334-
response = sts_client.get_web_identity_token(**params)
335-
336-
self.logger.info("Successfully obtained AWS JWT token")
337-
338-
return {
339-
"token": response["WebIdentityToken"],
340-
"expiration": response["Expiration"],
341-
}
342-
343-
except Exception as e:
344-
self.logger.error("Failed to get AWS JWT token: %s", str(e))
345-
raise

0 commit comments

Comments
 (0)