1717
1818use crate :: { Credential , constants:: * } ;
1919use async_trait:: async_trait;
20+ use form_urlencoded:: Serializer ;
2021use reqsign_core:: Result ;
2122use reqsign_core:: time:: Timestamp ;
2223use reqsign_core:: { Context , ProvideCredential } ;
@@ -26,12 +27,14 @@ use serde::Deserialize;
2627///
2728/// This provider reads configuration from environment variables at runtime:
2829/// - `ALIBABA_CLOUD_ROLE_ARN`: The ARN of the role to assume
30+ /// - `ALIBABA_CLOUD_ROLE_SESSION_NAME`: Optional role session name
2931/// - `ALIBABA_CLOUD_OIDC_PROVIDER_ARN`: The ARN of the OIDC provider
3032/// - `ALIBABA_CLOUD_OIDC_TOKEN_FILE`: Path to the OIDC token file
3133/// - `ALIBABA_CLOUD_STS_ENDPOINT`: Optional custom STS endpoint
3234#[ derive( Debug , Default , Clone ) ]
3335pub struct AssumeRoleWithOidcCredentialProvider {
3436 sts_endpoint : Option < String > ,
37+ role_session_name : Option < String > ,
3538}
3639
3740impl AssumeRoleWithOidcCredentialProvider {
@@ -47,6 +50,14 @@ impl AssumeRoleWithOidcCredentialProvider {
4750 self
4851 }
4952
53+ /// Set the role session name.
54+ ///
55+ /// This setting takes precedence over `ALIBABA_CLOUD_ROLE_SESSION_NAME`.
56+ pub fn with_role_session_name ( mut self , name : impl Into < String > ) -> Self {
57+ self . role_session_name = Some ( name. into ( ) ) ;
58+ self
59+ }
60+
5061 fn get_sts_endpoint ( & self , envs : & std:: collections:: HashMap < String , String > ) -> String {
5162 if let Some ( endpoint) = & self . sts_endpoint {
5263 return endpoint. clone ( ) ;
@@ -57,6 +68,16 @@ impl AssumeRoleWithOidcCredentialProvider {
5768 None => "https://sts.aliyuncs.com" . to_string ( ) ,
5869 }
5970 }
71+
72+ fn get_role_session_name ( & self , envs : & std:: collections:: HashMap < String , String > ) -> String {
73+ if let Some ( name) = & self . role_session_name {
74+ return name. clone ( ) ;
75+ }
76+
77+ envs. get ( ALIBABA_CLOUD_ROLE_SESSION_NAME )
78+ . cloned ( )
79+ . unwrap_or_else ( || "reqsign" . to_string ( ) )
80+ }
6081}
6182
6283#[ async_trait]
@@ -76,20 +97,22 @@ impl ProvideCredential for AssumeRoleWithOidcCredentialProvider {
7697 _ => return Ok ( None ) ,
7798 } ;
7899
79- let token = ctx. file_read ( token_file) . await ?;
80- let token = String :: from_utf8 ( token) ? ;
81- let role_session_name = "reqsign" ; // Default session name
100+ let token = ctx. file_read_as_string ( token_file) . await ?;
101+ let token = token. trim ( ) ;
102+ let role_session_name = self . get_role_session_name ( & envs ) ;
82103
83104 // Construct request to Aliyun STS Service.
84- let url = format ! (
85- "{}/?Action=AssumeRoleWithOIDC&OIDCProviderArn={}&RoleArn={}&RoleSessionName={}&Format=JSON&Version=2015-04-01&Timestamp={}&OIDCToken={}" ,
86- self . get_sts_endpoint( & envs) ,
87- provider_arn,
88- role_arn,
89- role_session_name,
90- Timestamp :: now( ) . format_rfc3339_zulu( ) ,
91- token
92- ) ;
105+ let query = Serializer :: new ( String :: new ( ) )
106+ . append_pair ( "Action" , "AssumeRoleWithOIDC" )
107+ . append_pair ( "OIDCProviderArn" , provider_arn)
108+ . append_pair ( "RoleArn" , role_arn)
109+ . append_pair ( "RoleSessionName" , & role_session_name)
110+ . append_pair ( "Format" , "JSON" )
111+ . append_pair ( "Version" , "2015-04-01" )
112+ . append_pair ( "Timestamp" , & Timestamp :: now ( ) . format_rfc3339_zulu ( ) )
113+ . append_pair ( "OIDCToken" , token)
114+ . finish ( ) ;
115+ let url = format ! ( "{}/?{query}" , self . get_sts_endpoint( & envs) ) ;
93116
94117 let req = http:: Request :: builder ( )
95118 . method ( http:: Method :: GET )
@@ -145,10 +168,14 @@ struct AssumeRoleWithOidcCredentials {
145168#[ cfg( test) ]
146169mod tests {
147170 use super :: * ;
171+ use async_trait:: async_trait;
172+ use bytes:: Bytes ;
148173 use reqsign_core:: StaticEnv ;
174+ use reqsign_core:: { Context , FileRead , HttpSend } ;
149175 use reqsign_file_read_tokio:: TokioFileRead ;
150176 use reqsign_http_send_reqwest:: ReqwestHttpSend ;
151177 use std:: collections:: HashMap ;
178+ use std:: sync:: { Arc , Mutex } ;
152179
153180 #[ test]
154181 fn test_parse_assume_role_with_oidc_response ( ) -> Result < ( ) > {
@@ -206,4 +233,183 @@ mod tests {
206233
207234 assert ! ( credential. is_none( ) ) ;
208235 }
236+
237+ #[ derive( Debug ) ]
238+ struct TestFileRead {
239+ expected_path : String ,
240+ content : Vec < u8 > ,
241+ }
242+
243+ #[ async_trait]
244+ impl FileRead for TestFileRead {
245+ async fn file_read ( & self , path : & str ) -> Result < Vec < u8 > > {
246+ assert_eq ! ( path, self . expected_path) ;
247+ Ok ( self . content . clone ( ) )
248+ }
249+ }
250+
251+ #[ derive( Clone , Debug ) ]
252+ struct CaptureHttpSend {
253+ uri : Arc < Mutex < Option < String > > > ,
254+ body : String ,
255+ }
256+
257+ impl CaptureHttpSend {
258+ fn new ( body : impl Into < String > ) -> Self {
259+ Self {
260+ uri : Arc :: new ( Mutex :: new ( None ) ) ,
261+ body : body. into ( ) ,
262+ }
263+ }
264+
265+ fn uri ( & self ) -> Option < String > {
266+ self . uri . lock ( ) . unwrap ( ) . clone ( )
267+ }
268+ }
269+
270+ #[ async_trait]
271+ impl HttpSend for CaptureHttpSend {
272+ async fn http_send ( & self , req : http:: Request < Bytes > ) -> Result < http:: Response < Bytes > > {
273+ * self . uri . lock ( ) . unwrap ( ) = Some ( req. uri ( ) . to_string ( ) ) ;
274+ let resp = http:: Response :: builder ( )
275+ . status ( http:: StatusCode :: OK )
276+ . body ( Bytes :: from ( self . body . clone ( ) ) )
277+ . expect ( "response must build" ) ;
278+ Ok ( resp)
279+ }
280+ }
281+
282+ #[ tokio:: test]
283+ async fn test_assume_role_with_oidc_supports_role_session_name ( ) -> Result < ( ) > {
284+ let _ = env_logger:: builder ( ) . is_test ( true ) . try_init ( ) ;
285+
286+ let token_path = "/mock/token" ;
287+ let raw_token = "header.payload.signature\n " ;
288+
289+ let file_read = TestFileRead {
290+ expected_path : token_path. to_string ( ) ,
291+ content : raw_token. as_bytes ( ) . to_vec ( ) ,
292+ } ;
293+
294+ let http_body = r#"{"Credentials":{"SecurityToken":"security_token","Expiration":"2124-05-25T11:45:17Z","AccessKeySecret":"secret_access_key","AccessKeyId":"access_key_id"}}"# ;
295+ let http_send = CaptureHttpSend :: new ( http_body) ;
296+
297+ let ctx = Context :: new ( )
298+ . with_file_read ( file_read)
299+ . with_http_send ( http_send. clone ( ) )
300+ . with_env ( StaticEnv {
301+ home_dir : None ,
302+ envs : HashMap :: from_iter ( [
303+ (
304+ ALIBABA_CLOUD_OIDC_TOKEN_FILE . to_string ( ) ,
305+ token_path. to_string ( ) ,
306+ ) ,
307+ (
308+ ALIBABA_CLOUD_ROLE_ARN . to_string ( ) ,
309+ "acs:ram::123456789012:role/test-role" . to_string ( ) ,
310+ ) ,
311+ (
312+ ALIBABA_CLOUD_OIDC_PROVIDER_ARN . to_string ( ) ,
313+ "acs:ram::123456789012:oidc-provider/test-provider" . to_string ( ) ,
314+ ) ,
315+ (
316+ ALIBABA_CLOUD_ROLE_SESSION_NAME . to_string ( ) ,
317+ "my-session" . to_string ( ) ,
318+ ) ,
319+ ] ) ,
320+ } ) ;
321+
322+ let provider = AssumeRoleWithOidcCredentialProvider :: new ( ) ;
323+ let cred = provider
324+ . provide_credential ( & ctx)
325+ . await ?
326+ . expect ( "credential must be loaded" ) ;
327+
328+ assert_eq ! ( cred. access_key_id, "access_key_id" ) ;
329+ assert_eq ! ( cred. access_key_secret, "secret_access_key" ) ;
330+ assert_eq ! ( cred. security_token. as_deref( ) , Some ( "security_token" ) ) ;
331+
332+ let recorded_uri = http_send
333+ . uri ( )
334+ . expect ( "http_send must capture outgoing uri" ) ;
335+ let uri: http:: Uri = recorded_uri. parse ( ) . expect ( "uri must parse" ) ;
336+ let query = uri. query ( ) . expect ( "query must exist" ) ;
337+ let params: HashMap < String , String > = form_urlencoded:: parse ( query. as_bytes ( ) )
338+ . into_owned ( )
339+ . collect ( ) ;
340+
341+ assert_eq ! (
342+ params. get( "RoleSessionName" ) . map( String :: as_str) ,
343+ Some ( "my-session" )
344+ ) ;
345+ assert_eq ! (
346+ params. get( "OIDCToken" ) . map( String :: as_str) ,
347+ Some ( "header.payload.signature" )
348+ ) ;
349+
350+ Ok ( ( ) )
351+ }
352+
353+ #[ tokio:: test]
354+ async fn test_assume_role_with_oidc_role_session_name_overrides_env ( ) -> Result < ( ) > {
355+ let _ = env_logger:: builder ( ) . is_test ( true ) . try_init ( ) ;
356+
357+ let token_path = "/mock/token" ;
358+
359+ let file_read = TestFileRead {
360+ expected_path : token_path. to_string ( ) ,
361+ content : b"token" . to_vec ( ) ,
362+ } ;
363+
364+ let http_body = r#"{"Credentials":{"SecurityToken":"security_token","Expiration":"2124-05-25T11:45:17Z","AccessKeySecret":"secret_access_key","AccessKeyId":"access_key_id"}}"# ;
365+ let http_send = CaptureHttpSend :: new ( http_body) ;
366+
367+ let ctx = Context :: new ( )
368+ . with_file_read ( file_read)
369+ . with_http_send ( http_send. clone ( ) )
370+ . with_env ( StaticEnv {
371+ home_dir : None ,
372+ envs : HashMap :: from_iter ( [
373+ (
374+ ALIBABA_CLOUD_OIDC_TOKEN_FILE . to_string ( ) ,
375+ token_path. to_string ( ) ,
376+ ) ,
377+ (
378+ ALIBABA_CLOUD_ROLE_ARN . to_string ( ) ,
379+ "acs:ram::123456789012:role/test-role" . to_string ( ) ,
380+ ) ,
381+ (
382+ ALIBABA_CLOUD_OIDC_PROVIDER_ARN . to_string ( ) ,
383+ "acs:ram::123456789012:oidc-provider/test-provider" . to_string ( ) ,
384+ ) ,
385+ (
386+ ALIBABA_CLOUD_ROLE_SESSION_NAME . to_string ( ) ,
387+ "env-session" . to_string ( ) ,
388+ ) ,
389+ ] ) ,
390+ } ) ;
391+
392+ let provider =
393+ AssumeRoleWithOidcCredentialProvider :: new ( ) . with_role_session_name ( "override-session" ) ;
394+ let _ = provider
395+ . provide_credential ( & ctx)
396+ . await ?
397+ . expect ( "credential must be loaded" ) ;
398+
399+ let recorded_uri = http_send
400+ . uri ( )
401+ . expect ( "http_send must capture outgoing uri" ) ;
402+ let uri: http:: Uri = recorded_uri. parse ( ) . expect ( "uri must parse" ) ;
403+ let query = uri. query ( ) . expect ( "query must exist" ) ;
404+ let params: HashMap < String , String > = form_urlencoded:: parse ( query. as_bytes ( ) )
405+ . into_owned ( )
406+ . collect ( ) ;
407+
408+ assert_eq ! (
409+ params. get( "RoleSessionName" ) . map( String :: as_str) ,
410+ Some ( "override-session" )
411+ ) ;
412+
413+ Ok ( ( ) )
414+ }
209415}
0 commit comments