Skip to content

Commit ec52ef5

Browse files
authored
KTOR-8144 KTOR-7299 KTOR-7392 Fix various socket issues (#4643)
* KTOR-8144 Fix socket close with open read/write channels on Native * KTOR-7299 Fix bad descriptor error when socket is closed immediately after a select call * Fix `awaitClosed` not working on Native * Socket error fixes * KTOR-7392 Fix socket descriptor not closed properly on Windows * Fix tests * Fix select issue on Windows causing sockets to get stuck
1 parent 9eef7dd commit ec52ef5

File tree

40 files changed

+725
-467
lines changed

40 files changed

+725
-467
lines changed

ktor-client/ktor-client-cio/common/src/io/ktor/client/engine/cio/Endpoint.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ internal class Endpoint(
223223
}
224224
}
225225

226-
val connection = socket.connection()
227-
if (!secure) return@connect address to connection
228-
229226
try {
227+
val connection = socket.connection()
228+
if (!secure) return@connect address to connection
229+
230230
if (proxy?.type == ProxyType.HTTP) {
231231
startTunnel(requestData, connection.output, connection.input)
232232
}

ktor-client/ktor-client-cio/common/test/CIOEngineTest.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import io.ktor.network.sockets.*
1616
import io.ktor.utils.io.*
1717
import io.ktor.websocket.*
1818
import kotlinx.coroutines.CoroutineScope
19+
import kotlinx.coroutines.coroutineScope
1920
import kotlinx.coroutines.delay
2021
import kotlinx.coroutines.flow.single
2122
import kotlinx.coroutines.launch
@@ -24,8 +25,6 @@ import kotlin.time.Duration.Companion.seconds
2425

2526
class CIOEngineTest : ClientEngineTest<CIOEngineConfig>(CIO) {
2627

27-
private val selectorManager = SelectorManager()
28-
2928
@Test
3029
fun testRequestTimeoutIgnoredWithWebSocket() = testClient {
3130
config {
@@ -241,9 +240,12 @@ class CIOEngineTest : ClientEngineTest<CIOEngineConfig>(CIO) {
241240
private fun TestClientBuilder<*>.withServerSocket(
242241
block: suspend CoroutineScope.(HttpClient, ServerSocket) -> Unit,
243242
) = test { client ->
243+
val selectorManager = SelectorManager()
244244
selectorManager.use {
245245
aSocket(it).tcp().bind(TEST_SERVER_SOCKET_HOST, 0).use { socket ->
246-
block(client, socket)
246+
coroutineScope {
247+
block(client, socket)
248+
}
247249
}
248250
}
249251
}

ktor-client/ktor-client-cio/common/test/ConnectionFactoryTest.kt

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@ class ConnectionFactoryTest {
3232
connectionsLimit = 2,
3333
addressConnectionsLimit = 1,
3434
)
35+
val sockets = mutableListOf<Socket>()
3536
withServerSocket { socket0 ->
3637
withServerSocket { socket1 ->
3738
withServerSocket { socket2 ->
38-
connectionFactory.connect(socket0.localAddress as InetSocketAddress)
39-
connectionFactory.connect(socket1.localAddress as InetSocketAddress)
39+
sockets += connectionFactory.connect(socket0.localAddress as InetSocketAddress)
40+
sockets += connectionFactory.connect(socket1.localAddress as InetSocketAddress)
4041

4142
assertTimeout {
42-
connectionFactory.connect(socket2.localAddress as InetSocketAddress)
43+
sockets += connectionFactory.connect(socket2.localAddress as InetSocketAddress)
4344
}
4445
}
4546
}
4647
}
48+
49+
sockets.forEach { it.close() }
4750
}
4851

4952
@Test
@@ -53,20 +56,23 @@ class ConnectionFactoryTest {
5356
connectionsLimit = 2,
5457
addressConnectionsLimit = 1,
5558
)
59+
val sockets = mutableListOf<Socket>()
5660
withServerSocket { socket0 ->
5761

5862
withServerSocket { socket1 ->
59-
connectionFactory.connect(socket0.localAddress as InetSocketAddress)
63+
sockets += connectionFactory.connect(socket0.localAddress as InetSocketAddress)
6064
assertTimeout {
61-
connectionFactory.connect(socket0.localAddress as InetSocketAddress)
65+
sockets += connectionFactory.connect(socket0.localAddress as InetSocketAddress)
6266
}
6367

64-
connectionFactory.connect(socket1.localAddress as InetSocketAddress)
68+
sockets += connectionFactory.connect(socket1.localAddress as InetSocketAddress)
6569
assertTimeout {
66-
connectionFactory.connect(socket1.localAddress as InetSocketAddress)
70+
sockets += connectionFactory.connect(socket1.localAddress as InetSocketAddress)
6771
}
6872
}
6973
}
74+
75+
sockets.forEach { it.close() }
7076
}
7177

7278
@Test
@@ -76,18 +82,21 @@ class ConnectionFactoryTest {
7682
connectionsLimit = 2,
7783
addressConnectionsLimit = 1,
7884
)
85+
val sockets = mutableListOf<Socket>()
7986
withServerSocket { socket0 ->
8087
withServerSocket { socket1 ->
81-
connectionFactory.connect(socket0.localAddress as InetSocketAddress)
88+
sockets += connectionFactory.connect(socket0.localAddress as InetSocketAddress)
8289

8390
// Release the `limit` semaphore when it fails to acquire the address semaphore.
8491
assertTimeout {
85-
connectionFactory.connect(socket0.localAddress as InetSocketAddress)
92+
sockets += connectionFactory.connect(socket0.localAddress as InetSocketAddress)
8693
}
8794

88-
connectionFactory.connect(socket1.localAddress as InetSocketAddress)
95+
sockets += connectionFactory.connect(socket1.localAddress as InetSocketAddress)
8996
}
9097
}
98+
99+
sockets.forEach { it.close() }
91100
}
92101

93102
private suspend fun assertTimeout(timeoutMillis: Long = 500, block: suspend () -> Unit) {

ktor-network/api/ktor-network.klib.api

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ final enum class io.ktor.network.selector/SelectInterest : kotlin/Enum<io.ktor.n
2323
final val AllInterests // io.ktor.network.selector/SelectInterest.Companion.AllInterests|{}AllInterests[0]
2424
final fun <get-AllInterests>(): kotlin/Array<io.ktor.network.selector/SelectInterest> // io.ktor.network.selector/SelectInterest.Companion.AllInterests.<get-AllInterests>|<get-AllInterests>(){}[0]
2525
}
26+
27+
// Targets: [native]
28+
enum entry CLOSE // io.ktor.network.selector/SelectInterest.CLOSE|null[0]
2629
}
2730

2831
abstract interface <#A: out io.ktor.network.sockets/ASocket> io.ktor.network.sockets/Acceptable : io.ktor.network.sockets/ASocket { // io.ktor.network.sockets/Acceptable|null[0]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
/*
2+
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package io.ktor.network.selector
6+
7+
internal expect abstract class SelectableBase() : Selectable
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package io.ktor.network.sockets
6+
7+
import io.ktor.network.selector.*
8+
import io.ktor.utils.io.*
9+
import kotlinx.atomicfu.*
10+
import kotlinx.coroutines.*
11+
import kotlinx.io.*
12+
import kotlin.coroutines.*
13+
14+
internal abstract class SocketBase(
15+
parent: CoroutineContext
16+
) : ReadWriteSocket, SelectableBase(), CoroutineScope {
17+
18+
private val closeFlag = atomic(false)
19+
20+
private val readerJob = atomic<ReaderJob?>(null)
21+
22+
private val writerJob = atomic<WriterJob?>(null)
23+
24+
override val socketContext: CompletableJob = Job(parent[Job])
25+
26+
override val coroutineContext: CoroutineContext
27+
get() = socketContext
28+
29+
override fun dispose() {
30+
close()
31+
}
32+
33+
override fun close() {
34+
if (!closeFlag.compareAndSet(false, true)) return
35+
36+
readerJob.value?.channel?.close()
37+
writerJob.value?.cancel()
38+
checkChannels()
39+
}
40+
41+
final override fun attachForReading(channel: ByteChannel): WriterJob {
42+
return attachFor("reading", channel, writerJob) {
43+
attachForReadingImpl(channel)
44+
}
45+
}
46+
47+
final override fun attachForWriting(channel: ByteChannel): ReaderJob {
48+
return attachFor("writing", channel, readerJob) {
49+
attachForWritingImpl(channel)
50+
}
51+
}
52+
53+
abstract fun attachForReadingImpl(channel: ByteChannel): WriterJob
54+
abstract fun attachForWritingImpl(channel: ByteChannel): ReaderJob
55+
56+
private inline fun <J : ChannelJob> attachFor(
57+
name: String,
58+
channel: ByteChannel,
59+
ref: AtomicRef<J?>,
60+
producer: () -> J
61+
): J {
62+
if (closeFlag.value) {
63+
val e = IOException("Socket closed")
64+
channel.close(e)
65+
throw e
66+
}
67+
68+
val j = producer()
69+
70+
if (!ref.compareAndSet(null, j)) {
71+
val e = IllegalStateException("$name channel has already been set")
72+
j.cancel()
73+
throw e
74+
}
75+
if (closeFlag.value) {
76+
val e = IOException("Socket closed")
77+
j.cancel()
78+
channel.close(e)
79+
throw e
80+
}
81+
82+
channel.attachJob(j)
83+
84+
j.invokeOnCompletion {
85+
checkChannels()
86+
}
87+
88+
return j
89+
}
90+
91+
internal abstract fun actualClose(): Throwable?
92+
93+
private fun checkChannels() {
94+
if (closeFlag.value && readerJob.completedOrNotStarted && writerJob.completedOrNotStarted) {
95+
val e1 = readerJob.exception
96+
val e2 = writerJob.exception
97+
val e3 = actualClose()
98+
99+
val combined = combine(combine(e1, e2), e3)
100+
101+
if (combined == null) socketContext.complete() else socketContext.completeExceptionally(combined)
102+
}
103+
}
104+
105+
private fun combine(e1: Throwable?, e2: Throwable?): Throwable? = when {
106+
e1 == null -> e2
107+
e2 == null -> e1
108+
e1 === e2 -> e1
109+
else -> {
110+
e1.addSuppressed(e2)
111+
e1
112+
}
113+
}
114+
115+
private inline val AtomicRef<out ChannelJob?>.completedOrNotStarted: Boolean
116+
get() = value.let { it == null || it.isCompleted }
117+
118+
private inline val AtomicRef<out ChannelJob?>.exception: Throwable?
119+
get() = value?.takeIf { it.isCancelled }
120+
?.getCancellationException()?.cause // TODO it should be completable deferred or provide its own exception
121+
}

0 commit comments

Comments
 (0)