Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow java-saml to be used in non-JavaEE containers #115

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.joda.time.DateTime;
import org.joda.time.Instant;
Expand All @@ -27,6 +24,7 @@
import com.onelogin.saml2.exception.Error;
import com.onelogin.saml2.exception.XMLEntityException;
import com.onelogin.saml2.http.HttpRequest;
import com.onelogin.saml2.http.HttpResponse;
import com.onelogin.saml2.logout.LogoutRequest;
import com.onelogin.saml2.logout.LogoutResponse;
import com.onelogin.saml2.servlet.ServletUtils;
Expand Down Expand Up @@ -57,14 +55,14 @@ public class Auth {
private Saml2Settings settings;

/**
* HttpServletRequest object to be processed (Contains GET and POST parameters, session, ...).
* HttpRequest object to be processed (Contains GET and POST parameters, session, ...).
*/
private HttpServletRequest request;
private HttpRequest request;

/**
* HttpServletResponse object to be used (For example to execute the redirections).
* HttpResponse object to be used (For example to execute the redirections).
*/
private HttpServletResponse response;
private HttpResponse response;

/**
* NameID.
Expand Down Expand Up @@ -168,15 +166,15 @@ public Auth(String filename) throws IOException, SettingsException, Error {
* Initializes the SP SAML instance.
*
* @param request
* HttpServletRequest object to be processed
* HttpRequest object to be processed
* @param response
* HttpServletResponse object to be used
* HttpResponse object to be used
*
* @throws IOException
* @throws SettingsException
* @throws Error
*/
public Auth(HttpServletRequest request, HttpServletResponse response) throws IOException, SettingsException, Error {
public Auth(HttpRequest request, HttpResponse response) throws IOException, SettingsException, Error {
this(new SettingsBuilder().fromFile("onelogin.saml.properties").build(), request, response);
}

Expand All @@ -186,15 +184,15 @@ public Auth(HttpServletRequest request, HttpServletResponse response) throws IOE
* @param filename
* String Filename with the settings
* @param request
* HttpServletRequest object to be processed
* HttpRequest object to be processed
* @param response
* HttpServletResponse object to be used
* HttpResponse object to be used
*
* @throws SettingsException
* @throws IOException
* @throws Error
*/
public Auth(String filename, HttpServletRequest request, HttpServletResponse response) throws SettingsException, IOException, Error {
public Auth(String filename, HttpRequest request, HttpResponse response) throws SettingsException, IOException, Error {
this(new SettingsBuilder().fromFile(filename).build(), request, response);
}

Expand All @@ -204,13 +202,13 @@ public Auth(String filename, HttpServletRequest request, HttpServletResponse res
* @param settings
* Saml2Settings object. Setting data
* @param request
* HttpServletRequest object to be processed
* HttpRequest object to be processed
* @param response
* HttpServletResponse object to be used
* HttpResponse object to be used
*
* @throws SettingsException
*/
public Auth(Saml2Settings settings, HttpServletRequest request, HttpServletResponse response) throws SettingsException {
public Auth(Saml2Settings settings, HttpRequest request, HttpResponse response) throws SettingsException {
this.settings = settings;
this.request = request;
this.response = response;
Expand Down Expand Up @@ -516,11 +514,10 @@ public String getSLOResponseUrl() {
*/
public void processResponse(String requestId) throws Exception {
authenticated = false;
final HttpRequest httpRequest = ServletUtils.makeHttpRequest(this.request);
final String samlResponseParameter = httpRequest.getParameter("SAMLResponse");
final String samlResponseParameter = request.getParameter("SAMLResponse");

if (samlResponseParameter != null) {
SamlResponse samlResponse = new SamlResponse(settings, httpRequest);
SamlResponse samlResponse = new SamlResponse(settings, request);
lastResponse = samlResponse.getSAMLResponseXml();

if (samlResponse.isValid(requestId)) {
Expand Down Expand Up @@ -568,13 +565,12 @@ public void processResponse() throws Exception {
* @throws Exception
*/
public void processSLO(Boolean keepLocalSession, String requestId) throws Exception {
final HttpRequest httpRequest = ServletUtils.makeHttpRequest(this.request);

final String samlRequestParameter = httpRequest.getParameter("SAMLRequest");
final String samlResponseParameter = httpRequest.getParameter("SAMLResponse");
final String samlRequestParameter = request.getParameter("SAMLRequest");
final String samlResponseParameter = request.getParameter("SAMLResponse");

if (samlResponseParameter != null) {
LogoutResponse logoutResponse = new LogoutResponse(settings, httpRequest);
LogoutResponse logoutResponse = new LogoutResponse(settings, request);
lastResponse = logoutResponse.getLogoutResponseXml();
if (!logoutResponse.isValid(requestId)) {
errors.add("invalid_logout_response");
Expand All @@ -591,12 +587,12 @@ public void processSLO(Boolean keepLocalSession, String requestId) throws Except
lastMessageId = logoutResponse.getId();
LOGGER.debug("processSLO success --> " + samlResponseParameter);
if (!keepLocalSession) {
request.getSession().invalidate();
request.invalidateSession();
}
}
}
} else if (samlRequestParameter != null) {
LogoutRequest logoutRequest = new LogoutRequest(settings, httpRequest);
LogoutRequest logoutRequest = new LogoutRequest(settings, request);
lastRequest = logoutRequest.getLogoutRequestXml();
if (!logoutRequest.isValid()) {
errors.add("invalid_logout_request");
Expand All @@ -607,11 +603,11 @@ public void processSLO(Boolean keepLocalSession, String requestId) throws Except
lastMessageId = logoutRequest.getId();
LOGGER.debug("processSLO success --> " + samlRequestParameter);
if (!keepLocalSession) {
request.getSession().invalidate();
request.invalidateSession();
}

String inResponseTo = logoutRequest.id;
LogoutResponse logoutResponseBuilder = new LogoutResponse(settings, httpRequest);
LogoutResponse logoutResponseBuilder = new LogoutResponse(settings, request);
logoutResponseBuilder.build(inResponseTo);
lastResponse = logoutResponseBuilder.getLogoutResponseXml();

Expand Down Expand Up @@ -819,7 +815,7 @@ public String buildResponseSignature(String samlResponse, String relayState, Str
/**
* Generates the Signature for a SAML Response
*
* @param samlResponse
* @param samlMessage
* The SAML Response
* @param relayState
* The RelayState
Expand Down
178 changes: 26 additions & 152 deletions core/src/main/java/com/onelogin/saml2/http/HttpRequest.java
Original file line number Diff line number Diff line change
@@ -1,151 +1,61 @@
package com.onelogin.saml2.http;

import static com.onelogin.saml2.util.Preconditions.checkNotNull;
import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
import com.onelogin.saml2.util.Util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;

import com.onelogin.saml2.util.Util;

/**
* Framework-agnostic representation of an HTTP request.
*
* @since 2.0.0
*/
public final class HttpRequest {

public static final Map<String, List<String>> EMPTY_PARAMETERS = Collections.<String, List<String>>emptyMap();

private final String requestURL;
private final Map<String, List<String>> parameters;
private final String queryString;
public abstract class HttpRequest {

/**
* Creates a new HttpRequest.
*
* @param requestURL the request URL (up to but not including query parameters)
* @throws NullPointerException if requestURL is null
* @deprecated Not providing a queryString can cause HTTP Redirect binding to fail.
* @return true if the request is using a secure scheme (HTTPS)
*/
@Deprecated
public HttpRequest(String requestURL) {
this(requestURL, EMPTY_PARAMETERS);
}
public abstract boolean isSecure();

/**
* Creates a new HttpRequest.
*
* @param requestURL the request URL (up to but not including query parameters)
* @param queryString string that is contained in the request URL after the path
* @return the name of the request protocol (HTTP / HTTPS)
*/
public HttpRequest(String requestURL, String queryString) {
this(requestURL, EMPTY_PARAMETERS, queryString);
}
public abstract String getScheme();

/**
* Creates a new HttpRequest.
*
* @param requestURL the request URL (up to but not including query parameters)
* @param parameters the request query parameters
* @throws NullPointerException if any of the parameters is null
* @deprecated Not providing a queryString can cause HTTP Redirect binding to fail.
* @return the server name in the request e.g. www.example.com
*/
@Deprecated
public HttpRequest(String requestURL, Map<String, List<String>> parameters) {
this(requestURL, parameters, null);
}
public abstract String getServerName();

/**
* Creates a new HttpRequest.
*
* @param requestURL the request URL (up to but not including query parameters)
* @param parameters the request query parameters
* @param queryString string that is contained in the request URL after the path
* @throws NullPointerException if any of the parameters is null
* @return the port over which the request is made e.g. 80 or 443
*/
public HttpRequest(String requestURL, Map<String, List<String>> parameters, String queryString) {
this.requestURL = checkNotNull(requestURL, "requestURL");
this.parameters = unmodifiableCopyOf(checkNotNull(parameters, "queryParams"));
this.queryString = StringUtils.trimToEmpty(queryString);
}
public abstract int getServerPort();

/**
* @param name the query parameter name
* @param value the query parameter value
* @return a new HttpRequest with the given query parameter added
* @throws NullPointerException if any of the parameters is null
* @return the query string part of the URL
*/
public HttpRequest addParameter(String name, String value) {
checkNotNull(name, "name");
checkNotNull(value, "value");

final List<String> oldValues = parameters.containsKey(name) ? parameters.get(name) : new ArrayList<String>();
final List<String> newValues = new ArrayList<>(oldValues);
newValues.add(value);
final Map<String, List<String>> params = new HashMap<>(parameters);
params.put(name, newValues);

return new HttpRequest(requestURL, params, queryString);
}
public abstract String getQueryString();

/**
* @param name the query parameter name
* @return a new HttpRequest with the given query parameter removed
* @throws NullPointerException if any of the parameters is null
* @return the URI the client used to make the request - only includes
* the server path, but not the query string parameters.
*/
public HttpRequest removeParameter(String name) {
checkNotNull(name, "name");

final Map<String, List<String>> params = new HashMap<>(parameters);
params.remove(name);
public abstract String getRequestURI();

return new HttpRequest(requestURL, params, queryString);
}

/**
* The URL the client used to make the request. Includes a protocol, server name, port number, and server path, but
* not the query string parameters.
*
* @return the request URL
*/
public String getRequestURL() {
return requestURL;
}
public abstract String getRequestURL();

/**
* @param name the query parameter name
* @return the first value for the parameter, or null
*/
public String getParameter(String name) {
List<String> values = getParameters(name);
return values.isEmpty() ? null : values.get(0);
}

/**
* @param name the query parameter name
* @return a List containing all values for the parameter
*/
public List<String> getParameters(String name) {
List<String> values = parameters.get(name);
return values != null ? values : Collections.<String>emptyList();
}

/**
* @return a map of all query parameters
*/
public Map<String, List<String>> getParameters() {
return parameters;
}
public abstract String getParameter(String name);

/**
* Return an url encoded get parameter value
Expand All @@ -155,8 +65,8 @@ public Map<String, List<String>> getParameters() {
* @param name
* @return the first value for the parameter, or null
*/
public String getEncodedParameter(String name) {
Matcher matcher = Pattern.compile(Pattern.quote(name) + "=([^&#]+)").matcher(queryString);
public final String getEncodedParameter(String name) {
Matcher matcher = Pattern.compile(Pattern.quote(name) + "=([^&#]+)").matcher(getQueryString());
if (matcher.find()) {
return matcher.group(1);
} else {
Expand All @@ -173,49 +83,13 @@ public String getEncodedParameter(String name) {
* @param defaultValue
* @return the first value for the parameter, or url encoded default value
*/
public String getEncodedParameter(String name, String defaultValue) {
String value = getEncodedParameter(name);
return (value != null ? value : Util.urlEncoder(defaultValue));
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}

if (o == null || getClass() != o.getClass()) {
return false;
}

HttpRequest that = (HttpRequest) o;
return Objects.equals(requestURL, that.requestURL) &&
Objects.equals(parameters, that.parameters) &&
Objects.equals(queryString, that.queryString);
public final String getEncodedParameter(String name, String defaultValue) {
String value = getEncodedParameter(name);
return (value != null ? value : Util.urlEncoder(defaultValue));
}

@Override
public int hashCode() {
return Objects.hash(requestURL, parameters, queryString);
}

@Override
public String toString() {
return "HttpRequest{" +
"requestURL='" + requestURL + '\'' +
", parameters=" + parameters +
", queryString=" + queryString +
'}';
}

private static Map<String, List<String>> unmodifiableCopyOf(Map<String, List<String>> orig) {
Map<String, List<String>> copy = new HashMap<>();
for (Map.Entry<String, List<String>> entry : orig.entrySet()) {
copy.put(entry.getKey(), unmodifiableList(new ArrayList<>(entry.getValue())));
}

return unmodifiableMap(copy);
}


/**
* Invalidate the current session
*/
public abstract void invalidateSession();
}
Loading