From 211fa52649f3205a275e8f49d7b99089edee459e Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Wed, 22 Jan 2025 16:47:41 -0600 Subject: [PATCH] Favor provided instances over shared objects Prior to this commit, providing oauth2Login() and oauth2Client() with clientRegistrationRepository() and authorizedClientRepository() caused objects to be shared across both configurers. These configurers will now prefer explicitly provided instances of those objects when they are available. Closes gh-16105 --- .../oauth2/client/OAuth2ClientConfigurer.java | 27 +++++- .../oauth2/client/OAuth2LoginConfigurer.java | 27 ++++-- .../client/OAuth2ClientConfigurerTests.java | 93 ++++++++++++++++++- .../client/OAuth2LoginConfigurerTests.java | 68 +++++++++++++- 4 files changed, 201 insertions(+), 14 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index 24b24909c99..3af7db6d02f 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -98,6 +98,10 @@ public final class OAuth2ClientConfigurer> private AuthorizationCodeGrantConfigurer authorizationCodeGrantConfigurer = new AuthorizationCodeGrantConfigurer(); + private ClientRegistrationRepository clientRegistrationRepository; + + private OAuth2AuthorizedClientRepository authorizedClientRepository; + /** * Sets the repository of client registrations. * @param clientRegistrationRepository the repository of client registrations @@ -107,6 +111,7 @@ public OAuth2ClientConfigurer clientRegistrationRepository( ClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); + this.clientRegistrationRepository = clientRegistrationRepository; return this; } @@ -119,6 +124,7 @@ public OAuth2ClientConfigurer authorizedClientRepository( OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository); + this.authorizedClientRepository = authorizedClientRepository; return this; } @@ -283,8 +289,7 @@ private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() { if (this.authorizationRequestResolver != null) { return this.authorizationRequestResolver; } - ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils - .getClientRegistrationRepository(getBuilder()); + ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(getBuilder()); return new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository, OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI); } @@ -292,8 +297,8 @@ private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() { private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B builder) { AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class); OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter( - OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), - OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder), authenticationManager); + getClientRegistrationRepository(builder), getAuthorizedClientRepository(builder), + authenticationManager); if (this.authorizationRequestRepository != null) { authorizationCodeGrantFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository); } @@ -315,6 +320,18 @@ private OAuth2AccessTokenResponseClient get return (bean != null) ? bean : new DefaultAuthorizationCodeTokenResponseClient(); } + private ClientRegistrationRepository getClientRegistrationRepository(B builder) { + return (OAuth2ClientConfigurer.this.clientRegistrationRepository != null) + ? OAuth2ClientConfigurer.this.clientRegistrationRepository + : OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder); + } + + private OAuth2AuthorizedClientRepository getAuthorizedClientRepository(B builder) { + return (OAuth2ClientConfigurer.this.authorizedClientRepository != null) + ? OAuth2ClientConfigurer.this.authorizedClientRepository + : OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder); + } + @SuppressWarnings("unchecked") private T getBeanOrNull(ResolvableType type) { ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class); diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index a6b5f7c52bf..913a8f1211e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -172,6 +172,10 @@ public final class OAuth2LoginConfigurer> private String loginProcessingUrl = OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI; + private ClientRegistrationRepository clientRegistrationRepository; + + private OAuth2AuthorizedClientRepository authorizedClientRepository; + /** * Sets the repository of client registrations. * @param clientRegistrationRepository the repository of client registrations @@ -181,6 +185,7 @@ public OAuth2LoginConfigurer clientRegistrationRepository( ClientRegistrationRepository clientRegistrationRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository); + this.clientRegistrationRepository = clientRegistrationRepository; return this; } @@ -194,6 +199,7 @@ public OAuth2LoginConfigurer authorizedClientRepository( OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository); + this.authorizedClientRepository = authorizedClientRepository; return this; } @@ -339,8 +345,7 @@ public OAuth2LoginConfigurer userInfoEndpoint(Customizer getJwtDecoderFactoryBean() { ResolvableType type = ResolvableType.forClassWithGenerics(JwtDecoderFactory.class, ClientRegistration.class); @@ -529,8 +543,7 @@ private void initDefaultLoginFilter(B http) { @SuppressWarnings("unchecked") private Map getLoginLinks() { Iterable clientRegistrations = null; - ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils - .getClientRegistrationRepository(this.getBuilder()); + ClientRegistrationRepository clientRegistrationRepository = this.getClientRegistrationRepository(); ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class); if (type != ResolvableType.NONE && ClientRegistration.class.isAssignableFrom(type.resolveGenerics()[0])) { clientRegistrations = (Iterable) clientRegistrationRepository; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java index 41e74807cdb..83dacaa265b 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,6 +75,7 @@ import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; @@ -285,6 +286,49 @@ public void configureWhenCustomAuthorizationRedirectStrategySetThenAuthorization verify(authorizationRedirectStrategy).sendRedirect(any(), any(), anyString()); } + @Test + public void configureWhenOAuth2LoginBeansConfiguredThenNotShared() throws Exception { + this.spring.register(OAuth2ClientConfigWithOAuth2Login.class).autowire(); + // Setup the Authorization Request in the session + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId()); + // @formatter:off + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri()) + .clientId(this.registration1.getClientId()) + .redirectUri("http://localhost/client-1") + .state("state") + .attributes(attributes) + .build(); + // @formatter:on + AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); + MockHttpSession session = (MockHttpSession) request.getSession(); + String principalName = "user1"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); + // @formatter:off + MockHttpServletRequestBuilder clientRequest = get("/client-1") + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .with(authentication(authentication)) + .session(session); + this.mockMvc.perform(clientRequest) + .andExpect(status().is3xxRedirection()) + .andExpect(redirectedUrl("http://localhost/client-1")); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = authorizedClientRepository + .loadAuthorizedClient(this.registration1.getRegistrationId(), authentication, request); + assertThat(authorizedClient).isNotNull(); + // Ensure shared objects set for OAuth2 Client are not used + ClientRegistrationRepository clientRegistrationRepository = this.spring.getContext() + .getBean(ClientRegistrationRepository.class); + OAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext() + .getBean(OAuth2AuthorizedClientRepository.class); + verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository); + } + @EnableWebSecurity @Configuration @EnableWebMvc @@ -362,4 +406,51 @@ OAuth2AuthorizedClientRepository authorizedClientRepository() { } + @Configuration + @EnableWebSecurity + @EnableWebMvc + static class OAuth2ClientConfigWithOAuth2Login { + + private final ClientRegistrationRepository clientRegistrationRepository = mock( + ClientRegistrationRepository.class); + + private final OAuth2AuthorizedClientRepository authorizedClientRepository = mock( + OAuth2AuthorizedClientRepository.class); + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .authorizeHttpRequests((authorize) -> authorize + .anyRequest().authenticated() + ) + .oauth2Client((oauth2Client) -> oauth2Client + .clientRegistrationRepository(OAuth2ClientConfigurerTests.clientRegistrationRepository) + .authorizedClientService(OAuth2ClientConfigurerTests.authorizedClientService) + .authorizationCodeGrant((authorizationCode) -> authorizationCode + .authorizationRequestResolver(authorizationRequestResolver) + .authorizationRedirectStrategy(authorizationRedirectStrategy) + .accessTokenResponseClient(accessTokenResponseClient) + ) + ) + .oauth2Login((oauth2Login) -> oauth2Login + .clientRegistrationRepository(this.clientRegistrationRepository) + .authorizedClientRepository(this.authorizedClientRepository) + ); + // @formatter:on + return http.build(); + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return this.clientRegistrationRepository; + } + + @Bean + OAuth2AuthorizedClientRepository authorizedClientRepository() { + return this.authorizedClientRepository; + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index b56d047a5f7..dfe6fea28fd 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +73,9 @@ import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository; import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository; +import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -115,6 +117,7 @@ import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; @@ -669,6 +672,30 @@ public void oauth2LoginWhenDefaultsThenNoOidcSessionRegistry() { .collect(Collectors.toList())).isEmpty(); } + @Test + public void oidcLoginWhenOAuth2ClientBeansConfiguredThenNotShared() throws Exception { + this.spring.register(OAuth2LoginConfigWithOAuth2Client.class, JwtDecoderFactoryConfig.class).autowire(); + OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid"); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response); + this.request.setParameter("code", "code123"); + this.request.setParameter("state", authorizationRequest.getState()); + this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain); + Authentication authentication = this.securityContextRepository + .loadContext(new HttpRequestResponseHolder(this.request, this.response)) + .getAuthentication(); + assertThat(authentication.getAuthorities()).hasSize(1); + assertThat(authentication.getAuthorities()).first() + .isInstanceOf(OidcUserAuthority.class) + .hasToString("OIDC_USER"); + + // Ensure shared objects set for OAuth2 Client are not used + ClientRegistrationRepository clientRegistrationRepository = this.spring.getContext() + .getBean(ClientRegistrationRepository.class); + OAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext() + .getBean(OAuth2AuthorizedClientRepository.class); + verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -1192,6 +1219,45 @@ SecurityFilterChain filterChain(HttpSecurity http) throws Exception { } + @Configuration + @EnableWebSecurity + static class OAuth2LoginConfigWithOAuth2Client extends CommonLambdaSecurityFilterChainConfig { + + private final ClientRegistrationRepository clientRegistrationRepository = mock( + ClientRegistrationRepository.class); + + private final OAuth2AuthorizedClientRepository authorizedClientRepository = mock( + OAuth2AuthorizedClientRepository.class); + + @Bean + SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((oauth2Login) -> oauth2Login + .clientRegistrationRepository( + new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) + .authorizedClientRepository(new HttpSessionOAuth2AuthorizedClientRepository()) + ) + .oauth2Client((oauth2Client) -> oauth2Client + .clientRegistrationRepository(this.clientRegistrationRepository) + .authorizedClientRepository(this.authorizedClientRepository) + ); + // @formatter:on + return super.configureFilterChain(http); + } + + @Bean + ClientRegistrationRepository clientRegistrationRepository() { + return this.clientRegistrationRepository; + } + + @Bean + OAuth2AuthorizedClientRepository authorizedClientRepository() { + return this.authorizedClientRepository; + } + + } + private abstract static class CommonSecurityFilterChainConfig { SecurityFilterChain configureFilterChain(HttpSecurity http) throws Exception {