Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid expensive SSL init on main #8248

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 69 additions & 46 deletions okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,15 @@ open class OkHttpClient internal constructor(
@get:JvmName("socketFactory")
val socketFactory: SocketFactory = builder.socketFactory

private val sslSocketFactoryOrNull: SSLSocketFactory?
private val sslInitializedFields: Lazy<SSLInitializedFields>?

@get:JvmName("sslSocketFactory")
val sslSocketFactory: SSLSocketFactory
get() = sslSocketFactoryOrNull ?: throw IllegalStateException("CLEARTEXT-only client")
get() = sslInitializedFields?.value?.sslSocketFactory ?: throw IllegalStateException("CLEARTEXT-only client")

@get:JvmName("x509TrustManager")
val x509TrustManager: X509TrustManager?
get() = sslInitializedFields?.value?.x509TrustManager

@get:JvmName("connectionSpecs")
val connectionSpecs: List<ConnectionSpec> =
Expand All @@ -219,11 +220,18 @@ open class OkHttpClient internal constructor(
@get:JvmName("hostnameVerifier")
val hostnameVerifier: HostnameVerifier = builder.hostnameVerifier

private lateinit var _certificatePinner: CertificatePinner

@get:JvmName("certificatePinner")
val certificatePinner: CertificatePinner
val certificatePinner: CertificatePinner by lazy {
certificateChainCleaner?.let {
_certificatePinner.withCertificateChainCleaner(it)
} ?: _certificatePinner
}

@get:JvmName("certificateChainCleaner")
val certificateChainCleaner: CertificateChainCleaner?
get() = sslInitializedFields?.value?.certificateChainCleaner

/**
* Default call timeout (in milliseconds). By default there is no timeout for complete calls, but
Expand Down Expand Up @@ -284,24 +292,24 @@ open class OkHttpClient internal constructor(

init {
if (connectionSpecs.none { it.isTls }) {
this.sslSocketFactoryOrNull = null
this.certificateChainCleaner = null
this.x509TrustManager = null
this.certificatePinner = CertificatePinner.DEFAULT
} else if (builder.sslSocketFactoryOrNull != null) {
this.sslSocketFactoryOrNull = builder.sslSocketFactoryOrNull
this.certificateChainCleaner = builder.certificateChainCleaner!!
this.x509TrustManager = builder.x509TrustManagerOrNull!!
this.certificatePinner =
builder.certificatePinner
.withCertificateChainCleaner(certificateChainCleaner!!)
this.sslInitializedFields = null
this._certificatePinner = CertificatePinner.DEFAULT
} else if (builder.sslInitializedFields != null) {
this.sslInitializedFields = builder.sslInitializedFields
this._certificatePinner = builder.certificatePinner
} else {
this.x509TrustManager = Platform.get().platformTrustManager()
this.sslSocketFactoryOrNull = Platform.get().newSslSocketFactory(x509TrustManager!!)
this.certificateChainCleaner = CertificateChainCleaner.get(x509TrustManager!!)
this.certificatePinner =
builder.certificatePinner
.withCertificateChainCleaner(certificateChainCleaner!!)
this.sslInitializedFields =
lazy {
val platform = Platform.get()
val trustManager = platform.platformTrustManager()
val certificateChainCleaner = CertificateChainCleaner.get(trustManager)
SSLInitializedFields(
trustManager,
platform.newSslSocketFactory(trustManager),
certificateChainCleaner,
)
}
this._certificatePinner = builder.certificatePinner
}

verifyClientState()
Expand Down Expand Up @@ -337,6 +345,12 @@ open class OkHttpClient internal constructor(
)
}

internal data class SSLInitializedFields(
val x509TrustManager: X509TrustManager,
val sslSocketFactory: SSLSocketFactory,
val certificateChainCleaner: CertificateChainCleaner,
)

private fun verifyClientState() {
check(null !in (interceptors as List<Interceptor?>)) {
"Null interceptor: $interceptors"
Expand All @@ -346,14 +360,10 @@ open class OkHttpClient internal constructor(
}

if (connectionSpecs.none { it.isTls }) {
check(sslSocketFactoryOrNull == null)
check(certificateChainCleaner == null)
check(x509TrustManager == null)
check(sslInitializedFields == null) { "ssl initialized for plaintext client" }
check(certificatePinner == CertificatePinner.DEFAULT)
} else {
checkNotNull(sslSocketFactoryOrNull) { "sslSocketFactory == null" }
checkNotNull(certificateChainCleaner) { "certificateChainCleaner == null" }
checkNotNull(x509TrustManager) { "x509TrustManager == null" }
checkNotNull(sslInitializedFields) { "ssl not initialized for client" }
}
}

Expand Down Expand Up @@ -609,13 +619,11 @@ open class OkHttpClient internal constructor(
internal var proxySelector: ProxySelector? = null
internal var proxyAuthenticator: Authenticator = Authenticator.NONE
internal var socketFactory: SocketFactory = SocketFactory.getDefault()
internal var sslSocketFactoryOrNull: SSLSocketFactory? = null
internal var x509TrustManagerOrNull: X509TrustManager? = null
internal var sslInitializedFields: Lazy<SSLInitializedFields>? = null
internal var connectionSpecs: List<ConnectionSpec> = DEFAULT_CONNECTION_SPECS
internal var protocols: List<Protocol> = DEFAULT_PROTOCOLS
internal var hostnameVerifier: HostnameVerifier = OkHostnameVerifier
internal var certificatePinner: CertificatePinner = CertificatePinner.DEFAULT
internal var certificateChainCleaner: CertificateChainCleaner? = null
internal var callTimeout = 0
internal var connectTimeout = 10_000
internal var readTimeout = 10_000
Expand Down Expand Up @@ -644,13 +652,11 @@ open class OkHttpClient internal constructor(
this.proxySelector = okHttpClient.proxySelector
this.proxyAuthenticator = okHttpClient.proxyAuthenticator
this.socketFactory = okHttpClient.socketFactory
this.sslSocketFactoryOrNull = okHttpClient.sslSocketFactoryOrNull
this.x509TrustManagerOrNull = okHttpClient.x509TrustManager
this.sslInitializedFields = okHttpClient.sslInitializedFields
this.connectionSpecs = okHttpClient.connectionSpecs
this.protocols = okHttpClient.protocols
this.hostnameVerifier = okHttpClient.hostnameVerifier
this.certificatePinner = okHttpClient.certificatePinner
this.certificateChainCleaner = okHttpClient.certificateChainCleaner
this.certificatePinner = okHttpClient._certificatePinner
this.callTimeout = okHttpClient.callTimeoutMillis
this.connectTimeout = okHttpClient.connectTimeoutMillis
this.readTimeout = okHttpClient.readTimeoutMillis
Expand Down Expand Up @@ -913,18 +919,25 @@ open class OkHttpClient internal constructor(
)
fun sslSocketFactory(sslSocketFactory: SSLSocketFactory) =
apply {
if (sslSocketFactory != this.sslSocketFactoryOrNull) {
if (sslSocketFactory != sslInitializedFields?.value?.sslSocketFactory) {
this.routeDatabase = null
}

this.sslSocketFactoryOrNull = sslSocketFactory
this.x509TrustManagerOrNull =
Platform.get().trustManager(sslSocketFactory) ?: throw IllegalStateException(
"Unable to extract the trust manager on ${Platform.get()}, " +
val platform = Platform.get()
val trustManager =
platform.trustManager(sslSocketFactory) ?: throw IllegalStateException(
"Unable to extract the trust manager on $platform, " +
"sslSocketFactory is ${sslSocketFactory.javaClass}",
)
this.certificateChainCleaner =
Platform.get().buildCertificateChainCleaner(x509TrustManagerOrNull!!)
// Expensive copy assuming SSL already initialized
sslInitializedFields =
lazyOf(
SSLInitializedFields(
trustManager,
sslSocketFactory = sslSocketFactory,
certificateChainCleaner = platform.buildCertificateChainCleaner(trustManager),
),
)
}

/**
Expand Down Expand Up @@ -976,13 +989,21 @@ open class OkHttpClient internal constructor(
sslSocketFactory: SSLSocketFactory,
trustManager: X509TrustManager,
) = apply {
if (sslSocketFactory != this.sslSocketFactoryOrNull || trustManager != this.x509TrustManagerOrNull) {
val existingSsl = sslInitializedFields?.value

if (sslSocketFactory != existingSsl?.sslSocketFactory || trustManager != existingSsl?.x509TrustManager) {
this.routeDatabase = null
}

this.sslSocketFactoryOrNull = sslSocketFactory
this.certificateChainCleaner = CertificateChainCleaner.get(trustManager)
this.x509TrustManagerOrNull = trustManager
// Expensive copy assuming SSL already initialized
sslInitializedFields =
lazyOf(
SSLInitializedFields(
trustManager,
sslSocketFactory = sslSocketFactory,
certificateChainCleaner = CertificateChainCleaner.get(trustManager),
),
)
}

fun connectionSpecs(connectionSpecs: List<ConnectionSpec>) =
Expand Down Expand Up @@ -1078,11 +1099,13 @@ open class OkHttpClient internal constructor(
*/
fun certificatePinner(certificatePinner: CertificatePinner) =
apply {
if (certificatePinner != this.certificatePinner) {
val cleanCertificatePinner = CertificatePinner(certificatePinner.pins)

if (cleanCertificatePinner != this.certificatePinner) {
this.routeDatabase = null
}

this.certificatePinner = certificatePinner
this.certificatePinner = cleanCertificatePinner
}

/**
Expand Down
78 changes: 78 additions & 0 deletions okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientConstructionTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (C) 2014 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okhttp3

import java.security.NoSuchAlgorithmException
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.X509TrustManager
import okhttp3.HttpUrl.Companion.toHttpUrl
import okhttp3.internal.platform.Platform
import okhttp3.testing.PlatformRule
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.extension.RegisterExtension

class OkHttpClientConstructionTest {
@RegisterExtension
var platform = PlatformRule()

@Test fun constructionDoesntTriggerPlatformOrSSL() {
Platform.resetForTests(platform = ExplosivePlatform { TODO("Avoid call") })

val client = OkHttpClient()

assertNotNull(client.toString())

client.newCall(Request("https://example.org/robots.txt".toHttpUrl()))
}

@Test fun cloneDoesntTriggerPlatformOrSSL() {
Platform.resetForTests(platform = ExplosivePlatform { TODO("Avoid call") })

val client = OkHttpClient()

val client2 = client.newBuilder().build()
assertNotNull(client2.toString())
}

@Test fun triggersOnExecute() {
Platform.resetForTests(platform = ExplosivePlatform { throw NoSuchAlgorithmException() })

val client = OkHttpClient()

val call = client.newCall(Request("https://example.org/robots.txt".toHttpUrl()))

assertThrows<NoSuchAlgorithmException> {
call.execute()
}
}

class ExplosivePlatform(private val explode: () -> Nothing) : Platform() {
override fun newSSLContext(): SSLContext {
explode()
}

override fun newSslSocketFactory(trustManager: X509TrustManager): SSLSocketFactory {
explode()
}

override fun platformTrustManager(): X509TrustManager {
explode()
}
}
}
15 changes: 14 additions & 1 deletion okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,15 @@ class OkHttpClientTest {
.routeDatabase,
)

// identical CertificatePinner
assertSame(
client.routeDatabase,
client.newBuilder()
.certificatePinner(CertificatePinner.Builder().build())
.build()
.routeDatabase,
)

// logically different scope of client for route db
assertNotSame(
client.routeDatabase,
Expand Down Expand Up @@ -423,7 +432,11 @@ class OkHttpClientTest {
assertNotSame(
client.routeDatabase,
client.newBuilder()
.certificatePinner(CertificatePinner.Builder().build())
.certificatePinner(
CertificatePinner.Builder()
.add("san.com", "sha1/afwiKY3RxoMmLkuRW1l7QsPZTJPwDS2pdDROQjXw8ig=")
.build(),
)
.build()
.routeDatabase,
)
Expand Down
Loading