Skip to content

Commit

Permalink
Favor provided instances over shared objects
Browse files Browse the repository at this point in the history
Prior to this commit, providing oauth2Login() and oauth2Client() with
clientRegistrationRepository() and authorizedClientRepository() caused
objects to be shared across both configurers.

These configurers will now prefer explicitly provided instances of
those objects when they are available.

Closes gh-16105
  • Loading branch information
sjohnr committed Jan 22, 2025
1 parent 7f410ce commit 211fa52
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -98,6 +98,10 @@ public final class OAuth2ClientConfigurer<B extends HttpSecurityBuilder<B>>

private AuthorizationCodeGrantConfigurer authorizationCodeGrantConfigurer = new AuthorizationCodeGrantConfigurer();

private ClientRegistrationRepository clientRegistrationRepository;

private OAuth2AuthorizedClientRepository authorizedClientRepository;

/**
* Sets the repository of client registrations.
* @param clientRegistrationRepository the repository of client registrations
Expand All @@ -107,6 +111,7 @@ public OAuth2ClientConfigurer<B> clientRegistrationRepository(
ClientRegistrationRepository clientRegistrationRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository);
this.clientRegistrationRepository = clientRegistrationRepository;
return this;
}

Expand All @@ -119,6 +124,7 @@ public OAuth2ClientConfigurer<B> authorizedClientRepository(
OAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository);
this.authorizedClientRepository = authorizedClientRepository;
return this;
}

Expand Down Expand Up @@ -283,17 +289,16 @@ private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
if (this.authorizationRequestResolver != null) {
return this.authorizationRequestResolver;
}
ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils
.getClientRegistrationRepository(getBuilder());
ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(getBuilder());
return new DefaultOAuth2AuthorizationRequestResolver(clientRegistrationRepository,
OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI);
}

private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B builder) {
AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class);
OAuth2AuthorizationCodeGrantFilter authorizationCodeGrantFilter = new OAuth2AuthorizationCodeGrantFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder),
OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder), authenticationManager);
getClientRegistrationRepository(builder), getAuthorizedClientRepository(builder),
authenticationManager);
if (this.authorizationRequestRepository != null) {
authorizationCodeGrantFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
}
Expand All @@ -315,6 +320,18 @@ private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> get
return (bean != null) ? bean : new DefaultAuthorizationCodeTokenResponseClient();
}

private ClientRegistrationRepository getClientRegistrationRepository(B builder) {
return (OAuth2ClientConfigurer.this.clientRegistrationRepository != null)
? OAuth2ClientConfigurer.this.clientRegistrationRepository
: OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder);
}

private OAuth2AuthorizedClientRepository getAuthorizedClientRepository(B builder) {
return (OAuth2ClientConfigurer.this.authorizedClientRepository != null)
? OAuth2ClientConfigurer.this.authorizedClientRepository
: OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(builder);
}

@SuppressWarnings("unchecked")
private <T> T getBeanOrNull(ResolvableType type) {
ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2024 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -172,6 +172,10 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>>

private String loginProcessingUrl = OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;

private ClientRegistrationRepository clientRegistrationRepository;

private OAuth2AuthorizedClientRepository authorizedClientRepository;

/**
* Sets the repository of client registrations.
* @param clientRegistrationRepository the repository of client registrations
Expand All @@ -181,6 +185,7 @@ public OAuth2LoginConfigurer<B> clientRegistrationRepository(
ClientRegistrationRepository clientRegistrationRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
this.getBuilder().setSharedObject(ClientRegistrationRepository.class, clientRegistrationRepository);
this.clientRegistrationRepository = clientRegistrationRepository;
return this;
}

Expand All @@ -194,6 +199,7 @@ public OAuth2LoginConfigurer<B> authorizedClientRepository(
OAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.getBuilder().setSharedObject(OAuth2AuthorizedClientRepository.class, authorizedClientRepository);
this.authorizedClientRepository = authorizedClientRepository;
return this;
}

Expand Down Expand Up @@ -339,8 +345,7 @@ public OAuth2LoginConfigurer<B> userInfoEndpoint(Customizer<UserInfoEndpointConf
@Override
public void init(B http) throws Exception {
OAuth2LoginAuthenticationFilter authenticationFilter = new OAuth2LoginAuthenticationFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()),
OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(this.getBuilder()), this.loginProcessingUrl);
this.getClientRegistrationRepository(), this.getAuthorizedClientRepository(), this.loginProcessingUrl);
authenticationFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy());
this.setAuthenticationFilter(authenticationFilter);
super.loginProcessingUrl(this.loginProcessingUrl);
Expand Down Expand Up @@ -406,8 +411,7 @@ public void configure(B http) throws Exception {
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
}
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()),
authorizationRequestBaseUri);
this.getClientRegistrationRepository(), authorizationRequestBaseUri);
}
if (this.authorizationEndpointConfig.authorizationRequestRepository != null) {
authorizationRequestFilter
Expand Down Expand Up @@ -439,6 +443,16 @@ protected RequestMatcher createLoginProcessingUrlMatcher(String loginProcessingU
return new AntPathRequestMatcher(loginProcessingUrl);
}

private ClientRegistrationRepository getClientRegistrationRepository() {
return (this.clientRegistrationRepository != null) ? this.clientRegistrationRepository
: OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder());
}

private OAuth2AuthorizedClientRepository getAuthorizedClientRepository() {
return (this.authorizedClientRepository != null) ? this.authorizedClientRepository
: OAuth2ClientConfigurerUtils.getAuthorizedClientRepository(this.getBuilder());
}

@SuppressWarnings("unchecked")
private JwtDecoderFactory<ClientRegistration> getJwtDecoderFactoryBean() {
ResolvableType type = ResolvableType.forClassWithGenerics(JwtDecoderFactory.class, ClientRegistration.class);
Expand Down Expand Up @@ -529,8 +543,7 @@ private void initDefaultLoginFilter(B http) {
@SuppressWarnings("unchecked")
private Map<String, String> getLoginLinks() {
Iterable<ClientRegistration> clientRegistrations = null;
ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils
.getClientRegistrationRepository(this.getBuilder());
ClientRegistrationRepository clientRegistrationRepository = this.getClientRegistrationRepository();
ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class);
if (type != ResolvableType.NONE && ClientRegistration.class.isAssignableFrom(type.resolveGenerics()[0])) {
clientRegistrations = (Iterable<ClientRegistration>) clientRegistrationRepository;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -75,6 +75,7 @@
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
Expand Down Expand Up @@ -285,6 +286,49 @@ public void configureWhenCustomAuthorizationRedirectStrategySetThenAuthorization
verify(authorizationRedirectStrategy).sendRedirect(any(), any(), anyString());
}

@Test
public void configureWhenOAuth2LoginBeansConfiguredThenNotShared() throws Exception {
this.spring.register(OAuth2ClientConfigWithOAuth2Login.class).autowire();
// Setup the Authorization Request in the session
Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId());
// @formatter:off
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri())
.clientId(this.registration1.getClientId())
.redirectUri("http://localhost/client-1")
.state("state")
.attributes(attributes)
.build();
// @formatter:on
AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
MockHttpServletRequest request = new MockHttpServletRequest("GET", "");
MockHttpServletResponse response = new MockHttpServletResponse();
authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response);
MockHttpSession session = (MockHttpSession) request.getSession();
String principalName = "user1";
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");
// @formatter:off
MockHttpServletRequestBuilder clientRequest = get("/client-1")
.param(OAuth2ParameterNames.CODE, "code")
.param(OAuth2ParameterNames.STATE, "state")
.with(authentication(authentication))
.session(session);
this.mockMvc.perform(clientRequest)
.andExpect(status().is3xxRedirection())
.andExpect(redirectedUrl("http://localhost/client-1"));
// @formatter:on
OAuth2AuthorizedClient authorizedClient = authorizedClientRepository
.loadAuthorizedClient(this.registration1.getRegistrationId(), authentication, request);
assertThat(authorizedClient).isNotNull();
// Ensure shared objects set for OAuth2 Client are not used
ClientRegistrationRepository clientRegistrationRepository = this.spring.getContext()
.getBean(ClientRegistrationRepository.class);
OAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext()
.getBean(OAuth2AuthorizedClientRepository.class);
verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository);
}

@EnableWebSecurity
@Configuration
@EnableWebMvc
Expand Down Expand Up @@ -362,4 +406,51 @@ OAuth2AuthorizedClientRepository authorizedClientRepository() {

}

@Configuration
@EnableWebSecurity
@EnableWebMvc
static class OAuth2ClientConfigWithOAuth2Login {

private final ClientRegistrationRepository clientRegistrationRepository = mock(
ClientRegistrationRepository.class);

private final OAuth2AuthorizedClientRepository authorizedClientRepository = mock(
OAuth2AuthorizedClientRepository.class);

@Bean
SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeHttpRequests((authorize) -> authorize
.anyRequest().authenticated()
)
.oauth2Client((oauth2Client) -> oauth2Client
.clientRegistrationRepository(OAuth2ClientConfigurerTests.clientRegistrationRepository)
.authorizedClientService(OAuth2ClientConfigurerTests.authorizedClientService)
.authorizationCodeGrant((authorizationCode) -> authorizationCode
.authorizationRequestResolver(authorizationRequestResolver)
.authorizationRedirectStrategy(authorizationRedirectStrategy)
.accessTokenResponseClient(accessTokenResponseClient)
)
)
.oauth2Login((oauth2Login) -> oauth2Login
.clientRegistrationRepository(this.clientRegistrationRepository)
.authorizedClientRepository(this.authorizedClientRepository)
);
// @formatter:on
return http.build();
}

@Bean
ClientRegistrationRepository clientRegistrationRepository() {
return this.clientRegistrationRepository;
}

@Bean
OAuth2AuthorizedClientRepository authorizedClientRepository() {
return this.authorizedClientRepository;
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2024 the original author or authors.
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -73,7 +73,9 @@
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
Expand Down Expand Up @@ -115,6 +117,7 @@
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
Expand Down Expand Up @@ -669,6 +672,30 @@ public void oauth2LoginWhenDefaultsThenNoOidcSessionRegistry() {
.collect(Collectors.toList())).isEmpty();
}

@Test
public void oidcLoginWhenOAuth2ClientBeansConfiguredThenNotShared() throws Exception {
this.spring.register(OAuth2LoginConfigWithOAuth2Client.class, JwtDecoderFactoryConfig.class).autowire();
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest("openid");
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response);
this.request.setParameter("code", "code123");
this.request.setParameter("state", authorizationRequest.getState());
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
Authentication authentication = this.securityContextRepository
.loadContext(new HttpRequestResponseHolder(this.request, this.response))
.getAuthentication();
assertThat(authentication.getAuthorities()).hasSize(1);
assertThat(authentication.getAuthorities()).first()
.isInstanceOf(OidcUserAuthority.class)
.hasToString("OIDC_USER");

// Ensure shared objects set for OAuth2 Client are not used
ClientRegistrationRepository clientRegistrationRepository = this.spring.getContext()
.getBean(ClientRegistrationRepository.class);
OAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext()
.getBean(OAuth2AuthorizedClientRepository.class);
verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository);
}

private void loadConfig(Class<?>... configs) {
AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext();
applicationContext.register(configs);
Expand Down Expand Up @@ -1192,6 +1219,45 @@ SecurityFilterChain filterChain(HttpSecurity http) throws Exception {

}

@Configuration
@EnableWebSecurity
static class OAuth2LoginConfigWithOAuth2Client extends CommonLambdaSecurityFilterChainConfig {

private final ClientRegistrationRepository clientRegistrationRepository = mock(
ClientRegistrationRepository.class);

private final OAuth2AuthorizedClientRepository authorizedClientRepository = mock(
OAuth2AuthorizedClientRepository.class);

@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
// @formatter:off
http
.oauth2Login((oauth2Login) -> oauth2Login
.clientRegistrationRepository(
new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION))
.authorizedClientRepository(new HttpSessionOAuth2AuthorizedClientRepository())
)
.oauth2Client((oauth2Client) -> oauth2Client
.clientRegistrationRepository(this.clientRegistrationRepository)
.authorizedClientRepository(this.authorizedClientRepository)
);
// @formatter:on
return super.configureFilterChain(http);
}

@Bean
ClientRegistrationRepository clientRegistrationRepository() {
return this.clientRegistrationRepository;
}

@Bean
OAuth2AuthorizedClientRepository authorizedClientRepository() {
return this.authorizedClientRepository;
}

}

private abstract static class CommonSecurityFilterChainConfig {

SecurityFilterChain configureFilterChain(HttpSecurity http) throws Exception {
Expand Down

0 comments on commit 211fa52

Please sign in to comment.