2020
2121def requires_access_token (
2222 * ,
23- provider_name : str ,
24- into : str = "access_token" ,
25- scopes : List [str ],
23+ # OAuth parameters (required for M2M and USER_FEDERATION)
24+ provider_name : Optional [ str ] = None ,
25+ scopes : Optional [ List [str ]] = None ,
2626 on_auth_url : Optional [Callable [[str ], Any ]] = None ,
27- auth_flow : Literal ["M2M" , "USER_FEDERATION" ],
2827 callback_url : Optional [str ] = None ,
2928 force_authentication : bool = False ,
3029 token_poller : Optional [TokenPoller ] = None ,
3130 custom_state : Optional [str ] = None ,
3231 custom_parameters : Optional [Dict [str , str ]] = None ,
32+ # AWS JWT parameters (required for AWS_JWT)
33+ audience : Optional [List [str ]] = None ,
34+ signing_algorithm : str = "ES384" ,
35+ duration_seconds : int = 300 ,
36+ tags : Optional [List [Dict [str , str ]]] = None ,
37+ # Common parameters
38+ into : str = "access_token" ,
39+ auth_flow : Literal ["M2M" , "USER_FEDERATION" , "AWS_JWT" ] = "USER_FEDERATION" ,
3340) -> Callable :
34- """Decorator that fetches an OAuth2 access token before calling the decorated function.
41+ """Decorator that fetches an access token before calling the decorated function.
3542
36- Args:
37- provider_name: The credential provider name
38- into: Parameter name to inject the token into
39- scopes: OAuth2 scopes to request
40- on_auth_url: Callback for handling authorization URLs
41- auth_flow: Authentication flow type ("M2M" or "USER_FEDERATION")
43+ Supports three authentication flows:
44+
45+ 1. USER_FEDERATION (OAuth 3LO): User consent required, uses credential provider
46+ 2. M2M (OAuth client credentials): Machine-to-machine, uses credential provider
47+ 3. AWS_JWT: Direct AWS STS JWT, no secrets required
48+
49+ OAuth Parameters (for M2M and USER_FEDERATION):
50+ provider_name: The credential provider name (required for OAuth flows)
51+ scopes: OAuth2 scopes to request (required for OAuth flows)
52+ on_auth_url: Callback for handling authorization URLs (USER_FEDERATION only)
4253 callback_url: OAuth2 callback URL
4354 force_authentication: Force re-authentication
4455 token_poller: Custom token poller implementation
45- custom_state: A state that allows applications to verify the validity of callbacks to callback_url
46- custom_parameters: A map of custom parameters to include in authorization request to the credential provider
47- Note: these parameters are in addition to standard OAuth 2.0 flow parameters
56+ custom_state: State for callback verification
57+ custom_parameters: Additional OAuth parameters
58+
59+ AWS JWT Parameters (for AWS_JWT):
60+ audience: List of intended token recipients (required for AWS_JWT)
61+ signing_algorithm: 'ES384' (default) or 'RS256'
62+ duration_seconds: Token lifetime 60-3600 (default 300)
63+ tags: Custom claims as [{'Key': str, 'Value': str}, ...]
64+
65+ Common Parameters:
66+ into: Parameter name to inject the token into (default: 'access_token')
67+ auth_flow: Authentication flow type
4868
4969 Returns:
5070 Decorator function
71+
72+ Examples:
73+ # OAuth USER_FEDERATION flow
74+ @requires_access_token(
75+ provider_name="CognitoProvider",
76+ scopes=["openid"],
77+ auth_flow="USER_FEDERATION",
78+ on_auth_url=lambda url: print(f"Please authorize: {url}")
79+ )
80+ async def call_oauth_api(*, access_token: str):
81+ ...
82+
83+ # AWS JWT flow (no secrets!)
84+ @requires_access_token(
85+ auth_flow="AWS_JWT",
86+ audience=["https://api.example.com"],
87+ signing_algorithm="ES384",
88+ )
89+ async def call_external_api(*, access_token: str):
90+ ...
5191 """
92+ # Validate parameters based on flow
93+ if auth_flow in ["M2M" , "USER_FEDERATION" ]:
94+ if not provider_name :
95+ raise ValueError (f"provider_name is required for auth_flow='{ auth_flow } '" )
96+ if not scopes :
97+ raise ValueError (f"scopes is required for auth_flow='{ auth_flow } '" )
98+ elif auth_flow == "AWS_JWT" :
99+ if not audience :
100+ raise ValueError ("audience is required for auth_flow='AWS_JWT'" )
101+ if signing_algorithm not in ["ES384" , "RS256" ]:
102+ raise ValueError ("signing_algorithm must be 'ES384' or 'RS256'" )
103+ if not (60 <= duration_seconds <= 3600 ):
104+ raise ValueError ("duration_seconds must be between 60 and 3600" )
52105
53106 def decorator (func : Callable ) -> Callable :
54107 client = IdentityClient (_get_region ())
55108
56- async def _get_token () -> str :
57- """Common token fetching logic."""
109+ async def _get_oauth_token () -> str :
110+ """Get token via OAuth flow (existing logic) ."""
58111 return await client .get_token (
59112 provider_name = provider_name ,
60113 agent_identity_token = await _get_workload_access_token (client ),
@@ -68,6 +121,23 @@ async def _get_token() -> str:
68121 custom_parameters = custom_parameters ,
69122 )
70123
124+ async def _get_aws_jwt_token () -> str :
125+ """Get token via AWS STS (new logic)."""
126+ result = client .get_aws_jwt_token_sync (
127+ audience = audience ,
128+ signing_algorithm = signing_algorithm ,
129+ duration_seconds = duration_seconds ,
130+ tags = tags ,
131+ )
132+ return result ["token" ]
133+
134+ async def _get_token () -> str :
135+ """Route to appropriate token retrieval method."""
136+ if auth_flow == "AWS_JWT" :
137+ return await _get_aws_jwt_token ()
138+ else :
139+ return await _get_oauth_token ()
140+
71141 @wraps (func )
72142 async def async_wrapper (* args : Any , ** kwargs_func : Any ) -> Any :
73143 token = await _get_token ()
0 commit comments