diff --git a/akka-http-cors-bench-jmh/src/main/scala/ch/megard/akka/http/cors/CorsBenchmark.scala b/akka-http-cors-bench-jmh/src/main/scala/ch/megard/akka/http/cors/CorsBenchmark.scala index f282364..2174859 100644 --- a/akka-http-cors-bench-jmh/src/main/scala/ch/megard/akka/http/cors/CorsBenchmark.scala +++ b/akka-http-cors-bench-jmh/src/main/scala/ch/megard/akka/http/cors/CorsBenchmark.scala @@ -10,6 +10,7 @@ import akka.http.scaladsl.model.{HttpMethods, HttpRequest} import akka.http.scaladsl.server.Directives import akka.http.scaladsl.unmarshalling.Unmarshal import ch.megard.akka.http.cors.scaladsl.CorsDirectives +import ch.megard.akka.http.cors.scaladsl.settings.CorsSettings import com.typesafe.config.ConfigFactory import org.openjdk.jmh.annotations._ @@ -25,7 +26,8 @@ class CorsBenchmark extends Directives with CorsDirectives { implicit private val system: ActorSystem = ActorSystem("CorsBenchmark", config) implicit private val ec: ExecutionContext = scala.concurrent.ExecutionContext.global - private val http = Http() + private val http = Http() + private val corsSettings = CorsSettings.default private var binding: ServerBinding = _ private var request: HttpRequest = _ @@ -40,7 +42,7 @@ class CorsBenchmark extends Directives with CorsDirectives { complete("ok") } } ~ path("cors") { - cors() { + cors(corsSettings) { get { complete("ok") } diff --git a/akka-http-cors-example/src/main/java/ch/megard/akka/http/cors/javadsl/CorsServer.java b/akka-http-cors-example/src/main/java/ch/megard/akka/http/cors/javadsl/CorsServer.java index 893449a..81b9d2c 100644 --- a/akka-http-cors-example/src/main/java/ch/megard/akka/http/cors/javadsl/CorsServer.java +++ b/akka-http-cors-example/src/main/java/ch/megard/akka/http/cors/javadsl/CorsServer.java @@ -43,11 +43,9 @@ protected Route routes() { // Note how rejections and exceptions are handled *before* the CORS directive (in the inner route). // This is required to have the correct CORS headers in the response even when an error occurs. - return handleErrors.apply(() -> cors(() -> handleErrors.apply(() -> route( - path("ping", () -> - complete("pong")), - path("pong", () -> - failWith(new NoSuchElementException("pong not found, try with ping"))) + return handleErrors.apply(() -> cors(() -> handleErrors.apply(() -> concat( + path("ping", () -> complete("pong")), + path("pong", () -> failWith(new NoSuchElementException("pong not found, try with ping"))) )))); } diff --git a/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/javadsl/CorsDirectives.scala b/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/javadsl/CorsDirectives.scala index 44c3509..8f75076 100644 --- a/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/javadsl/CorsDirectives.scala +++ b/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/javadsl/CorsDirectives.scala @@ -8,6 +8,7 @@ import ch.megard.akka.http.cors.javadsl.settings.CorsSettings import ch.megard.akka.http.cors.scaladsl object CorsDirectives { + def cors(inner: Supplier[Route]): Route = RouteAdapter { scaladsl.CorsDirectives.cors() { diff --git a/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/javadsl/settings/CorsSettings.scala b/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/javadsl/settings/CorsSettings.scala index f238505..9249fc4 100644 --- a/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/javadsl/settings/CorsSettings.scala +++ b/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/javadsl/settings/CorsSettings.scala @@ -8,7 +8,7 @@ import akka.http.javadsl.model.HttpMethod import ch.megard.akka.http.cors.javadsl.model.{HttpHeaderRange, HttpOriginMatcher} import ch.megard.akka.http.cors.scaladsl import ch.megard.akka.http.cors.scaladsl.settings.CorsSettingsImpl -import com.typesafe.config.Config +import com.typesafe.config.{Config, ConfigFactory} /** * Public API but not intended for subclassing @@ -36,7 +36,7 @@ abstract class CorsSettings { self: CorsSettingsImpl => object CorsSettings { def create(config: Config): CorsSettings = scaladsl.settings.CorsSettings(config) def create(configOverrides: String): CorsSettings = scaladsl.settings.CorsSettings(configOverrides) - def create(system: ActorSystem): CorsSettings = create(system.settings.config) - - def defaultSettings: CorsSettings = scaladsl.settings.CorsSettings.defaultSettings + def create(system: ActorSystem): CorsSettings = scaladsl.settings.CorsSettings(system) + @deprecated("Use `CorsSetting.create` instead", "1.0.0") + def defaultSettings: CorsSettings = create(ConfigFactory.load(getClass.getClassLoader)) } diff --git a/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/scaladsl/CorsDirectives.scala b/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/scaladsl/CorsDirectives.scala index eea7947..8ee50d0 100644 --- a/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/scaladsl/CorsDirectives.scala +++ b/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/scaladsl/CorsDirectives.scala @@ -22,6 +22,20 @@ trait CorsDirectives { import BasicDirectives._ import RouteDirectives._ + /** + * Wraps its inner route with support for the CORS mechanism, enabling cross origin requests. + * + * In particular the recommendation written by the W3C in https://www.w3.org/TR/cors/ is + * implemented by this directive. + * + * The settings are loaded from the Actor System configuration. + */ + def cors(): Directive0 = { + extractActorSystem.flatMap { system => + cors(CorsSettings(system)) + } + } + /** * Wraps its inner route with support for the CORS mechanism, enabling cross origin requests. * @@ -30,7 +44,7 @@ trait CorsDirectives { * * @param settings the settings used by the CORS filter */ - def cors(settings: CorsSettings = CorsSettings.defaultSettings): Directive0 = { + def cors(settings: CorsSettings): Directive0 = { import settings._ /** Return the invalid origins, or `Nil` if one is valid. */ diff --git a/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/scaladsl/settings/CorsSettings.scala b/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/scaladsl/settings/CorsSettings.scala index 15cde1c..0d41b5c 100644 --- a/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/scaladsl/settings/CorsSettings.scala +++ b/akka-http-cors/src/main/scala/ch/megard/akka/http/cors/scaladsl/settings/CorsSettings.scala @@ -13,7 +13,7 @@ import com.typesafe.config.ConfigException.{Missing, WrongType} import com.typesafe.config.{Config, ConfigFactory} import scala.collection.JavaConverters._ -import scala.collection.immutable.Seq +import scala.collection.immutable.{ListMap, Seq} import scala.compat.java8.OptionConverters import scala.util.Try @@ -159,15 +159,48 @@ abstract class CorsSettings private[akka] () extends javadsl.settings.CorsSettin } object CorsSettings { - private val prefix = "akka-http-cors" + private val prefix = "akka-http-cors" + final private val MaxCached = 8 + private[this] var cache = ListMap.empty[ActorSystem, CorsSettings] - def apply(system: ActorSystem): CorsSettings = apply(system.settings.config) - def apply(config: Config): CorsSettings = fromSubConfig(config.getConfig(prefix)) + /** + * Creates an instance of settings using the given Config. + */ + def apply(config: Config): CorsSettings = + fromSubConfig(config.getConfig(prefix)) + + /** + * Create an instance of settings using the given String of config overrides to override + * settings set in the class loader of this class (i.e. by application.conf or reference.conf files in + * the class loader of this class). + */ def apply(configOverrides: String): CorsSettings = apply( ConfigFactory.parseString(configOverrides).withFallback(ConfigFactory.defaultReference(getClass.getClassLoader)) ) + /** + * Creates an instance of CorsSettings using the configuration provided by the given ActorSystem. + */ + def apply(system: ActorSystem): CorsSettings = + // From private akka.http.impl.util.SettingsCompanionImpl implementation + cache.getOrElse( + system, { + val settings = apply(system.settings.config) + val c = if (cache.size < MaxCached) cache else cache.tail + cache = c.updated(system, settings) + settings + } + ) + + /** + * Creates an instance of CorsSettings using the configuration provided by the given ActorSystem. + */ + implicit def default(implicit system: ActorSystem): CorsSettings = apply(system) + + @deprecated("Use either `CorsSetting.default` or `CorsSettings.apply` instead", "1.0.0") + def defaultSettings: CorsSettings = apply(ConfigFactory.load(getClass.getClassLoader)) + def fromSubConfig(config: Config): CorsSettings = { def parseStringList(path: String): List[String] = Try(config.getStringList(path).asScala.toList).recover { @@ -198,5 +231,4 @@ object CorsSettings { ) } - val defaultSettings = apply(ConfigFactory.load(getClass.getClassLoader)) } diff --git a/akka-http-cors/src/test/scala/ch/megard/akka/http/cors/CorsDirectivesSpec.scala b/akka-http-cors/src/test/scala/ch/megard/akka/http/cors/CorsDirectivesSpec.scala index 0ed396a..7e30935 100644 --- a/akka-http-cors/src/test/scala/ch/megard/akka/http/cors/CorsDirectivesSpec.scala +++ b/akka-http-cors/src/test/scala/ch/megard/akka/http/cors/CorsDirectivesSpec.scala @@ -20,14 +20,40 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with val exampleOrigin = HttpOrigin("http://example.com") val exampleStatus = StatusCodes.Created + val referenceSettings = CorsSettings("") + + // override config for the ActorSystem to test `cors()` + override def testConfigSource: String = + """ + |akka-http-cors { + | allow-credentials = false + |} + |""".stripMargin + def route(settings: CorsSettings, responseHeaders: Seq[HttpHeader] = Nil): Route = cors(settings) { complete(HttpResponse(exampleStatus, responseHeaders, HttpEntity(actual))) } - "The cors directive" should { + "The cors() directive" should { + "extract its settings from the actor system" in { + val route = cors() { + complete(HttpResponse(exampleStatus, Nil, HttpEntity(actual))) + } + + Get() ~> Origin(exampleOrigin) ~> route ~> check { + responseAs[String] shouldBe actual + response.status shouldBe exampleStatus + response.headers should contain theSameElementsAs Seq( + `Access-Control-Allow-Origin`.* + ) + } + } + } + + "The cors(settings) directive" should { "not affect actual requests when not strict" in { - val settings = CorsSettings.defaultSettings + val settings = referenceSettings val responseHeaders = Seq(Host("my-host"), `Access-Control-Max-Age`(60)) Get() ~> { route(settings, responseHeaders) @@ -40,7 +66,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "reject requests without Origin header when strict" in { - val settings = CorsSettings.defaultSettings.withAllowGenericHttpRequests(false) + val settings = referenceSettings.withAllowGenericHttpRequests(false) Get() ~> { route(settings) } ~> check { @@ -49,7 +75,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "accept actual requests with a single Origin" in { - val settings = CorsSettings.defaultSettings + val settings = referenceSettings Get() ~> Origin(exampleOrigin) ~> { route(settings) } ~> check { @@ -63,7 +89,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "accept pre-flight requests with a null origin when allowed-origins = `*`" in { - val settings = CorsSettings.defaultSettings + val settings = referenceSettings Options() ~> Origin(Seq.empty) ~> `Access-Control-Request-Method`(GET) ~> { route(settings) } ~> check { @@ -78,7 +104,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "reject pre-flight requests with a null origin when allowed-origins != `*`" in { - val settings = CorsSettings.defaultSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) + val settings = referenceSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) Options() ~> Origin(Seq.empty) ~> `Access-Control-Request-Method`(GET) ~> { route(settings) } ~> check { @@ -87,7 +113,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "accept actual requests with a null Origin" in { - val settings = CorsSettings.defaultSettings + val settings = referenceSettings Get() ~> Origin(Seq.empty) ~> { route(settings) } ~> check { @@ -104,7 +130,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with val subdomainMatcher = HttpOriginMatcher(HttpOrigin("http://*.example.com")) val subdomainOrigin = HttpOrigin("http://sub.example.com") - val settings = CorsSettings.defaultSettings.withAllowedOrigins(subdomainMatcher) + val settings = referenceSettings.withAllowedOrigins(subdomainMatcher) Get() ~> Origin(subdomainOrigin) ~> { route(settings) } ~> check { @@ -118,7 +144,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "return `Access-Control-Allow-Origin: *` to actual request only when credentials are not allowed" in { - val settings = CorsSettings.defaultSettings.withAllowCredentials(false) + val settings = referenceSettings.withAllowCredentials(false) Get() ~> Origin(exampleOrigin) ~> { route(settings) } ~> check { @@ -132,7 +158,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with "return `Access-Control-Expose-Headers` to actual request with all the exposed headers in the settings" in { val exposedHeaders = Seq("X-a", "X-b", "X-c") - val settings = CorsSettings.defaultSettings.withExposedHeaders(exposedHeaders) + val settings = referenceSettings.withExposedHeaders(exposedHeaders) Get() ~> Origin(exampleOrigin) ~> { route(settings) } ~> check { @@ -147,7 +173,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "remove CORS-related headers from the original response before adding the new ones" in { - val settings = CorsSettings.defaultSettings.withExposedHeaders(Seq("X-good")) + val settings = referenceSettings.withExposedHeaders(Seq("X-good")) val responseHeaders = Seq( Host("my-host"), // untouched `Access-Control-Allow-Origin`("http://bad.com"), // replaced @@ -172,7 +198,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "accept valid pre-flight requests" in { - val settings = CorsSettings.defaultSettings + val settings = referenceSettings Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~> { route(settings) } ~> check { @@ -188,7 +214,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "accept actual requests with OPTION method" in { - val settings = CorsSettings.defaultSettings + val settings = referenceSettings Options() ~> Origin(exampleOrigin) ~> { route(settings) } ~> check { @@ -203,7 +229,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with "reject actual requests with invalid origin" when { "the origin is null" in { - val settings = CorsSettings.defaultSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) + val settings = referenceSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) Get() ~> Origin(Seq.empty) ~> { route(settings) } ~> check { @@ -211,7 +237,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } } "there is one origin" in { - val settings = CorsSettings.defaultSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) + val settings = referenceSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) val invalidOrigin = HttpOrigin("http://invalid.com") Get() ~> Origin(invalidOrigin) ~> { route(settings) @@ -222,7 +248,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "reject pre-flight requests with invalid origin" in { - val settings = CorsSettings.defaultSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) + val settings = referenceSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) val invalidOrigin = HttpOrigin("http://invalid.com") Options() ~> Origin(invalidOrigin) ~> `Access-Control-Request-Method`(GET) ~> { route(settings) @@ -232,7 +258,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "reject pre-flight requests with invalid method" in { - val settings = CorsSettings.defaultSettings + val settings = referenceSettings val invalidMethod = PATCH Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(invalidMethod) ~> { route(settings) @@ -242,7 +268,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "reject pre-flight requests with invalid header" in { - val settings = CorsSettings.defaultSettings.withAllowedHeaders(HttpHeaderRange()) + val settings = referenceSettings.withAllowedHeaders(HttpHeaderRange()) val invalidHeader = "X-header" Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~> `Access-Control-Request-Headers`(invalidHeader) ~> { @@ -253,7 +279,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "reject pre-flight requests with multiple origins" in { - val settings = CorsSettings.defaultSettings.withAllowGenericHttpRequests(false) + val settings = referenceSettings.withAllowGenericHttpRequests(false) Options() ~> Origin(exampleOrigin, exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~> { route(settings) } ~> check { @@ -263,7 +289,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with } "the default rejection handler" should { - val settings = CorsSettings.defaultSettings + val settings = referenceSettings .withAllowGenericHttpRequests(false) .withAllowedOrigins(HttpOriginMatcher(exampleOrigin)) .withAllowedHeaders(HttpHeaderRange()) diff --git a/akka-http-cors/src/test/scala/ch/megard/akka/http/cors/scaladsl/settings/CorsSettingsSpec.scala b/akka-http-cors/src/test/scala/ch/megard/akka/http/cors/scaladsl/settings/CorsSettingsSpec.scala index ad03b0d..2a92202 100644 --- a/akka-http-cors/src/test/scala/ch/megard/akka/http/cors/scaladsl/settings/CorsSettingsSpec.scala +++ b/akka-http-cors/src/test/scala/ch/megard/akka/http/cors/scaladsl/settings/CorsSettingsSpec.scala @@ -2,30 +2,59 @@ package ch.megard.akka.http.cors.scaladsl.settings import akka.http.scaladsl.model.headers.HttpOrigin import akka.http.scaladsl.model.{HttpMethod, HttpMethods} +import akka.http.scaladsl.testkit.ScalatestRouteTest import ch.megard.akka.http.cors.scaladsl.model.{HttpHeaderRange, HttpOriginMatcher} import com.typesafe.config.{ConfigFactory, ConfigValueFactory} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec -class CorsSettingsSpec extends AnyWordSpec with Matchers { +class CorsSettingsSpec extends AnyWordSpec with Matchers with ScalatestRouteTest { import HttpMethods._ - private val validConfigStr = + // Override some configs loaded through the Actor system + override def testConfigSource = """ - |akka-http-cors { - | allow-generic-http-requests = true - | allow-credentials = true - | allowed-origins = "*" - | allowed-headers = "*" - | allowed-methods = ["GET", "OPTIONS", "XXX"] - | exposed-headers = [] - | max-age = 30 minutes - |} - """.stripMargin - - private val validConfig = ConfigFactory.parseString(validConfigStr) - - "The CorsSettings object" should { + akka-http-cors { + allow-credentials = false + } + """ + + val validConfig = ConfigFactory.parseString( + """ + akka-http-cors { + allow-generic-http-requests = true + allow-credentials = true + allowed-origins = "*" + allowed-headers = "*" + allowed-methods = ["GET", "OPTIONS", "XXX"] + exposed-headers = [] + max-age = 30 minutes + } + """ + ) + + val referenceSettings = CorsSettings("") + + "CorsSettings" should { + + "load settings from the actor system by default" in { + val settings1 = CorsSettings.default + val settings2 = CorsSettings(system) + + settings1 should not be referenceSettings + settings1 shouldBe settings2 + + referenceSettings.allowCredentials shouldBe true + settings1.allowCredentials shouldBe false + } + + "cache the settings from the actor system" in { + val settings1 = CorsSettings(system) + val settings2 = CorsSettings(system) + + settings1 shouldBe theSameInstanceAs(settings2) + } + "return valid cors settings from a valid config object" in { val corsSettings = CorsSettings(validConfig) corsSettings.allowGenericHttpRequests shouldBe true diff --git a/build.sbt b/build.sbt index cd747ab..1fea552 100644 --- a/build.sbt +++ b/build.sbt @@ -8,7 +8,6 @@ lazy val commonSettings = Seq( "-deprecation", "-target:jvm-1.8", "-encoding", "utf8", - "-Xfuture", "-Ywarn-dead-code", "-Ywarn-numeric-widen", "-Ywarn-unused",