Skip to content

Commit

Permalink
[test] Add artificial delay for handshake headers frame to workaround…
Browse files Browse the repository at this point in the history
… Ktor bug
  • Loading branch information
joffrey-bion committed Mar 29, 2024
1 parent 6cac898 commit 5b29483
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ import io.ktor.client.engine.*
import io.ktor.client.plugins.websocket.*
import org.hildan.krossbow.websocket.*
import org.hildan.krossbow.websocket.test.*
import kotlin.time.*

abstract class KtorClientTestSuite(supportsStatusCodes: Boolean) : WebSocketClientTestSuite(supportsStatusCodes) {
abstract class KtorClientTestSuite(
supportsStatusCodes: Boolean,
headersTestDelay: Duration? = null,
) : WebSocketClientTestSuite(supportsStatusCodes, headersTestDelay) {

override fun provideClient(): WebSocketClient = KtorWebSocketClient(
HttpClient(provideEngine()) { install(WebSockets) },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.ktor.client.*
import io.ktor.client.plugins.websocket.*
import org.hildan.krossbow.websocket.WebSocketClient
import org.hildan.krossbow.websocket.test.*
import kotlin.time.Duration.Companion.milliseconds

// WinHttp: error is too generic and doesn't differ per status code
// JS node: error is too generic and doesn't differ per status code (ECONNREFUSED, unlike 'ws')
Expand All @@ -17,6 +18,8 @@ private val Platform.supportsStatusCodes: Boolean
// Also, it covers cases of dynamically-selected implementations.
class KtorMppWebSocketClientTest : WebSocketClientTestSuite(
supportsStatusCodes = currentPlatform().supportsStatusCodes,
// workaround for https://youtrack.jetbrains.com/issue/KTOR-6883
headersTestDelay = 200.milliseconds.takeIf { currentPlatform() == Platform.Js.NodeJs },
) {
override fun provideClient(): WebSocketClient = KtorWebSocketClient(
HttpClient {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package org.hildan.krossbow.websocket.ktor

import io.ktor.client.engine.*
import io.ktor.client.engine.js.*
import org.hildan.krossbow.websocket.test.*
import kotlin.time.Duration.Companion.milliseconds

class KtorJsWebSocketClientTest : KtorClientTestSuite(
// JS node: error is too generic and doesn't differ per status code (ECONNREFUSED, unlike 'ws')
// JS browser: cannot support status codes for security reasons
supportsStatusCodes = false,
// workaround for https://youtrack.jetbrains.com/issue/KTOR-6883
headersTestDelay = 200.milliseconds.takeIf { currentPlatform() == Platform.Js.NodeJs },
) {
override fun provideEngine(): HttpClientEngineFactory<*> = Js
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import kotlin.time.*
import kotlin.time.Duration.Companion.seconds

abstract class WebSocketClientTestSuite(
val supportsStatusCodes: Boolean = true,
private val supportsStatusCodes: Boolean = true,
private val headersTestDelay: Duration? = null,
) {
abstract fun provideClient(): WebSocketClient

Expand All @@ -28,16 +29,22 @@ abstract class WebSocketClientTestSuite(
wsClient = provideClient()
}

private fun testUrl(path: String, testCaseName: String? = null): String =
"${testServerConfig.wsUrl}$path?${testUrlQuery(testCaseName)}"
private fun testUrl(
path: String,
testCaseName: String? = null,
otherParams: Map<String, String> = emptyMap(),
): String = "${testServerConfig.wsUrl}$path?${testUrlQuery(testCaseName, otherParams)}"

private fun testUrlQuery(testCaseName: String? = null): String {
private fun testUrlQuery(testCaseName: String? = null, otherParams: Map<String, String> = emptyMap()): String {
val params = buildMap {
put("agent", agent)
put("testClass", this@WebSocketClientTestSuite::class.simpleName)
if (testCaseName != null) {
put("testCase", testCaseName)
}
otherParams.forEach { (key, value) ->
put(key, value)
}
}
return params.entries.joinToString("&") { "${it.key}=${it.value}" }
}
Expand Down Expand Up @@ -186,8 +193,10 @@ abstract class WebSocketClientTestSuite(
fun testHandshakeHeaders() = runTestRealTime {
if (wsClient.supportsCustomHeaders) {
println("Connecting with agent $agent to ${testServerConfig.wsUrl}/sendHandshakeHeaders")
// workaround for https://youtrack.jetbrains.com/issue/KTOR-6883
val extraParams = if (headersTestDelay != null) mapOf("scheduleDelay" to headersTestDelay.toString()) else emptyMap()
val session = wsClient.connect(
url = testUrl(path = "/sendHandshakeHeaders"),
url = testUrl(path = "/sendHandshakeHeaders", otherParams = extraParams),
headers = mapOf("My-Header-1" to "my-value-1", "My-Header-2" to "my-value-2"),
)
println("Connected with agent $agent to ${testServerConfig.wsUrl}/sendHandshakeHeaders")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,46 @@ import org.java_websocket.handshake.*
import org.java_websocket.server.*
import java.net.*
import java.nio.*
import kotlin.time.*

internal class EchoWebSocketServer(port: Int = 0) : WebSocketServer(InetSocketAddress(port)) {

private val delayedHeadersScope = CoroutineScope(SupervisorJob() + Dispatchers.Default)

override fun onStart() {
}

override fun onOpen(conn: WebSocket, handshake: ClientHandshake) {
val uri = URI.create(handshake.resourceDescriptor)
println("Connection to URI $uri")

if (uri.path == "/sendHandshakeHeaders") {
val headerNames = handshake.iterateHttpFields().asSequence().toList()
val headersData = headerNames.joinToString("\n") { "$it=${handshake.getFieldValue(it)}" }
println("Sending message with headers...")
conn.send(headersData)
println("Headers frame sent!")
val queryParams = uri.queryAsMap()
val scheduleDelay = queryParams["scheduleDelay"]?.let(Duration::parse)
conn.sendMessageWithHeaders(handshake, scheduleDelay)
} else {
println("Not sending headers frame for URI $uri")
}
}

private fun WebSocket.sendMessageWithHeaders(handshake: ClientHandshake, scheduleDelay: Duration? = null) {
val headerNames = handshake.iterateHttpFields().asSequence().toList()
val headersData = headerNames.joinToString("\n") { "$it=${handshake.getFieldValue(it)}" }
if (scheduleDelay != null) {
// necessary due to https://youtrack.jetbrains.com/issue/KTOR-6883
println("Scheduling message with headers in $scheduleDelay")
delayedHeadersScope.launch {
delay(scheduleDelay)
send(headersData)
println("Headers frame sent!")
}
} else {
println("Sending message with headers...")
send(headersData)
println("Headers frame sent!")
}
}

override fun onMessage(conn: WebSocket, message: String?) {
conn.send(message)
}
Expand Down Expand Up @@ -54,3 +74,8 @@ internal class EchoWebSocketServer(port: Int = 0) : WebSocketServer(InetSocketAd
port
}
}

private fun URI.queryAsMap() = query.split("&")
.map { it.split("=") }
.associate { it[0] to it[1] }

0 comments on commit 5b29483

Please sign in to comment.