Skip to content

Commit

Permalink
feat: basic SAML SP metadata for non-default ID zone
Browse files Browse the repository at this point in the history
- fix a mistake where we set assertingPartyDetails.wantAuthnRequestsSigned based
on the user config `login.saml.signRequest` (in reality, this assertingPartyDetails.wantAuthnRequestsSigned
should depend on the SAML IDP's declared preference, aka it's IDP metadata). Now, the impact
of `login.saml.signRequest` is more appropriately scoped to only controlling whether the SAML
SP metadata declares that the SP signs its outgoing requests.

- correctly populates the basic fields of non-default zone SAML SP metadata (such as
WantAssertionsSigned and AuthnRequestsSigned), so that for default vs. non-default zones, the
SP metadatas have feature parity.

[#187846376]

Signed-off-by: Duane May <[email protected]>
Signed-off-by: Peter Chen <[email protected]>
  • Loading branch information
duanemay authored and peterhaochen47 committed Jul 10, 2024
1 parent 55b998d commit 745fff3
Show file tree
Hide file tree
Showing 15 changed files with 205 additions and 118 deletions.
1 change: 1 addition & 0 deletions server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ dependencies {

testImplementation(libraries.jsonPathAssert)
testImplementation(libraries.guavaTestLib)
testImplementation(libraries.xmlUnit)

implementation(libraries.commonsIo)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public class IdentityZoneConfigurationBootstrap implements InitializingBean {
private String samlSpPrivateKeyPassphrase;
private String samlSpCertificate;
private boolean disableSamlInResponseToCheck = false;
private boolean samlWantAssertionSigned = true;
private boolean samlRequestSigned = true;

private Map<String, Map<String, String>> samlKeys;
private String activeKeyId;
Expand Down Expand Up @@ -89,6 +91,8 @@ public void afterPropertiesSet() throws InvalidIdentityZoneDetailsException {
definition.getSamlConfig().setPrivateKey(samlSpPrivateKey);
definition.getSamlConfig().setPrivateKeyPassword(samlSpPrivateKeyPassphrase);
definition.getSamlConfig().setDisableInResponseToCheck(disableSamlInResponseToCheck);
definition.getSamlConfig().setWantAssertionSigned(samlWantAssertionSigned);
definition.getSamlConfig().setRequestSigned(samlRequestSigned);
definition.setIdpDiscoveryEnabled(idpDiscoveryEnabled);
definition.setAccountChooserEnabled(accountChooserEnabled);
definition.setDefaultIdentityProvider(defaultIdentityProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@ public class ConfiguratorRelyingPartyRegistrationRepository

private final SamlIdentityProviderConfigurator configurator;
private final KeyWithCert keyWithCert;
private final Boolean samlSignRequest;
private final String samlEntityID;

public ConfiguratorRelyingPartyRegistrationRepository(Boolean samlSignRequest,
@Qualifier("samlEntityID") String samlEntityID,
public ConfiguratorRelyingPartyRegistrationRepository(@Qualifier("samlEntityID") String samlEntityID,
KeyWithCert keyWithCert,
SamlIdentityProviderConfigurator configurator) {
Assert.notNull(configurator, "configurator cannot be null");
this.configurator = configurator;
this.keyWithCert = keyWithCert;
this.samlSignRequest = samlSignRequest;
this.samlEntityID = samlEntityID;
}

Expand All @@ -45,7 +42,7 @@ public RelyingPartyRegistration findByRegistrationId(String registrationId) {
for (SamlIdentityProviderDefinition identityProviderDefinition : identityProviderDefinitions) {
if (identityProviderDefinition.getIdpEntityAlias().equals(registrationId)) {
return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, identityProviderDefinition.getNameID(), samlSignRequest,
samlEntityID, identityProviderDefinition.getNameID(),
keyWithCert, identityProviderDefinition.getMetaDataLocation(), registrationId);
}
}
Expand All @@ -69,7 +66,7 @@ else if (zone.getConfig() != null && zone.getConfig().getSamlConfig() != null) {
}

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, null, samlSignRequest,
samlEntityID, null,
keyWithCert, "dummy-saml-idp-metadata.xml", null,
samlServiceUri);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ private RelyingPartyRegistrationBuilder() {
}

public static RelyingPartyRegistration buildRelyingPartyRegistration(
String samlEntityID, String samlSpNameId, boolean samlSignRequest,
String samlEntityID, String samlSpNameId,
KeyWithCert keyWithCert,
String metadataLocation, String rpRegstrationId) {
return buildRelyingPartyRegistration(samlEntityID, samlSpNameId,
samlSignRequest, keyWithCert, metadataLocation, rpRegstrationId,
keyWithCert, metadataLocation, rpRegstrationId,
samlEntityID);
}

public static RelyingPartyRegistration buildRelyingPartyRegistration(
String samlEntityID, String samlSpNameId, boolean samlSignRequest,
String samlEntityID, String samlSpNameId,
KeyWithCert keyWithCert, String metadataLocation,
String rpRegstrationId, String samlServiceUri) {
SamlIdentityProviderDefinition.MetadataLocation type = SamlIdentityProviderDefinition.getType(metadataLocation);
Expand Down Expand Up @@ -64,9 +64,6 @@ public static RelyingPartyRegistration buildRelyingPartyRegistration(
c.add(Saml2MessageBinding.REDIRECT);
c.add(Saml2MessageBinding.POST);
})
.assertingPartyDetails(details -> details
.wantAuthnRequestsSigned(samlSignRequest)
)
.signingX509Credentials(cred -> cred
.add(Saml2X509Credential.signing(keyWithCert.getPrivateKey(), keyWithCert.getCertificate()))
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package org.cloudfoundry.identity.uaa.provider.saml;

import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.SamlConfig;
import org.cloudfoundry.identity.uaa.zone.ZoneAware;
import org.cloudfoundry.identity.uaa.zone.beans.IdentityZoneManager;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
Expand All @@ -11,14 +13,11 @@
import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.util.Assert;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RestController;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
Expand All @@ -27,80 +26,73 @@
@RestController
public class SamlMetadataEndpoint implements ZoneAware {
public static final String DEFAULT_REGISTRATION_ID = "example";
private static final String DEFAULT_FILE_NAME = "saml-sp.xml";
private static final String APPLICATION_XML_CHARSET_UTF_8 = "application/xml; charset=UTF-8";
private static final String CONTENT_DISPOSITION_FORMAT = "attachment; filename=\"%s\"; filename*=UTF-8''%s";

// @todo - this should be a Zone aware resolver
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
private final Saml2MetadataResolver saml2MetadataResolver;
private final IdentityZoneManager identityZoneManager;

private String fileName;
private String encodedFileName;

private final Boolean wantAssertionSigned;
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;

public SamlMetadataEndpoint(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
SamlConfigProps samlConfigProps) {
IdentityZoneManager identityZoneManager) {
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
this.relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository);
this.identityZoneManager = identityZoneManager;
OpenSamlMetadataResolver resolver = new OpenSamlMetadataResolver();
this.saml2MetadataResolver = resolver;
resolver.setEntityDescriptorCustomizer(new EntityDescriptorCustomizer());
this.wantAssertionSigned = samlConfigProps.getWantAssertionSigned();
setFileName(DEFAULT_FILE_NAME);
}

private class EntityDescriptorCustomizer implements Consumer<OpenSamlMetadataResolver.EntityDescriptorParameters> {
@Override
public void accept(OpenSamlMetadataResolver.EntityDescriptorParameters entityDescriptorParameters) {
SamlConfig samlConfig = identityZoneManager.getCurrentIdentityZone().getConfig().getSamlConfig();

EntityDescriptor descriptor = entityDescriptorParameters.getEntityDescriptor();
SPSSODescriptor spssodescriptor = descriptor.getSPSSODescriptor(SAMLConstants.SAML20P_NS);
spssodescriptor.setWantAssertionsSigned(wantAssertionSigned);
spssodescriptor.setAuthnRequestsSigned(entityDescriptorParameters.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned());
spssodescriptor.setWantAssertionsSigned(samlConfig.isWantAssertionSigned());
spssodescriptor.setAuthnRequestsSigned(samlConfig.isRequestSigned());
}
}

@GetMapping(value = "/saml/metadata", produces = APPLICATION_XML_CHARSET_UTF_8)
public ResponseEntity<String> legacyMetadataEndpoint(HttpServletRequest request) {
return metadataEndpoint(DEFAULT_REGISTRATION_ID, request);
public ResponseEntity<String> legacyMetadataEndpoint() {
return metadataEndpoint(DEFAULT_REGISTRATION_ID);
}

@GetMapping(value = "/saml/metadata/{registrationId}", produces = APPLICATION_XML_CHARSET_UTF_8)
public ResponseEntity<String> metadataEndpoint(@PathVariable String registrationId, HttpServletRequest request) {
public ResponseEntity<String> metadataEndpoint(@PathVariable String registrationId) {
RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
if (relyingPartyRegistration == null) {
return ResponseEntity.status(HttpServletResponse.SC_UNAUTHORIZED).build();
}
String metadata = saml2MetadataResolver.resolve(relyingPartyRegistration);

// @todo - fileName may need to be dynamic based on registrationID
String[] fileNames = retrieveZoneAwareFileNames();
String contentDisposition = ContentDispositionFilename.getContentDisposition(retrieveZone());
return ResponseEntity.ok()
.header(HttpHeaders.CONTENT_DISPOSITION, String.format(
CONTENT_DISPOSITION_FORMAT, fileNames[0], fileNames[1]))
.header(HttpHeaders.CONTENT_DISPOSITION, contentDisposition)
.body(metadata);
}
}

public void setFileName(String fileName) {
encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8);
this.fileName = fileName;
}
record ContentDispositionFilename(String fileName) {
private static final String CONTENT_DISPOSITION_FORMAT = "attachment; filename=\"%s\"; filename*=UTF-8''%s";
private static final String DEFAULT_FILE_NAME = "saml-sp.xml";

private String[] retrieveZoneAwareFileNames() {
IdentityZone zone = retrieveZone();
String[] fileNames = new String[2];
static ContentDispositionFilename retrieveZoneAwareContentDispositionFilename(IdentityZone zone) {
if (zone.isUaa()) {
fileNames[0] = fileName;
fileNames[1] = encodedFileName;
}
else {
fileNames[0] = "saml-" + zone.getSubdomain() + "-sp.xml";
fileNames[1] = URLEncoder.encode(fileNames[0],
StandardCharsets.UTF_8);
return new ContentDispositionFilename(DEFAULT_FILE_NAME);
}
return fileNames;
String filename = "saml-%s-sp.xml".formatted(zone.getSubdomain());
return new ContentDispositionFilename(filename);
}

static String getContentDisposition(IdentityZone zone) {
return retrieveZoneAwareContentDispositionFilename(zone).getContentDisposition();
}

String getContentDisposition() {
String encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8);
return CONTENT_DISPOSITION_FORMAT.formatted(fileName, encodedFileName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,17 @@ public class SamlRelyingPartyRegistrationRepositoryConfig {
private final SamlConfigProps samlConfigProps;
private final BootstrapSamlIdentityProviderData bootstrapSamlIdentityProviderData;
private final String samlSpNameID;
private final Boolean samlSignRequest;

public SamlRelyingPartyRegistrationRepositoryConfig(@Qualifier("samlEntityID") String samlEntityID,
SamlConfigProps samlConfigProps,
BootstrapSamlIdentityProviderData bootstrapSamlIdentityProviderData,
@Value("${login.saml.nameID:urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified}")
String samlSpNameID,
@Value("${login.saml.signRequest:true}")
Boolean samlSignRequest
String samlSpNameID
) {
this.samlEntityID = samlEntityID;
this.samlConfigProps = samlConfigProps;
this.bootstrapSamlIdentityProviderData = bootstrapSamlIdentityProviderData;
this.samlSpNameID = samlSpNameID;
this.samlSignRequest = samlSignRequest;
}

@Autowired
Expand All @@ -69,20 +65,20 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti
// even when there are no SAML IDPs configured.
// See relevant issue: https://github.com/spring-projects/spring-security/issues/11369
RelyingPartyRegistration defaultRelyingPartyRegistration = RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, samlSpNameID, samlSignRequest, keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, DEFAULT_REGISTRATION_ID);
samlEntityID, samlSpNameID, keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, DEFAULT_REGISTRATION_ID);
relyingPartyRegistrations.add(defaultRelyingPartyRegistration);

for (SamlIdentityProviderDefinition samlIdentityProviderDefinition : bootstrapSamlIdentityProviderData.getIdentityProviderDefinitions()) {
relyingPartyRegistrations.add(
RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, samlSpNameID, samlSignRequest, keyWithCert,
samlEntityID, samlSpNameID, keyWithCert,
samlIdentityProviderDefinition.getMetaDataLocation(),
samlIdentityProviderDefinition.getIdpEntityAlias())
);
}

InMemoryRelyingPartyRegistrationRepository bootstrapRepo = new InMemoryRelyingPartyRegistrationRepository(relyingPartyRegistrations);
ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlSignRequest, samlEntityID, keyWithCert, samlIdentityProviderConfigurator);
ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, keyWithCert, samlIdentityProviderConfigurator);
return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,20 @@ void keyIdNullException() {
}

@Test
void defaultSamlKeys() throws Exception {
void samlKeysAndSigningConfigs() throws Exception {
bootstrap.setSamlSpPrivateKey(SamlTestUtils.PROVIDER_PRIVATE_KEY);
bootstrap.setSamlSpCertificate(SamlTestUtils.PROVIDER_CERTIFICATE);
bootstrap.setSamlSpPrivateKeyPassphrase(SamlTestUtils.PROVIDER_PRIVATE_KEY_PASSWORD);
bootstrap.setSamlWantAssertionSigned(false);
bootstrap.setSamlRequestSigned(false);
bootstrap.afterPropertiesSet();

IdentityZone uaa = provisioning.retrieve(IdentityZone.getUaaZoneId());
assertThat(uaa.getConfig().getSamlConfig().getPrivateKey()).isEqualTo(SamlTestUtils.PROVIDER_PRIVATE_KEY);
assertThat(uaa.getConfig().getSamlConfig().getPrivateKeyPassword()).isEqualTo(SamlTestUtils.PROVIDER_PRIVATE_KEY_PASSWORD);
assertThat(uaa.getConfig().getSamlConfig().getCertificate()).isEqualTo(SamlTestUtils.PROVIDER_CERTIFICATE);
assertThat(uaa.getConfig().getSamlConfig().isWantAssertionSigned()).isEqualTo(false);
assertThat(uaa.getConfig().getSamlConfig().isRequestSigned()).isEqualTo(false);
}

@Test
Expand Down Expand Up @@ -253,7 +258,6 @@ void logoutRedirect() throws Exception {
assertThat(config.getLinks().getLogout().isDisableRedirectParameter()).isFalse();
}


@Test
void testPrompts() throws Exception {
List<Prompt> prompts = Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ class ConfiguratorRelyingPartyRegistrationRepositoryTest {

@BeforeEach
void setUp() {
repository = new ConfiguratorRelyingPartyRegistrationRepository(true, ENTITY_ID, mockKeyWithCert,
repository = new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, mockKeyWithCert,
mockConfigurator);
}

@Test
void constructorWithNullConfiguratorThrows() {
assertThatThrownBy(() -> new ConfiguratorRelyingPartyRegistrationRepository(
true, ENTITY_ID, mockKeyWithCert, null)
ENTITY_ID, mockKeyWithCert, null)
).isInstanceOf(IllegalArgumentException.class);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class RelyingPartyRegistrationBuilderTest {
private static final String ENTITY_ID = "entityId";
private static final String NAME_ID = "nameIdFormat";
private static final String REGISTRATION_ID = "registrationId";
private static final boolean SIGN_REQUEST = true;

@Mock
private KeyWithCert mockKeyWithCert;
Expand All @@ -42,7 +41,7 @@ void buildsRelyingPartyRegistrationFromLocation() {
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));

RelyingPartyRegistration registration = RelyingPartyRegistrationBuilder
.buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, SIGN_REQUEST, mockKeyWithCert, "saml-sample-metadata.xml", REGISTRATION_ID);
.buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, mockKeyWithCert, "saml-sample-metadata.xml", REGISTRATION_ID);
assertThat(registration)
.returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId)
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
Expand All @@ -62,7 +61,7 @@ void buildsRelyingPartyRegistrationFromXML() {

String metadataXml = loadResouceAsString("saml-sample-metadata.xml");
RelyingPartyRegistration registration = RelyingPartyRegistrationBuilder
.buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, SIGN_REQUEST, mockKeyWithCert, metadataXml, REGISTRATION_ID);
.buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, mockKeyWithCert, metadataXml, REGISTRATION_ID);

assertThat(registration)
.returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId)
Expand All @@ -81,7 +80,7 @@ void failsWithInvalidXML() {
String metadataXml = "<?xml version=\"1.0\"?>\n<xml>invalid xml</xml>";
assertThatThrownBy(() ->
RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(ENTITY_ID, NAME_ID,
SIGN_REQUEST, mockKeyWithCert, metadataXml, REGISTRATION_ID))
mockKeyWithCert, metadataXml, REGISTRATION_ID))
.isInstanceOf(Saml2Exception.class)
.hasMessageContaining("Unsupported element");
}
Expand Down
Loading

0 comments on commit 745fff3

Please sign in to comment.