diff --git a/core/src/main/java/org/microshed/testing/jwt/JwtConfigExtension.java b/core/src/main/java/org/microshed/testing/jwt/JwtConfigExtension.java index 9b70592e..4f2acc41 100644 --- a/core/src/main/java/org/microshed/testing/jwt/JwtConfigExtension.java +++ b/core/src/main/java/org/microshed/testing/jwt/JwtConfigExtension.java @@ -12,103 +12,108 @@ import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; public class JwtConfigExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback { private static final InternalLogger LOG = InternalLogger.get(JwtConfigExtension.class); + private final AtomicBoolean needsRemoval = new AtomicBoolean(false); + @Override public void beforeTestExecution(ExtensionContext context) throws Exception { - configureJwt(context); + // Check if the test method has the @JwtConfig annotation + context.getTestMethod().ifPresent(testMethod -> { + JwtConfig jwtConfig = testMethod.getAnnotation(JwtConfig.class); + if (Objects.isNull(jwtConfig)) { + return; + } + LOG.info("JwtConfig on method: " + testMethod.getName()); + configureJwt(testMethod, jwtConfig); + needsRemoval.set(true); + }); } @Override public void afterTestExecution(ExtensionContext context) { - removeJwt(context); - } - - private void configureJwt(ExtensionContext context) throws ExtensionConfigurationException { - // Check if the test method has the @JwtConfig annotation - Method testMethod = context.getTestMethod().orElse(null); - if (testMethod != null) { - - // Check if RestAssured is being used - Class restAssuredClass = tryLoad("io.restassured.RestAssured"); - if (restAssuredClass == null) { - LOG.debug("RESTAssured not found!"); - } else { - LOG.debug("RESTAssured found!"); - + context.getTestMethod().ifPresent(testMethod -> { + if(needsRemoval.compareAndSet(true, false)) { JwtConfig jwtConfig = testMethod.getAnnotation(JwtConfig.class); - if (jwtConfig != null) { - LOG.info("JWTConfig on method: " + testMethod.getName()); - - try { - // Get the RequestSpecBuilder class - Class requestSpecBuilderClass = Class.forName("io.restassured.builder.RequestSpecBuilder"); - - // Create an instance of RequestSpecBuilder - Object requestSpecBuilder = requestSpecBuilderClass.getDeclaredConstructor().newInstance(); - - // Get the requestSpecification field - Field requestSpecificationField = restAssuredClass.getDeclaredField("requestSpecification"); - requestSpecificationField.setAccessible(true); - - // Get the header method of RequestSpecBuilder - Method headerMethod = requestSpecBuilderClass.getDeclaredMethod("addHeader", String.class, String.class); - - // Build the JWT and add it to the header - String jwt = JwtBuilder.buildJwt(jwtConfig.subject(), jwtConfig.issuer(), jwtConfig.claims()); - headerMethod.invoke(requestSpecBuilder, "Authorization", "Bearer " + jwt); - LOG.debug("Using provided JWT auth header: " + jwt); - - // Set the updated requestSpecification - requestSpecificationField.set(null, requestSpecBuilderClass.getMethod("build").invoke(requestSpecBuilder)); - - } catch (ClassNotFoundException e) { - throw new ExtensionConfigurationException("Class 'RequestSpecBuilder' not found for method " + testMethod.getName(), e); - } catch (InstantiationException | IllegalAccessException e) { - throw new ExtensionConfigurationException("Error instantiating 'RequestSpecBuilder' for method " + testMethod.getName(), e); - } catch (NoSuchFieldException e) { - throw new ExtensionConfigurationException("Field 'requestSpecification' not found in RestAssured for method " + testMethod.getName(), e); - } catch (NoSuchMethodException e) { - throw new ExtensionConfigurationException("Method 'addHeader' or 'build' not found in 'RequestSpecBuilder' for method " + testMethod.getName(), e); - } catch (InvocationTargetException e) { - throw new ExtensionConfigurationException("Error invoking method on 'RequestSpecBuilder' for method " + testMethod.getName(), e); - } catch (MalformedClaimException | JoseException e) { - throw new ExtensionConfigurationException("Error building JWT", e); - } + if (Objects.isNull(jwtConfig)) { + return; } + removeJwt(testMethod, jwtConfig); } + }); + } + + private void configureJwt(Method testMethod, JwtConfig jwtConfig) throws ExtensionConfigurationException { + // Check if RestAssured is being used + Class restAssuredClass = tryLoad("io.restassured.RestAssured"); + if (Objects.isNull(restAssuredClass)) { + LOG.debug("RESTAssured not found!"); + return; + } + + LOG.debug("RESTAssured found!"); + try { + // Get the RequestSpecBuilder class + Class requestSpecBuilderClass = Class.forName("io.restassured.builder.RequestSpecBuilder"); + + // Create an instance of RequestSpecBuilder + Object requestSpecBuilder = requestSpecBuilderClass.getDeclaredConstructor().newInstance(); + + // Get the requestSpecification field + Field requestSpecificationField = restAssuredClass.getDeclaredField("requestSpecification"); + requestSpecificationField.setAccessible(true); + + // Get the header method of RequestSpecBuilder + Method headerMethod = requestSpecBuilderClass.getDeclaredMethod("addHeader", String.class, String.class); + + // Build the JWT and add it to the header + String jwt = JwtBuilder.buildJwt(jwtConfig.subject(), jwtConfig.issuer(), jwtConfig.claims()); + headerMethod.invoke(requestSpecBuilder, "Authorization", "Bearer " + jwt); + LOG.debug("Using provided JWT auth header: " + jwt); + + // Set the updated requestSpecification + requestSpecificationField.set(null, requestSpecBuilderClass.getMethod("build").invoke(requestSpecBuilder)); + + } catch (ClassNotFoundException e) { + throw new ExtensionConfigurationException("Class 'RequestSpecBuilder' not found for method " + testMethod.getName(), e); + } catch (InstantiationException | IllegalAccessException e) { + throw new ExtensionConfigurationException("Error instantiating 'RequestSpecBuilder' for method " + testMethod.getName(), e); + } catch (NoSuchFieldException e) { + throw new ExtensionConfigurationException("Field 'requestSpecification' not found in RestAssured for method " + testMethod.getName(), e); + } catch (NoSuchMethodException e) { + throw new ExtensionConfigurationException("Method 'addHeader' or 'build' not found in 'RequestSpecBuilder' for method " + testMethod.getName(), e); + } catch (InvocationTargetException e) { + throw new ExtensionConfigurationException("Error invoking method on 'RequestSpecBuilder' for method " + testMethod.getName(), e); + } catch (MalformedClaimException | JoseException e) { + throw new ExtensionConfigurationException("Error building JWT", e); } } - private void removeJwt(ExtensionContext context) throws ExtensionConfigurationException { - // Check if the test method has the @JwtConfig annotation - Method testMethod = context.getTestMethod().orElse(null); - if (testMethod != null) { - LOG.debug("Method was annotated with: " + testMethod.getName()); - - // Check if RestAssured is being used - Class restAssuredClass = tryLoad("io.restassured.RestAssured"); - if (restAssuredClass == null) { - LOG.debug("RESTAssured not found!"); - } else { - try { - // Get the requestSpecification field - Field requestSpecificationField = restAssuredClass.getDeclaredField("requestSpecification"); - requestSpecificationField.setAccessible(true); - - // Removes all requestSpec - requestSpecificationField.set(null, null); - - } catch (NoSuchFieldException e) { - throw new ExtensionConfigurationException("Field 'requestSpecification' not found in RestAssured", e); - } catch (IllegalAccessException e) { - throw new ExtensionConfigurationException("Error accessing 'requestSpecification' field in RestAssured", e); - } - } + private void removeJwt(Method testMethod, JwtConfig jwtConfig) throws ExtensionConfigurationException { + // Check if RestAssured is being used + Class restAssuredClass = tryLoad("io.restassured.RestAssured"); + if (restAssuredClass == null) { + LOG.debug("RESTAssured not found!"); + return; + } + + try { + // Get the requestSpecification field + Field requestSpecificationField = restAssuredClass.getDeclaredField("requestSpecification"); + requestSpecificationField.setAccessible(true); + + // Removes all requestSpec + requestSpecificationField.set(null, null); + } catch (NoSuchFieldException e) { + throw new ExtensionConfigurationException("Field 'requestSpecification' not found in RestAssured", e); + } catch (IllegalAccessException e) { + throw new ExtensionConfigurationException("Error accessing 'requestSpecification' field in RestAssured", e); } }