Skip to content

Commit

Permalink
fixes #293
Browse files Browse the repository at this point in the history
Signed-off-by: appiepollo14 <[email protected]>
  • Loading branch information
appiepollo14 committed Oct 22, 2024
1 parent 7d0da60 commit 6f90fb4
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Comparator;

import jakarta.ws.rs.ApplicationPath;
import jakarta.ws.rs.core.Application;
import jakarta.ws.rs.ext.MessageBodyReader;
import jakarta.ws.rs.ext.MessageBodyWriter;

import org.apache.cxf.jaxrs.client.JAXRSClientFactoryBean;
import org.junit.platform.commons.support.AnnotationSupport;
import org.junit.platform.commons.support.ReflectionSupport;
Expand All @@ -55,8 +55,8 @@ public class RestClientBuilder {

/**
* @param appContextRoot The protocol, hostname, port, and application root path for the REST Client
* For example, <code>http://localhost:8080/myapp/</code>. If unspecified, the app context
* root will be automatically detected by {@link ApplicationEnvironment#getApplicationURL()}
* For example, <code>http://localhost:8080/myapp/</code>. If unspecified, the app context
* root will be automatically detected by {@link ApplicationEnvironment#getApplicationURL()}
* @return The same builder instance
*/
public RestClientBuilder withAppContextRoot(String appContextRoot) {
Expand All @@ -67,9 +67,9 @@ public RestClientBuilder withAppContextRoot(String appContextRoot) {

/**
* @param jaxrsPath The portion of the path after the app context root. For example, if a JAX-RS
* endpoint is deployed at <code>http://localhost:8080/myapp/hello</code> and the app context root
* is <code>http://localhost:8080/myapp/</code>, then the jaxrsPath is <code>hello</code>. If
* unspecified, the JAX-RS path will be automatically detected by annotation scanning.
* endpoint is deployed at <code>http://localhost:8080/myapp/hello</code> and the app context root
* is <code>http://localhost:8080/myapp/</code>, then the jaxrsPath is <code>hello</code>. If
* unspecified, the JAX-RS path will be automatically detected by annotation scanning.
* @return The same builder instance
*/
public RestClientBuilder withJaxrsPath(String jaxrsPath) {
Expand All @@ -93,7 +93,7 @@ public RestClientBuilder withJwt(String jwt) {
}

/**
* @param user The username portion of the Basic auth header
* @param user The username portion of the Basic auth header
* @param password The password portion of the Basic auth header
* @return The same builder instance
*/
Expand All @@ -110,7 +110,7 @@ public RestClientBuilder withBasicAuth(String user, String password) {
}

/**
* @param key The header key
* @param key The header key
* @param value The header value
* @return The same builder instance
*/
Expand All @@ -126,8 +126,8 @@ public RestClientBuilder withHeader(String key, String value) {

/**
* @param providers One or more providers to apply. Providers typically implement
* {@link MessageBodyReader} and/or {@link MessageBodyWriter}. If unspecified,
* the {@link JsonBProvider} will be applied.
* {@link MessageBodyReader} and/or {@link MessageBodyWriter}. If unspecified,
* the {@link JsonBProvider} will be applied.
* @return The same builder instance
*/
public RestClientBuilder withProviders(Class<?>... providers) {
Expand All @@ -145,7 +145,7 @@ public <T> T build(Class<T> clazz) {
providers = Collections.singletonList(JsonBProvider.class);

JAXRSClientFactoryBean bean = new org.apache.cxf.jaxrs.client.JAXRSClientFactoryBean();
String basePath = join(appContextRoot, jaxrsPath);
String basePath = joinPaths(appContextRoot, jaxrsPath);
LOG.info("Building rest client for " + clazz + " with base path: " + basePath + " and providers: " + providers);
bean.setResourceClass(clazz);
bean.setProviders(providers);
Expand All @@ -163,10 +163,10 @@ private static String locateApplicationPath(Class<?> clazz) {

// First check for a jakarta.ws.rs.core.Application in the same package as the resource
List<Class<?>> appClasses = ReflectionSupport.findAllClassesInPackage(resourcePackage,
c -> Application.class.isAssignableFrom(c) &&
AnnotationSupport.isAnnotated(c, ApplicationPath.class),
n -> true);
if (appClasses.size() == 0) {
c -> Application.class.isAssignableFrom(c) &&
AnnotationSupport.isAnnotated(c, ApplicationPath.class),
n -> true);
if (appClasses.isEmpty()) {
LOG.debug("no classes implementing Application found in pkg: " + resourcePackage);
// If not found, check under the 3rd package, so com.foo.bar.*
// Classpath scanning can be expensive, so we jump straight to the 3rd package from root instead
Expand All @@ -176,39 +176,40 @@ private static String locateApplicationPath(Class<?> clazz) {
String checkPkg = pkgs[0] + '.' + pkgs[1] + '.' + pkgs[2];
LOG.debug("checking in pkg: " + checkPkg);
appClasses = ReflectionSupport.findAllClassesInPackage(checkPkg,
c -> Application.class.isAssignableFrom(c) &&
AnnotationSupport.isAnnotated(c, ApplicationPath.class),
n -> true);
c -> Application.class.isAssignableFrom(c) &&
AnnotationSupport.isAnnotated(c, ApplicationPath.class),
n -> true);
}
}

if (appClasses.size() == 0) {
if (appClasses.isEmpty()) {
LOG.info("No classes implementing 'jakarta.ws.rs.core.Application' found on classpath to set base path from " + clazz +
". Defaulting base path to '/'");
". Defaulting base path to '/'");
return "";
}

Class<?> selectedClass = appClasses.stream()
.sorted((c1, c2) -> c1.getName().compareTo(c2.getName()))
.findFirst()
.get();
.sorted(Comparator.comparing(Class::getName))
.findFirst()
.get();
ApplicationPath appPath = AnnotationSupport.findAnnotation(selectedClass, ApplicationPath.class).get();
if (appClasses.size() > 1) {
LOG.warn("Found multiple classes implementing 'jakarta.ws.rs.core.Application' on classpath: " + appClasses +
". Setting base path from the first class discovered (" + selectedClass.getCanonicalName() + ") with path: " +
appPath.value());
". Setting base path from the first class discovered (" + selectedClass.getCanonicalName() + ") with path: " +
appPath.value());
}
LOG.debug("Using base ApplicationPath of '" + appPath.value() + "'");
return appPath.value();
}

private static String join(String firstPart, String secondPart) {
if (firstPart.endsWith("/") && secondPart.startsWith("/"))
return firstPart + secondPart.substring(1);
else if (firstPart.endsWith("/") || secondPart.startsWith("/"))
return firstPart + secondPart;
private static String joinPaths(String appContextRoot, String jaxrsPath) {
if (appContextRoot.endsWith("/") && jaxrsPath.startsWith("/"))
return appContextRoot + jaxrsPath.substring(1);
else if (appContextRoot.endsWith("/") || jaxrsPath.startsWith("/"))
return appContextRoot + jaxrsPath;
else
return firstPart + "/" + secondPart;
return appContextRoot + "/" + jaxrsPath;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@
*/
package org.microshed.testing.jupiter;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Properties;

import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionConfigurationException;
import org.junit.jupiter.api.extension.ExtensionContext;
Expand All @@ -44,6 +33,12 @@
import org.microshed.testing.kafka.KafkaConsumerClient;
import org.microshed.testing.kafka.KafkaProducerClient;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.net.URL;
import java.util.*;

/**
* JUnit Jupiter extension that is applied whenever the <code>@MicroProfileTest</code> is used on a test class.
* Currently this is tied to Testcontainers managing runtime build/deployment, but in a future version
Expand Down Expand Up @@ -90,8 +85,8 @@ private static void injectRestClients(Class<?> clazz) {

for (Field restClientField : restClientFields) {
if (!Modifier.isPublic(restClientField.getModifiers()) ||
!Modifier.isStatic(restClientField.getModifiers()) ||
Modifier.isFinal(restClientField.getModifiers())) {
!Modifier.isStatic(restClientField.getModifiers()) ||
Modifier.isFinal(restClientField.getModifiers())) {
throw new ExtensionConfigurationException("REST client field must be public, static, and non-final: " + restClientField);
}
RestClientBuilder rcBuilder = new RestClientBuilder();
Expand Down Expand Up @@ -137,10 +132,10 @@ private static void injectKafkaClients(Class<?> clazz) {
throw new ExtensionConfigurationException("Fields annotated with @KafkaProducerClient must be of the type " + KafkaProducer.getName());
}
if (!Modifier.isPublic(producerField.getModifiers()) ||
!Modifier.isStatic(producerField.getModifiers()) ||
Modifier.isFinal(producerField.getModifiers())) {
!Modifier.isStatic(producerField.getModifiers()) ||
Modifier.isFinal(producerField.getModifiers())) {
throw new ExtensionConfigurationException("The KafkaProducer field annotated with @KafkaProducerClient " +
"must be public, static, and non-final: " + producerField);
"must be public, static, and non-final: " + producerField);
}

Properties properties = kafkaProcessor.getProducerProperties(producerField);
Expand All @@ -159,10 +154,10 @@ private static void injectKafkaClients(Class<?> clazz) {
throw new ExtensionConfigurationException("Fields annotated with @KafkaConsumerClient must be of the type " + KafkaConsumer.getName());
}
if (!Modifier.isPublic(consumerField.getModifiers()) ||
!Modifier.isStatic(consumerField.getModifiers()) ||
Modifier.isFinal(consumerField.getModifiers())) {
!Modifier.isStatic(consumerField.getModifiers()) ||
Modifier.isFinal(consumerField.getModifiers())) {
throw new ExtensionConfigurationException("The KafkaProducer field annotated with @KafkaConsumerClient " +
"must be public, static, and non-final: " + consumerField);
"must be public, static, and non-final: " + consumerField);
}

Properties properties = kafkaProcessor.getConsumerProperties(consumerField);
Expand All @@ -182,7 +177,7 @@ private static void injectKafkaClients(Class<?> clazz) {
}
}

@SuppressWarnings({ "unchecked", "rawtypes" })
@SuppressWarnings({"unchecked", "rawtypes"})
private static void configureRestAssured(ApplicationEnvironment config) {
if (!config.configureRestAssured())
return;
Expand Down
10 changes: 6 additions & 4 deletions core/src/main/java/org/microshed/testing/jwt/JwtConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@
*/
package org.microshed.testing.jwt;

import org.junit.jupiter.api.extension.ExtendWith;
import org.microshed.testing.jaxrs.RESTClient;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.microshed.testing.jaxrs.RESTClient;

/**
* Used to annotate a REST Client to configure MicroProfile JWT settings
* that will be applied to all of its HTTP invocations.
* In order for this annotation to have any effect, the field must also
* be annotated with {@link RESTClient}.
*/
@Target({ ElementType.FIELD })
@Target({ElementType.FIELD, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@ExtendWith(JwtConfigExtension.class)
public @interface JwtConfig {

public static final String DEFAULT_ISSUER = "http://testissuer.com";
Expand All @@ -46,7 +48,7 @@
* array of claims in the following format:
* key=value
* example: {"sub=fred", "upn=fred", "kid=123"}
*
* <p>
* For arrays, separate values with a comma.
* example: {"groups=red,green,admin", "sub=fred"}
*
Expand Down
122 changes: 122 additions & 0 deletions core/src/main/java/org/microshed/testing/jwt/JwtConfigExtension.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package org.microshed.testing.jwt;

import org.jose4j.jwt.MalformedClaimException;
import org.jose4j.lang.JoseException;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionConfigurationException;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.microshed.testing.internal.InternalLogger;
import org.microshed.testing.jupiter.MicroShedTestExtension;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class JwtConfigExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback {

private static final InternalLogger LOG = InternalLogger.get(JwtConfigExtension.class);

@Override
public void beforeTestExecution(ExtensionContext context) throws Exception {
configureJwt(context);
}

@Override
public void afterTestExecution(ExtensionContext context) {
removeJwt(context);
}

private void configureJwt(ExtensionContext context) throws ExtensionConfigurationException {

// Check if the test method has the @JwtConfig annotation
Method testMethod = context.getTestMethod().orElse(null);
if (testMethod != null) {

// Check if RestAssured is being used
Class<?> restAssuredClass = tryLoad("io.restassured.RestAssured");
if (restAssuredClass == null) {
LOG.debug("RESTAssured not found!");
} else {
LOG.debug("RESTAssured found!");

JwtConfig jwtConfig = testMethod.getAnnotation(JwtConfig.class);
if (jwtConfig != null) {
LOG.info("JWTConfig on method: " + testMethod.getName());

try {
// Get the RequestSpecBuilder class
Class<?> requestSpecBuilderClass = Class.forName("io.restassured.builder.RequestSpecBuilder");

// Create an instance of RequestSpecBuilder
Object requestSpecBuilder = requestSpecBuilderClass.getDeclaredConstructor().newInstance();

// Get the requestSpecification field
Field requestSpecificationField = restAssuredClass.getDeclaredField("requestSpecification");
requestSpecificationField.setAccessible(true);

// Get the header method of RequestSpecBuilder
Method headerMethod = requestSpecBuilderClass.getDeclaredMethod("addHeader", String.class, String.class);

// Build the JWT and add it to the header
String jwt = JwtBuilder.buildJwt(jwtConfig.subject(), jwtConfig.issuer(), jwtConfig.claims());
headerMethod.invoke(requestSpecBuilder, "Authorization", "Bearer " + jwt);
LOG.debug("Using provided JWT auth header: " + jwt);

// Set the updated requestSpecification
requestSpecificationField.set(null, requestSpecBuilderClass.getMethod("build").invoke(requestSpecBuilder));

} catch (ClassNotFoundException e) {
throw new ExtensionConfigurationException("Class 'RequestSpecBuilder' not found for method " + testMethod.getName(), e);
} catch (InstantiationException | IllegalAccessException e) {
throw new ExtensionConfigurationException("Error instantiating 'RequestSpecBuilder' for method " + testMethod.getName(), e);
} catch (NoSuchFieldException e) {
throw new ExtensionConfigurationException("Field 'requestSpecification' not found in RestAssured for method " + testMethod.getName(), e);
} catch (NoSuchMethodException e) {
throw new ExtensionConfigurationException("Method 'addHeader' or 'build' not found in 'RequestSpecBuilder' for method " + testMethod.getName(), e);
} catch (InvocationTargetException e) {
throw new ExtensionConfigurationException("Error invoking method on 'RequestSpecBuilder' for method " + testMethod.getName(), e);
} catch (MalformedClaimException | JoseException e) {
throw new ExtensionConfigurationException("Error building JWT", e);
}
}
}
}
}

private void removeJwt(ExtensionContext context) throws ExtensionConfigurationException {
// Check if the test method has the @JwtConfig annotation
Method testMethod = context.getTestMethod().orElse(null);
if (testMethod != null) {
LOG.debug("Method was annotated with: " + testMethod.getName());

// Check if RestAssured is being used
Class<?> restAssuredClass = tryLoad("io.restassured.RestAssured");
if (restAssuredClass == null) {
LOG.debug("RESTAssured not found!");
} else {
try {
// Get the requestSpecification field
Field requestSpecificationField = restAssuredClass.getDeclaredField("requestSpecification");
requestSpecificationField.setAccessible(true);

// Removes all requestSpec
requestSpecificationField.set(null, null);

} catch (NoSuchFieldException e) {
throw new ExtensionConfigurationException("Field 'requestSpecification' not found in RestAssured", e);
} catch (IllegalAccessException e) {
throw new ExtensionConfigurationException("Error accessing 'requestSpecification' field in RestAssured", e);
}
}
}
}

private static Class<?> tryLoad(String clazz) {
try {
return Class.forName(clazz, false, MicroShedTestExtension.class.getClassLoader());
} catch (ClassNotFoundException | LinkageError e) {
return null;
}
}
}
Loading

0 comments on commit 6f90fb4

Please sign in to comment.