Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow workaround for Proxy.HTTPS #8379

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package okhttp3.containers

import assertk.assertThat
import assertk.assertions.contains
import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import okhttp3.HttpUrl.Companion.toHttpUrl
Expand Down Expand Up @@ -84,15 +85,20 @@ class BasicMockServerTest {

fun OkHttpClient.Builder.trustMockServer(): OkHttpClient.Builder =
apply {
val keyStoreFactory = KeyStoreFactory(Configuration.configuration(), MockServerLogger())

val socketFactory = keyStoreFactory.sslContext().socketFactory

val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(keyStoreFactory.loadOrCreateKeyStore())
val trustManager = trustManagerFactory.trustManagers.first() as X509TrustManager
val (socketFactory, trustManager) = trustManagerPair()

sslSocketFactory(socketFactory, trustManager)
}

fun trustManagerPair(): Pair<SSLSocketFactory, X509TrustManager> {
val keyStoreFactory = KeyStoreFactory(Configuration.configuration(), MockServerLogger())

val socketFactory = keyStoreFactory.sslContext().socketFactory

val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(keyStoreFactory.loadOrCreateKeyStore())
val trustManager = trustManagerFactory.trustManagers.first() as X509TrustManager
return Pair(socketFactory, trustManager)
}
}
}
45 changes: 45 additions & 0 deletions container-tests/src/test/java/okhttp3/containers/BasicProxyTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.containers.BasicMockServerTest.Companion.MOCKSERVER_IMAGE
import okhttp3.containers.BasicMockServerTest.Companion.trustManagerPair
import okhttp3.containers.BasicMockServerTest.Companion.trustMockServer
import okio.buffer
import okio.source
Expand Down Expand Up @@ -104,6 +105,13 @@ class BasicProxyTest {
@Test
fun testOkHttpSecureProxiedHttp1() {
testRequest {
it.withProxyConfiguration(
ProxyConfiguration.proxyConfiguration(
ProxyConfiguration.Type.HTTPS,
it.remoteAddress(),
),
)

val client =
OkHttpClient.Builder()
.trustMockServer()
Expand All @@ -121,6 +129,36 @@ class BasicProxyTest {
}
}

@Test
fun testOkHttpSecureProxiedHttp2() {
testRequest {
it.withProxyConfiguration(
ProxyConfiguration.proxyConfiguration(
ProxyConfiguration.Type.HTTPS,
it.remoteAddress(),
),
)

val (socketFactory, trustManager) = trustManagerPair()

val client =
OkHttpClient.Builder()
.sslSocketFactory(socketFactory, trustManager)
.proxy(Proxy(Proxy.Type.HTTP, it.remoteAddress()))
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
.socketFactory(socketFactory)
.build()

val response =
client.newCall(
Request((mockServer.secureEndpoint + "/person?name=peter").toHttpUrl()),
).execute()

assertThat(response.body.string()).contains("Peter the person")
assertThat(response.protocol).isEqualTo(Protocol.HTTP_2)
}
}

@Test
fun testUrlConnectionDirect() {
testRequest {
Expand Down Expand Up @@ -169,6 +207,13 @@ class BasicProxyTest {
HttpsURLConnection.setDefaultSSLSocketFactory(keyStoreFactory.sslContext().socketFactory)

testRequest {
it.withProxyConfiguration(
ProxyConfiguration.proxyConfiguration(
ProxyConfiguration.Type.HTTPS,
it.remoteAddress(),
),
)

val proxy =
Proxy(
Proxy.Type.HTTP,
Expand Down
29 changes: 28 additions & 1 deletion okcurl/src/main/kotlin/okhttp3/curl/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
package okhttp3.curl

import com.github.ajalt.clikt.core.CliktCommand
import com.github.ajalt.clikt.core.UsageError
import com.github.ajalt.clikt.parameters.arguments.argument
import com.github.ajalt.clikt.parameters.options.default
import com.github.ajalt.clikt.parameters.options.flag
import com.github.ajalt.clikt.parameters.options.multiple
import com.github.ajalt.clikt.parameters.options.option
import com.github.ajalt.clikt.parameters.types.int
import java.net.InetSocketAddress
import java.net.Proxy
import java.security.cert.X509Certificate
import java.util.Properties
import java.util.concurrent.TimeUnit.SECONDS
Expand Down Expand Up @@ -74,6 +77,8 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we

val sslDebug: Boolean by option(help = "Output SSL Debug").flag()

val proxy: String? by option(help = "Proxy config")

val url: String? by argument(name = "url", help = "Remote resource URL")

var client: Call.Factory? = null
Expand All @@ -98,16 +103,37 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we
if (callTimeout != DEFAULT_TIMEOUT) {
builder.callTimeout(callTimeout.toLong(), SECONDS)
}
var sslSocketFactory: SSLSocketFactory? = null
if (allowInsecure) {
val trustManager = createInsecureTrustManager()
val sslSocketFactory = createInsecureSslSocketFactory(trustManager)
sslSocketFactory = createInsecureSslSocketFactory(trustManager)
builder.sslSocketFactory(sslSocketFactory, trustManager)
builder.hostnameVerifier(createInsecureHostnameVerifier())
}
if (verbose) {
val logger = HttpLoggingInterceptor.Logger(::println)
builder.eventListenerFactory(LoggingEventListener.Factory(logger))
}
proxy?.let {
val (type, host, port) = it.split(':', limit = 3)
val address = InetSocketAddress.createUnresolved(host, port.toInt())
when (type) {
"http" -> {
builder.proxy(Proxy(Proxy.Type.HTTP, address))
}

"https" -> {
builder.proxy(Proxy(Proxy.Type.HTTP, address))
.socketFactory(sslSocketFactory ?: Platform.get().newSslSocketFactory(Platform.get().platformTrustManager()))
}

"socks4" -> {
builder.proxy(Proxy(Proxy.Type.SOCKS, address))
}

else -> throw UsageError("Unknown proxy '$it'")
}
}
return builder.build()
}

Expand All @@ -129,6 +155,7 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we
return prop.getProperty("version", "dev")
}

@Suppress("TrustAllX509TrustManager", "CustomX509TrustManager")
private fun createInsecureTrustManager(): X509TrustManager =
object : X509TrustManager {
override fun checkClientTrusted(
Expand Down
6 changes: 4 additions & 2 deletions okhttp/src/main/kotlin/okhttp3/OkHttpClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ open class OkHttpClient internal constructor(
checkNotNull(certificateChainCleaner) { "certificateChainCleaner == null" }
checkNotNull(x509TrustManager) { "x509TrustManager == null" }
}

if ((proxy?.type() ?: Proxy.Type.DIRECT) == Proxy.Type.DIRECT && socketFactory is SSLSocketFactory) {
Platform.get().log("socketFactory is SSLSocketFactory without Proxy", Platform.WARN)
}
}

/** Prepares the [request] to be executed at some point in the future. */
Expand Down Expand Up @@ -890,8 +894,6 @@ open class OkHttpClient internal constructor(
*/
fun socketFactory(socketFactory: SocketFactory) =
apply {
require(socketFactory !is SSLSocketFactory) { "socketFactory instanceof SSLSocketFactory" }

if (socketFactory != this.socketFactory) {
this.routeDatabase = null
}
Expand Down
10 changes: 5 additions & 5 deletions okhttp/src/test/java/okhttp3/OkHttpClientTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ class OkHttpClientTest {
assertThat(response.body.string()).isEqualTo("abc")
}

@Test fun sslSocketFactorySetAsSocketFactory() {
val builder = OkHttpClient.Builder()
assertFailsWith<IllegalArgumentException> {
builder.socketFactory(SSLSocketFactory.getDefault())
}
@Test
fun sslSocketFactorySetAsSocketFactory() {
OkHttpClient.Builder()
.socketFactory(SSLSocketFactory.getDefault())
.build()
}

@Test fun noSslSocketFactoryConfigured() {
Expand Down
Loading