Skip to content

Commit

Permalink
Merge pull request #74 from lomigmegard/feature/settings_from_actor_s…
Browse files Browse the repository at this point in the history
…ystem

Load settings from the current Actor System
  • Loading branch information
lomigmegard authored May 23, 2020
2 parents 61c07b0 + bebbe90 commit ac7bd83
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import akka.http.scaladsl.model.{HttpMethods, HttpRequest}
import akka.http.scaladsl.server.Directives
import akka.http.scaladsl.unmarshalling.Unmarshal
import ch.megard.akka.http.cors.scaladsl.CorsDirectives
import ch.megard.akka.http.cors.scaladsl.settings.CorsSettings
import com.typesafe.config.ConfigFactory
import org.openjdk.jmh.annotations._

Expand All @@ -25,7 +26,8 @@ class CorsBenchmark extends Directives with CorsDirectives {
implicit private val system: ActorSystem = ActorSystem("CorsBenchmark", config)
implicit private val ec: ExecutionContext = scala.concurrent.ExecutionContext.global

private val http = Http()
private val http = Http()
private val corsSettings = CorsSettings.default

private var binding: ServerBinding = _
private var request: HttpRequest = _
Expand All @@ -40,7 +42,7 @@ class CorsBenchmark extends Directives with CorsDirectives {
complete("ok")
}
} ~ path("cors") {
cors() {
cors(corsSettings) {
get {
complete("ok")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,9 @@ protected Route routes() {

// Note how rejections and exceptions are handled *before* the CORS directive (in the inner route).
// This is required to have the correct CORS headers in the response even when an error occurs.
return handleErrors.apply(() -> cors(() -> handleErrors.apply(() -> route(
path("ping", () ->
complete("pong")),
path("pong", () ->
failWith(new NoSuchElementException("pong not found, try with ping")))
return handleErrors.apply(() -> cors(() -> handleErrors.apply(() -> concat(
path("ping", () -> complete("pong")),
path("pong", () -> failWith(new NoSuchElementException("pong not found, try with ping")))
))));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import ch.megard.akka.http.cors.javadsl.settings.CorsSettings
import ch.megard.akka.http.cors.scaladsl

object CorsDirectives {

def cors(inner: Supplier[Route]): Route =
RouteAdapter {
scaladsl.CorsDirectives.cors() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import akka.http.javadsl.model.HttpMethod
import ch.megard.akka.http.cors.javadsl.model.{HttpHeaderRange, HttpOriginMatcher}
import ch.megard.akka.http.cors.scaladsl
import ch.megard.akka.http.cors.scaladsl.settings.CorsSettingsImpl
import com.typesafe.config.Config
import com.typesafe.config.{Config, ConfigFactory}

/**
* Public API but not intended for subclassing
Expand Down Expand Up @@ -36,7 +36,7 @@ abstract class CorsSettings { self: CorsSettingsImpl =>
object CorsSettings {
def create(config: Config): CorsSettings = scaladsl.settings.CorsSettings(config)
def create(configOverrides: String): CorsSettings = scaladsl.settings.CorsSettings(configOverrides)
def create(system: ActorSystem): CorsSettings = create(system.settings.config)

def defaultSettings: CorsSettings = scaladsl.settings.CorsSettings.defaultSettings
def create(system: ActorSystem): CorsSettings = scaladsl.settings.CorsSettings(system)
@deprecated("Use `CorsSetting.create` instead", "1.0.0")
def defaultSettings: CorsSettings = create(ConfigFactory.load(getClass.getClassLoader))
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ trait CorsDirectives {
import BasicDirectives._
import RouteDirectives._

/**
* Wraps its inner route with support for the CORS mechanism, enabling cross origin requests.
*
* In particular the recommendation written by the W3C in https://www.w3.org/TR/cors/ is
* implemented by this directive.
*
* The settings are loaded from the Actor System configuration.
*/
def cors(): Directive0 = {
extractActorSystem.flatMap { system =>
cors(CorsSettings(system))
}
}

/**
* Wraps its inner route with support for the CORS mechanism, enabling cross origin requests.
*
Expand All @@ -30,7 +44,7 @@ trait CorsDirectives {
*
* @param settings the settings used by the CORS filter
*/
def cors(settings: CorsSettings = CorsSettings.defaultSettings): Directive0 = {
def cors(settings: CorsSettings): Directive0 = {
import settings._

/** Return the invalid origins, or `Nil` if one is valid. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import com.typesafe.config.ConfigException.{Missing, WrongType}
import com.typesafe.config.{Config, ConfigFactory}

import scala.collection.JavaConverters._
import scala.collection.immutable.Seq
import scala.collection.immutable.{ListMap, Seq}
import scala.compat.java8.OptionConverters
import scala.util.Try

Expand Down Expand Up @@ -159,15 +159,48 @@ abstract class CorsSettings private[akka] () extends javadsl.settings.CorsSettin
}

object CorsSettings {
private val prefix = "akka-http-cors"
private val prefix = "akka-http-cors"
final private val MaxCached = 8
private[this] var cache = ListMap.empty[ActorSystem, CorsSettings]

def apply(system: ActorSystem): CorsSettings = apply(system.settings.config)
def apply(config: Config): CorsSettings = fromSubConfig(config.getConfig(prefix))
/**
* Creates an instance of settings using the given Config.
*/
def apply(config: Config): CorsSettings =
fromSubConfig(config.getConfig(prefix))

/**
* Create an instance of settings using the given String of config overrides to override
* settings set in the class loader of this class (i.e. by application.conf or reference.conf files in
* the class loader of this class).
*/
def apply(configOverrides: String): CorsSettings =
apply(
ConfigFactory.parseString(configOverrides).withFallback(ConfigFactory.defaultReference(getClass.getClassLoader))
)

/**
* Creates an instance of CorsSettings using the configuration provided by the given ActorSystem.
*/
def apply(system: ActorSystem): CorsSettings =
// From private akka.http.impl.util.SettingsCompanionImpl implementation
cache.getOrElse(
system, {
val settings = apply(system.settings.config)
val c = if (cache.size < MaxCached) cache else cache.tail
cache = c.updated(system, settings)
settings
}
)

/**
* Creates an instance of CorsSettings using the configuration provided by the given ActorSystem.
*/
implicit def default(implicit system: ActorSystem): CorsSettings = apply(system)

@deprecated("Use either `CorsSetting.default` or `CorsSettings.apply` instead", "1.0.0")
def defaultSettings: CorsSettings = apply(ConfigFactory.load(getClass.getClassLoader))

def fromSubConfig(config: Config): CorsSettings = {
def parseStringList(path: String): List[String] =
Try(config.getStringList(path).asScala.toList).recover {
Expand Down Expand Up @@ -198,5 +231,4 @@ object CorsSettings {
)
}

val defaultSettings = apply(ConfigFactory.load(getClass.getClassLoader))
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,40 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
val exampleOrigin = HttpOrigin("http://example.com")
val exampleStatus = StatusCodes.Created

val referenceSettings = CorsSettings("")

// override config for the ActorSystem to test `cors()`
override def testConfigSource: String =
"""
|akka-http-cors {
| allow-credentials = false
|}
|""".stripMargin

def route(settings: CorsSettings, responseHeaders: Seq[HttpHeader] = Nil): Route =
cors(settings) {
complete(HttpResponse(exampleStatus, responseHeaders, HttpEntity(actual)))
}

"The cors directive" should {
"The cors() directive" should {
"extract its settings from the actor system" in {
val route = cors() {
complete(HttpResponse(exampleStatus, Nil, HttpEntity(actual)))
}

Get() ~> Origin(exampleOrigin) ~> route ~> check {
responseAs[String] shouldBe actual
response.status shouldBe exampleStatus
response.headers should contain theSameElementsAs Seq(
`Access-Control-Allow-Origin`.*
)
}
}
}

"The cors(settings) directive" should {
"not affect actual requests when not strict" in {
val settings = CorsSettings.defaultSettings
val settings = referenceSettings
val responseHeaders = Seq(Host("my-host"), `Access-Control-Max-Age`(60))
Get() ~> {
route(settings, responseHeaders)
Expand All @@ -40,7 +66,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"reject requests without Origin header when strict" in {
val settings = CorsSettings.defaultSettings.withAllowGenericHttpRequests(false)
val settings = referenceSettings.withAllowGenericHttpRequests(false)
Get() ~> {
route(settings)
} ~> check {
Expand All @@ -49,7 +75,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"accept actual requests with a single Origin" in {
val settings = CorsSettings.defaultSettings
val settings = referenceSettings
Get() ~> Origin(exampleOrigin) ~> {
route(settings)
} ~> check {
Expand All @@ -63,7 +89,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"accept pre-flight requests with a null origin when allowed-origins = `*`" in {
val settings = CorsSettings.defaultSettings
val settings = referenceSettings
Options() ~> Origin(Seq.empty) ~> `Access-Control-Request-Method`(GET) ~> {
route(settings)
} ~> check {
Expand All @@ -78,7 +104,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"reject pre-flight requests with a null origin when allowed-origins != `*`" in {
val settings = CorsSettings.defaultSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
val settings = referenceSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
Options() ~> Origin(Seq.empty) ~> `Access-Control-Request-Method`(GET) ~> {
route(settings)
} ~> check {
Expand All @@ -87,7 +113,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"accept actual requests with a null Origin" in {
val settings = CorsSettings.defaultSettings
val settings = referenceSettings
Get() ~> Origin(Seq.empty) ~> {
route(settings)
} ~> check {
Expand All @@ -104,7 +130,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
val subdomainMatcher = HttpOriginMatcher(HttpOrigin("http://*.example.com"))
val subdomainOrigin = HttpOrigin("http://sub.example.com")

val settings = CorsSettings.defaultSettings.withAllowedOrigins(subdomainMatcher)
val settings = referenceSettings.withAllowedOrigins(subdomainMatcher)
Get() ~> Origin(subdomainOrigin) ~> {
route(settings)
} ~> check {
Expand All @@ -118,7 +144,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"return `Access-Control-Allow-Origin: *` to actual request only when credentials are not allowed" in {
val settings = CorsSettings.defaultSettings.withAllowCredentials(false)
val settings = referenceSettings.withAllowCredentials(false)
Get() ~> Origin(exampleOrigin) ~> {
route(settings)
} ~> check {
Expand All @@ -132,7 +158,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with

"return `Access-Control-Expose-Headers` to actual request with all the exposed headers in the settings" in {
val exposedHeaders = Seq("X-a", "X-b", "X-c")
val settings = CorsSettings.defaultSettings.withExposedHeaders(exposedHeaders)
val settings = referenceSettings.withExposedHeaders(exposedHeaders)
Get() ~> Origin(exampleOrigin) ~> {
route(settings)
} ~> check {
Expand All @@ -147,7 +173,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"remove CORS-related headers from the original response before adding the new ones" in {
val settings = CorsSettings.defaultSettings.withExposedHeaders(Seq("X-good"))
val settings = referenceSettings.withExposedHeaders(Seq("X-good"))
val responseHeaders = Seq(
Host("my-host"), // untouched
`Access-Control-Allow-Origin`("http://bad.com"), // replaced
Expand All @@ -172,7 +198,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"accept valid pre-flight requests" in {
val settings = CorsSettings.defaultSettings
val settings = referenceSettings
Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~> {
route(settings)
} ~> check {
Expand All @@ -188,7 +214,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"accept actual requests with OPTION method" in {
val settings = CorsSettings.defaultSettings
val settings = referenceSettings
Options() ~> Origin(exampleOrigin) ~> {
route(settings)
} ~> check {
Expand All @@ -203,15 +229,15 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with

"reject actual requests with invalid origin" when {
"the origin is null" in {
val settings = CorsSettings.defaultSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
val settings = referenceSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
Get() ~> Origin(Seq.empty) ~> {
route(settings)
} ~> check {
rejection shouldBe CorsRejection(CorsRejection.InvalidOrigin(Seq.empty))
}
}
"there is one origin" in {
val settings = CorsSettings.defaultSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
val settings = referenceSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
val invalidOrigin = HttpOrigin("http://invalid.com")
Get() ~> Origin(invalidOrigin) ~> {
route(settings)
Expand All @@ -222,7 +248,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"reject pre-flight requests with invalid origin" in {
val settings = CorsSettings.defaultSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
val settings = referenceSettings.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
val invalidOrigin = HttpOrigin("http://invalid.com")
Options() ~> Origin(invalidOrigin) ~> `Access-Control-Request-Method`(GET) ~> {
route(settings)
Expand All @@ -232,7 +258,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"reject pre-flight requests with invalid method" in {
val settings = CorsSettings.defaultSettings
val settings = referenceSettings
val invalidMethod = PATCH
Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(invalidMethod) ~> {
route(settings)
Expand All @@ -242,7 +268,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"reject pre-flight requests with invalid header" in {
val settings = CorsSettings.defaultSettings.withAllowedHeaders(HttpHeaderRange())
val settings = referenceSettings.withAllowedHeaders(HttpHeaderRange())
val invalidHeader = "X-header"
Options() ~> Origin(exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~>
`Access-Control-Request-Headers`(invalidHeader) ~> {
Expand All @@ -253,7 +279,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"reject pre-flight requests with multiple origins" in {
val settings = CorsSettings.defaultSettings.withAllowGenericHttpRequests(false)
val settings = referenceSettings.withAllowGenericHttpRequests(false)
Options() ~> Origin(exampleOrigin, exampleOrigin) ~> `Access-Control-Request-Method`(GET) ~> {
route(settings)
} ~> check {
Expand All @@ -263,7 +289,7 @@ class CorsDirectivesSpec extends AnyWordSpec with Matchers with Directives with
}

"the default rejection handler" should {
val settings = CorsSettings.defaultSettings
val settings = referenceSettings
.withAllowGenericHttpRequests(false)
.withAllowedOrigins(HttpOriginMatcher(exampleOrigin))
.withAllowedHeaders(HttpHeaderRange())
Expand Down
Loading

0 comments on commit ac7bd83

Please sign in to comment.