88from typing import Any , Callable , Dict , List , Literal , Optional
99
1010import boto3
11+ from botocore .exceptions import ClientError
1112
1213from bedrock_agentcore .runtime import BedrockAgentCoreContext
1314from bedrock_agentcore .services .identity import IdentityClient , TokenPoller
2021
2122def requires_access_token (
2223 * ,
23- # OAuth parameters (required for M2M and USER_FEDERATION)
24- provider_name : Optional [ str ] = None ,
25- scopes : Optional [ List [str ]] = None ,
24+ provider_name : str ,
25+ into : str = "access_token" ,
26+ scopes : List [str ],
2627 on_auth_url : Optional [Callable [[str ], Any ]] = None ,
28+ auth_flow : Literal ["M2M" , "USER_FEDERATION" ],
2729 callback_url : Optional [str ] = None ,
2830 force_authentication : bool = False ,
2931 token_poller : Optional [TokenPoller ] = None ,
3032 custom_state : Optional [str ] = None ,
3133 custom_parameters : Optional [Dict [str , str ]] = None ,
32- # AWS JWT parameters (required for AWS_JWT)
33- audience : Optional [List [str ]] = None ,
34- signing_algorithm : str = "ES384" ,
35- duration_seconds : int = 300 ,
36- tags : Optional [List [Dict [str , str ]]] = None ,
37- # Common parameters
38- into : str = "access_token" ,
39- auth_flow : Literal ["M2M" , "USER_FEDERATION" , "AWS_JWT" ] = "USER_FEDERATION" ,
4034) -> Callable :
41- """Decorator that fetches an access token before calling the decorated function.
42-
43- Supports three authentication flows:
35+ """Decorator that fetches an OAuth2 access token before calling the decorated function.
4436
45- 1. USER_FEDERATION (OAuth 3LO): User consent required, uses credential provider
46- 2. M2M (OAuth client credentials): Machine-to-machine, uses credential provider
47- 3. AWS_JWT: Direct AWS STS JWT, no secrets required
48-
49- OAuth Parameters (for M2M and USER_FEDERATION):
50- provider_name: The credential provider name (required for OAuth flows)
51- scopes: OAuth2 scopes to request (required for OAuth flows)
52- on_auth_url: Callback for handling authorization URLs (USER_FEDERATION only)
37+ Args:
38+ provider_name: The credential provider name
39+ into: Parameter name to inject the token into
40+ scopes: OAuth2 scopes to request
41+ on_auth_url: Callback for handling authorization URLs
42+ auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION")
5343 callback_url: OAuth2 callback URL
5444 force_authentication: Force re-authentication
5545 token_poller: Custom token poller implementation
56- custom_state: State for callback verification
57- custom_parameters: Additional OAuth parameters
58-
59- AWS JWT Parameters (for AWS_JWT):
60- audience: List of intended token recipients (required for AWS_JWT)
61- signing_algorithm: 'ES384' (default) or 'RS256'
62- duration_seconds: Token lifetime 60-3600 (default 300)
63- tags: Custom claims as [{'Key': str, 'Value': str}, ...]
64-
65- Common Parameters:
66- into: Parameter name to inject the token into (default: 'access_token')
67- auth_flow: Authentication flow type
46+ custom_state: A state that allows applications to verify the validity of callbacks to callback_url
47+ custom_parameters: A map of custom parameters to include in authorization request to the credential provider
48+ Note: these parameters are in addition to standard OAuth 2.0 flow parameters
6849
6950 Returns:
7051 Decorator function
71-
72- Examples:
73- # OAuth USER_FEDERATION flow
74- @requires_access_token(
75- provider_name="CognitoProvider",
76- scopes=["openid"],
77- auth_flow="USER_FEDERATION",
78- on_auth_url=lambda url: print(f"Please authorize: {url}")
79- )
80- async def call_oauth_api(*, access_token: str):
81- ...
82-
83- # AWS JWT flow (no secrets!)
84- @requires_access_token(
85- auth_flow="AWS_JWT",
86- audience=["https://api.example.com"],
87- signing_algorithm="ES384",
88- )
89- async def call_external_api(*, access_token: str):
90- ...
9152 """
92- # Validate parameters based on flow
93- if auth_flow in ["M2M" , "USER_FEDERATION" ]:
94- if not provider_name :
95- raise ValueError (f"provider_name is required for auth_flow='{ auth_flow } '" )
96- if not scopes :
97- raise ValueError (f"scopes is required for auth_flow='{ auth_flow } '" )
98- elif auth_flow == "AWS_JWT" :
99- if not audience :
100- raise ValueError ("audience is required for auth_flow='AWS_JWT'" )
101- if signing_algorithm not in ["ES384" , "RS256" ]:
102- raise ValueError ("signing_algorithm must be 'ES384' or 'RS256'" )
103- if not (60 <= duration_seconds <= 3600 ):
104- raise ValueError ("duration_seconds must be between 60 and 3600" )
10553
10654 def decorator (func : Callable ) -> Callable :
10755 client = IdentityClient (_get_region ())
10856
109- async def _get_oauth_token () -> str :
110- """Get token via OAuth flow (existing logic) ."""
57+ async def _get_token () -> str :
58+ """Common token fetching logic."""
11159 return await client .get_token (
11260 provider_name = provider_name ,
11361 agent_identity_token = await _get_workload_access_token (client ),
@@ -121,23 +69,6 @@ async def _get_oauth_token() -> str:
12169 custom_parameters = custom_parameters ,
12270 )
12371
124- async def _get_aws_jwt_token () -> str :
125- """Get token via AWS STS (new logic)."""
126- result = client .get_aws_jwt_token_sync (
127- audience = audience ,
128- signing_algorithm = signing_algorithm ,
129- duration_seconds = duration_seconds ,
130- tags = tags ,
131- )
132- return result ["token" ]
133-
134- async def _get_token () -> str :
135- """Route to appropriate token retrieval method."""
136- if auth_flow == "AWS_JWT" :
137- return await _get_aws_jwt_token ()
138- else :
139- return await _get_oauth_token ()
140-
14172 @wraps (func )
14273 async def async_wrapper (* args : Any , ** kwargs_func : Any ) -> Any :
14374 token = await _get_token ()
@@ -170,6 +101,122 @@ def sync_wrapper(*args: Any, **kwargs_func: Any) -> Any:
170101 return decorator
171102
172103
104+ def requires_iam_access_token (
105+ * ,
106+ audience : List [str ],
107+ signing_algorithm : str = "ES384" ,
108+ duration_seconds : int = 300 ,
109+ tags : Optional [List [Dict [str , str ]]] = None ,
110+ into : str = "access_token" ,
111+ ) -> Callable :
112+ """Decorator that fetches an AWS IAM JWT token before calling the decorated function.
113+
114+ This decorator obtains a signed JWT from AWS STS using the GetWebIdentityToken API.
115+ The JWT can be used to authenticate with external services that support OIDC token
116+ validation. No client secrets are required - the token is signed by AWS.
117+
118+ This is separate from @requires_access_token which uses AgentCore Identity for
119+ OAuth 2.0 flows. Use this decorator for M2M authentication with services that
120+ accept AWS-signed JWTs.
121+
122+ Args:
123+ audience: List of intended token recipients (populates 'aud' claim in JWT).
124+ Must match what the external service expects.
125+ signing_algorithm: Algorithm for signing the JWT.
126+ 'ES384' (default) or 'RS256'.
127+ duration_seconds: Token lifetime in seconds (60-3600, default 300).
128+ tags: Optional custom claims as [{'Key': str, 'Value': str}, ...].
129+ These are added to the JWT as additional claims.
130+ into: Parameter name to inject the token into (default: 'access_token').
131+
132+ Returns:
133+ Decorator function that wraps the target function.
134+
135+ Raises:
136+ ValueError: If parameters are invalid.
137+ RuntimeError: If AWS JWT federation is not enabled for the account.
138+ ClientError: If the STS API call fails.
139+
140+ Example:
141+ @tool
142+ @requires_iam_access_token(
143+ audience=["https://api.example.com"],
144+ signing_algorithm="ES384",
145+ duration_seconds=300,
146+ )
147+ def call_external_api(query: str, *, access_token: str) -> str:
148+ '''Call external API with AWS JWT authentication.'''
149+ import requests
150+ response = requests.get(
151+ "https://api.example.com/data",
152+ headers={"Authorization": f"Bearer {access_token}"},
153+ params={"q": query},
154+ )
155+ return response.text
156+
157+ Note:
158+ Before using this decorator, you must:
159+ 1. Enable AWS IAM Outbound Web Identity Federation for your account
160+ (via `agentcore identity setup-aws-jwt` or IAM API)
161+ 2. Ensure the execution role has `sts:GetWebIdentityToken` permission
162+ 3. Configure the external service to trust your AWS account's issuer URL
163+ """
164+ # Validate parameters
165+ if not audience :
166+ raise ValueError ("audience is required" )
167+ if signing_algorithm not in ["ES384" , "RS256" ]:
168+ raise ValueError ("signing_algorithm must be 'ES384' or 'RS256'" )
169+ if not (60 <= duration_seconds <= 3600 ):
170+ raise ValueError ("duration_seconds must be between 60 and 3600" )
171+
172+ logger = logging .getLogger (__name__ )
173+
174+ def _get_iam_jwt_token (region : str ) -> str :
175+ """Get JWT from AWS STS - NO IdentityClient involved."""
176+ logger .info ("Getting AWS IAM JWT token from STS..." )
177+ sts_client = boto3 .client ("sts" , region_name = region )
178+
179+ params = {
180+ "Audience" : audience ,
181+ "SigningAlgorithm" : signing_algorithm ,
182+ "DurationSeconds" : duration_seconds ,
183+ }
184+ if tags :
185+ params ["Tags" ] = tags
186+
187+ try :
188+ response = sts_client .get_web_identity_token (** params )
189+ logger .info ("Successfully obtained AWS IAM JWT token" )
190+ return response ["WebIdentityToken" ]
191+ except ClientError as e :
192+ error_code = e .response .get ("Error" , {}).get ("Code" , "" )
193+ if error_code in ["FeatureDisabledException" , "FeatureDisabled" ]:
194+ raise RuntimeError ("AWS IAM Outbound Web Identity Federation is not enabled." ) from e
195+ logger .error ("Failed to get AWS IAM JWT token: %s" , str (e ))
196+ raise
197+
198+ def decorator (func : Callable ) -> Callable :
199+ @wraps (func )
200+ async def async_wrapper (* args : Any , ** kwargs_func : Any ) -> Any :
201+ region = _get_region ()
202+ token = _get_iam_jwt_token (region )
203+ kwargs_func [into ] = token
204+ return await func (* args , ** kwargs_func )
205+
206+ @wraps (func )
207+ def sync_wrapper (* args : Any , ** kwargs_func : Any ) -> Any :
208+ region = _get_region ()
209+ token = _get_iam_jwt_token (region )
210+ kwargs_func [into ] = token
211+ return func (* args , ** kwargs_func )
212+
213+ if asyncio .iscoroutinefunction (func ):
214+ return async_wrapper
215+ return sync_wrapper
216+
217+ return decorator
218+
219+
173220def requires_api_key (* , provider_name : str , into : str = "api_key" ) -> Callable :
174221 """Decorator that fetches an API key before calling the decorated function.
175222
0 commit comments