Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a socket abstraction #8410

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mockwebserver-deprecated/api/mockwebserver.api
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public final class okhttp3/mockwebserver/MockWebServer : org/junit/rules/Externa
public final fun -deprecated_protocols ()Ljava/util/List;
public final fun -deprecated_protocols (Ljava/util/List;)V
public final fun -deprecated_requestCount ()I
public final fun -deprecated_serverSocketFactory (Ljavax/net/ServerSocketFactory;)V
public final fun -deprecated_serverSocketFactory (Lokhttp3/internal/socket/OkioServerSocketFactory;)V
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we hide this in this deprecated API, keep it still in terms of javax.net?

public fun <init> ()V
public fun close ()V
public final fun enqueue (Lokhttp3/mockwebserver/MockResponse;)V
Expand All @@ -80,7 +80,7 @@ public final class okhttp3/mockwebserver/MockWebServer : org/junit/rules/Externa
public final fun getPort ()I
public final fun getProtocolNegotiationEnabled ()Z
public final fun getRequestCount ()I
public final fun getServerSocketFactory ()Ljavax/net/ServerSocketFactory;
public final fun getServerSocketFactory ()Lokhttp3/internal/socket/OkioServerSocketFactory;
public final fun noClientAuth ()V
public final fun protocols ()Ljava/util/List;
public final fun requestClientAuth ()V
Expand All @@ -89,7 +89,7 @@ public final class okhttp3/mockwebserver/MockWebServer : org/junit/rules/Externa
public final fun setDispatcher (Lokhttp3/mockwebserver/Dispatcher;)V
public final fun setProtocolNegotiationEnabled (Z)V
public final fun setProtocols (Ljava/util/List;)V
public final fun setServerSocketFactory (Ljavax/net/ServerSocketFactory;)V
public final fun setServerSocketFactory (Lokhttp3/internal/socket/OkioServerSocketFactory;)V
public final fun shutdown ()V
public final fun start ()V
public final fun start (I)V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import java.net.Proxy
import java.util.concurrent.TimeUnit
import java.util.logging.Level
import java.util.logging.Logger
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
import okhttp3.ExperimentalOkHttpApi
import okhttp3.HttpUrl
import okhttp3.Protocol
import okhttp3.internal.socket.OkioServerSocketFactory
import org.junit.rules.ExternalResource

class MockWebServer : ExternalResource(), Closeable {
Expand All @@ -37,7 +37,7 @@ class MockWebServer : ExternalResource(), Closeable {

var bodyLimit: Long by delegate::bodyLimit

var serverSocketFactory: ServerSocketFactory? by delegate::serverSocketFactory
var serverSocketFactory: OkioServerSocketFactory? by delegate::serverSocketFactory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm wondering whether dropping down to the native implementation should be a feature, and so here we could grab the real ServerSocketFactory?


var dispatcher: Dispatcher = QueueDispatcher()
set(value) {
Expand Down Expand Up @@ -103,7 +103,7 @@ class MockWebServer : ExternalResource(), Closeable {
),
level = DeprecationLevel.ERROR,
)
fun setServerSocketFactory(serverSocketFactory: ServerSocketFactory) {
fun setServerSocketFactory(serverSocketFactory: OkioServerSocketFactory) {
delegate.serverSocketFactory = serverSocketFactory
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ package okhttp3.mockwebserver
import java.io.IOException
import java.net.Inet6Address
import java.net.Socket
import javax.net.ssl.SSLSocket
import okhttp3.Handshake
import okhttp3.Handshake.Companion.handshake
import okhttp3.Headers
import okhttp3.HttpUrl
import okhttp3.HttpUrl.Companion.toHttpUrlOrNull
import okhttp3.TlsVersion
import okhttp3.internal.socket.OkioSslSocket
import okio.Buffer

class RecordedRequest {
Expand Down Expand Up @@ -97,9 +97,9 @@ class RecordedRequest {
this.sequenceNumber = sequenceNumber
this.failure = failure

if (socket is SSLSocket) {
if (socket is OkioSslSocket) {
try {
this.handshake = socket.session.handshake()
this.handshake = socket.session?.handshake()
} catch (e: IOException) {
throw IllegalArgumentException(e)
}
Expand All @@ -117,7 +117,7 @@ class RecordedRequest {
}
this.path = path

val scheme = if (socket is SSLSocket) "https" else "http"
val scheme = if (socket is OkioSslSocket) "https" else "http"
val inetAddress = socket.localAddress

var hostname = inetAddress.hostName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import java.net.InetAddress
import java.net.Proxy
import java.net.Socket
import java.util.concurrent.TimeUnit
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLSocketFactory
import okhttp3.Handshake
import okhttp3.Headers
Expand All @@ -29,6 +28,7 @@ import okhttp3.Protocol
import okhttp3.TlsVersion
import okhttp3.WebSocketListener
import okhttp3.internal.http2.Settings
import okhttp3.internal.socket.RealOkioServerSocketFactory
import okio.Buffer
import org.junit.Ignore
import org.junit.Test
Expand Down Expand Up @@ -115,7 +115,7 @@ class KotlinSourceModernTest {
var hostName: String = mockWebServer.hostName
hostName = mockWebServer.hostName
val toProxyAddress: Proxy = mockWebServer.toProxyAddress()
mockWebServer.serverSocketFactory = ServerSocketFactory.getDefault()
mockWebServer.serverSocketFactory = RealOkioServerSocketFactory()
val url: HttpUrl = mockWebServer.url("")
mockWebServer.bodyLimit = 0L
mockWebServer.protocolNegotiationEnabled = false
Expand Down
8 changes: 4 additions & 4 deletions mockwebserver/api/mockwebserver3.api
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public final class mockwebserver3/MockWebServer : java/io/Closeable {
public final fun getProtocolNegotiationEnabled ()Z
public final fun getProtocols ()Ljava/util/List;
public final fun getRequestCount ()I
public final fun getServerSocketFactory ()Ljavax/net/ServerSocketFactory;
public final fun getServerSocketFactory ()Lokhttp3/internal/socket/OkioServerSocketFactory;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we feel about breaking changes to MockWebServer public methods?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this new mockwebserver seems fine, we should avoid changes to okhttp3/mockwebserver/* in the deprecated package.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I guess these APIs in the non deprecated API become experimental APIs, is that ok?

public final fun getStarted ()Z
public final fun noClientAuth ()V
public final fun requestClientAuth ()V
Expand All @@ -116,7 +116,7 @@ public final class mockwebserver3/MockWebServer : java/io/Closeable {
public final fun setDispatcher (Lmockwebserver3/Dispatcher;)V
public final fun setProtocolNegotiationEnabled (Z)V
public final fun setProtocols (Ljava/util/List;)V
public final fun setServerSocketFactory (Ljavax/net/ServerSocketFactory;)V
public final fun setServerSocketFactory (Lokhttp3/internal/socket/OkioServerSocketFactory;)V
public final fun setStarted (Z)V
public final fun shutdown ()V
public final fun start ()V
Expand Down Expand Up @@ -159,8 +159,8 @@ public final class mockwebserver3/QueueDispatcher$Companion {
}

public final class mockwebserver3/RecordedRequest {
public fun <init> (Ljava/lang/String;Lokhttp3/Headers;Ljava/util/List;JLokio/Buffer;ILjava/net/Socket;Ljava/io/IOException;)V
public synthetic fun <init> (Ljava/lang/String;Lokhttp3/Headers;Ljava/util/List;JLokio/Buffer;ILjava/net/Socket;Ljava/io/IOException;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Ljava/lang/String;Lokhttp3/Headers;Ljava/util/List;JLokio/Buffer;ILokhttp3/internal/socket/OkioSocket;Ljava/io/IOException;)V
public synthetic fun <init> (Ljava/lang/String;Lokhttp3/Headers;Ljava/util/List;JLokio/Buffer;ILokhttp3/internal/socket/OkioSocket;Ljava/io/IOException;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun getBody ()Lokio/Buffer;
public final fun getBodySize ()J
public final fun getChunkSizes ()Ljava/util/List;
Expand Down
79 changes: 32 additions & 47 deletions mockwebserver/src/main/kotlin/mockwebserver3/MockWebServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.ProtocolException
import java.net.Proxy
import java.net.ServerSocket
import java.net.Socket
import java.net.SocketException
import java.security.SecureRandom
import java.security.cert.CertificateException
Expand All @@ -39,7 +37,6 @@ import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Level
import java.util.logging.Logger
import javax.net.ServerSocketFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSocket
import javax.net.ssl.SSLSocketFactory
Expand Down Expand Up @@ -80,6 +77,13 @@ import okhttp3.internal.http2.Http2Connection
import okhttp3.internal.http2.Http2Stream
import okhttp3.internal.immutableListOf
import okhttp3.internal.platform.Platform
import okhttp3.internal.socket.OkioServerSocket
import okhttp3.internal.socket.OkioServerSocketFactory
import okhttp3.internal.socket.OkioSocket
import okhttp3.internal.socket.OkioSslSocketFactory
import okhttp3.internal.socket.RealOkioServerSocketFactory
import okhttp3.internal.socket.RealOkioSslSocket
import okhttp3.internal.socket.RealOkioSslSocketFactory
import okhttp3.internal.threadFactory
import okhttp3.internal.toImmutableList
import okhttp3.internal.ws.RealWebSocket
Expand All @@ -91,8 +95,6 @@ import okio.BufferedSource
import okio.Sink
import okio.Timeout
import okio.buffer
import okio.sink
import okio.source

/**
* A scriptable web server. Callers supply canned responses and the server replays them upon request
Expand All @@ -107,7 +109,7 @@ class MockWebServer : Closeable {
private val taskRunner = TaskRunner(taskRunnerBackend)
private val requestQueue = LinkedBlockingQueue<RecordedRequest>()
private val openClientSockets =
Collections.newSetFromMap(ConcurrentHashMap<Socket, Boolean>())
Collections.newSetFromMap(ConcurrentHashMap<OkioSocket, Boolean>())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given OkioSocket is an interface, does this weaken our safe assumptions about equals/hashcode being object reference ones?

Not sure if this means

a) documenting the equals/hashcode behaviour
b) introducing some Socket identifier?
c) changing this to a IdentityHashMap?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great question. My instinct is to document that equals/hashcode are referenced-based, because I think that's the only reasonable way to think about socket equality. I think I would be surprised if sockets were considered equal based on anything than their unique state, at least by default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, although a socket is a 4-tuple of host:port, host:port.

And equality could break with decoration used this way.

Probably documenting reference equality is right, just bouncing ideas.

private val openConnections =
Collections.newSetFromMap(ConcurrentHashMap<Http2Connection, Boolean>())

Expand All @@ -123,10 +125,10 @@ class MockWebServer : Closeable {
/** The number of bytes of the POST body to keep in memory to the given limit. */
var bodyLimit: Long = Long.MAX_VALUE

var serverSocketFactory: ServerSocketFactory? = null
var serverSocketFactory: OkioServerSocketFactory? = null
@Synchronized get() {
if (field == null && started) {
field = ServerSocketFactory.getDefault() // Build the default value lazily.
field = RealOkioServerSocketFactory()
}
return field
}
Expand All @@ -136,8 +138,8 @@ class MockWebServer : Closeable {
field = value
}

private var serverSocket: ServerSocket? = null
private var sslSocketFactory: SSLSocketFactory? = null
private var serverSocket: OkioServerSocket? = null
private var sslSocketFactory: OkioSslSocketFactory? = null
private var clientAuth = CLIENT_AUTH_NONE

/**
Expand Down Expand Up @@ -232,7 +234,7 @@ class MockWebServer : Closeable {
* Serve requests with HTTPS rather than otherwise.
*/
fun useHttps(sslSocketFactory: SSLSocketFactory) {
this.sslSocketFactory = sslSocketFactory
this.sslSocketFactory = RealOkioSslSocketFactory(sslSocketFactory)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth us having a more opinionated extenion method for these?

fun SslSocketFactory.asOkioSocketFactory(): OkioSslSocketFactory = ...

}

/**
Expand Down Expand Up @@ -337,7 +339,7 @@ class MockWebServer : Closeable {

this._inetSocketAddress = inetSocketAddress

serverSocket = serverSocketFactory!!.createServerSocket()
serverSocket = serverSocketFactory!!.newServerSocket()

// Reuse if the user specified a port
serverSocket!!.reuseAddress = inetSocketAddress.port != 0
Expand Down Expand Up @@ -374,7 +376,7 @@ class MockWebServer : Closeable {
@Throws(Exception::class)
private fun acceptConnections() {
while (true) {
val socket: Socket
val socket: OkioSocket
try {
socket = serverSocket!!.accept()
} catch (e: SocketException) {
Expand Down Expand Up @@ -414,7 +416,7 @@ class MockWebServer : Closeable {
taskRunnerBackend.shutdown()
}

private fun serveConnection(raw: Socket) {
private fun serveConnection(raw: OkioSocket) {
taskRunner.newQueue().execute("MockWebServer ${raw.remoteSocketAddress}", cancelable = false) {
try {
SocketHandler(raw).handle()
Expand All @@ -426,7 +428,7 @@ class MockWebServer : Closeable {
}
}

internal inner class SocketHandler(private val raw: Socket) {
internal inner class SocketHandler(private val raw: OkioSocket) {
private var sequenceNumber = 0

@Throws(Exception::class)
Expand All @@ -435,22 +437,16 @@ class MockWebServer : Closeable {

val socketPolicy = dispatcher.peek().socketPolicy
val protocol: Protocol
val socket: Socket
val socket: OkioSocket
when {
sslSocketFactory != null -> {
if (socketPolicy === FailHandshake) {
dispatchBookkeepingRequest(sequenceNumber, raw)
processHandshakeFailure(raw)
return
}
socket =
sslSocketFactory!!.createSocket(
raw,
raw.inetAddress.hostAddress,
raw.port,
true,
)
val sslSocket = socket as SSLSocket
socket = sslSocketFactory!!.createSocket(raw)
val sslSocket = (socket as RealOkioSslSocket).delegate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it deserves a TODO, eventually I assume we want this to be supported without dropping down to the platform SslSocket?

sslSocket.useClientMode = false
if (clientAuth == CLIENT_AUTH_REQUIRED) {
sslSocket.needClientAuth = true
Expand Down Expand Up @@ -508,10 +504,7 @@ class MockWebServer : Closeable {
throw AssertionError()
}

val source = socket.source().buffer()
val sink = socket.sink().buffer()

while (processOneRequest(socket, source, sink)) {
while (processOneRequest(socket, socket.source, socket.sink)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you simplify this to just passing socket here?

}

if (sequenceNumber == 0) {
Expand All @@ -532,10 +525,8 @@ class MockWebServer : Closeable {
private fun processTunnelRequests(): Boolean {
if (!dispatcher.peek().inTunnel) return true // No tunnel requests.

val source = raw.source().buffer()
val sink = raw.sink().buffer()
while (true) {
val socketStillGood = processOneRequest(raw, source, sink)
val socketStillGood = processOneRequest(raw, raw.source, raw.sink)

// Clean up after the last exchange on a socket.
if (!socketStillGood) {
Expand All @@ -554,7 +545,7 @@ class MockWebServer : Closeable {
*/
@Throws(IOException::class, InterruptedException::class)
private fun processOneRequest(
socket: Socket,
socket: OkioSocket,
source: BufferedSource,
sink: BufferedSink,
): Boolean {
Expand Down Expand Up @@ -621,17 +612,11 @@ class MockWebServer : Closeable {
}

@Throws(Exception::class)
private fun processHandshakeFailure(raw: Socket) {
private fun processHandshakeFailure(raw: OkioSocket) {
val context = SSLContext.getInstance("TLS")
context.init(null, arrayOf<TrustManager>(UNTRUSTED_TRUST_MANAGER), SecureRandom())
val sslSocketFactory = context.socketFactory
val socket =
sslSocketFactory.createSocket(
raw,
raw.inetAddress.hostAddress,
raw.port,
true,
) as SSLSocket
val sslSocketFactory = RealOkioSslSocketFactory(context.socketFactory)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this worth an extension method?

context.okioSslSocketFactory

val socket = sslSocketFactory.createSocket(raw)
try {
socket.startHandshake() // we're testing a handshake failure
throw AssertionError()
Expand All @@ -643,7 +628,7 @@ class MockWebServer : Closeable {
@Throws(InterruptedException::class)
private fun dispatchBookkeepingRequest(
sequenceNumber: Int,
socket: Socket,
socket: OkioSocket,
) {
val request =
RecordedRequest(
Expand All @@ -663,7 +648,7 @@ class MockWebServer : Closeable {
/** @param sequenceNumber the index of this request on this connection.*/
@Throws(IOException::class)
private fun readRequest(
socket: Socket,
socket: OkioSocket,
source: BufferedSource,
sink: BufferedSink,
sequenceNumber: Int,
Expand Down Expand Up @@ -764,7 +749,7 @@ class MockWebServer : Closeable {

@Throws(IOException::class)
private fun handleWebSocketUpgrade(
socket: Socket,
socket: OkioSocket,
source: BufferedSource,
sink: BufferedSink,
request: RecordedRequest,
Expand Down Expand Up @@ -829,7 +814,7 @@ class MockWebServer : Closeable {

@Throws(IOException::class)
private fun writeHttpResponse(
socket: Socket,
socket: OkioSocket,
sink: BufferedSink,
response: MockResponse,
) {
Expand Down Expand Up @@ -876,7 +861,7 @@ class MockWebServer : Closeable {
policy: MockResponse,
disconnectHalfway: Boolean,
expectedByteCount: Long,
socket: Socket,
socket: OkioSocket,
): Sink {
var result: Sink = this

Expand Down Expand Up @@ -956,7 +941,7 @@ class MockWebServer : Closeable {

/** Processes HTTP requests layered over HTTP/2. */
private inner class Http2SocketHandler constructor(
private val socket: Socket,
private val socket: OkioSocket,
private val protocol: Protocol,
) : Http2Connection.Listener() {
private val sequenceNumber = AtomicInteger()
Expand Down
Loading