From b52e7802b3535a99dee3926d54e9b0d276e2eabc Mon Sep 17 00:00:00 2001 From: Gaston Thea Date: Wed, 16 Jul 2025 10:14:58 -0300 Subject: [PATCH 1/3] Tunnel establishment --- .../network/SslProxyTunnelEstablisher.java | 187 +++++++++++++ .../SslProxyTunnelEstablisherTest.java | 246 ++++++++++++++++++ 2 files changed, 433 insertions(+) create mode 100644 src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java create mode 100644 src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java diff --git a/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java new file mode 100644 index 000000000..17c4a3276 --- /dev/null +++ b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java @@ -0,0 +1,187 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.net.HttpRetryException; +import java.net.HttpURLConnection; +import java.net.Socket; +import java.nio.charset.StandardCharsets; + +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +import io.split.android.client.utils.logger.Logger; + +/** + * Establishes SSL tunnels to SSL proxies using CONNECT protocol. + */ +class SslProxyTunnelEstablisher { + + private static final String CRLF = "\r\n"; + private static final String PROXY_AUTHORIZATION_HEADER = "Proxy-Authorization"; + + /** + * Establishes an SSL tunnel through the proxy using the CONNECT method. + * After successful tunnel establishment, extracts the underlying socket + * for use with origin server SSL connections. + * + * @param proxyHost The proxy server hostname + * @param proxyPort The proxy server port + * @param targetHost The target server hostname + * @param targetPort The target server port + * @param sslSocketFactory SSL socket factory for proxy authentication + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @return Raw socket with tunnel established (connection maintained) + * @throws IOException if tunnel establishment fails + */ + @NonNull + public Socket establishTunnel(@NonNull String proxyHost, + int proxyPort, + @NonNull String targetHost, + int targetPort, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + Socket rawSocket = null; + SSLSocket sslSocket = null; + + try { + // Step 1: Create raw TCP connection to proxy + rawSocket = new Socket(proxyHost, proxyPort); + rawSocket.setSoTimeout(10000); // 10 second timeout + + // Create a temporary SSL socket to establish the SSL session with proper trust validation + sslSocket = (SSLSocket) sslSocketFactory.createSocket(rawSocket, proxyHost, proxyPort, false); + sslSocket.setUseClientMode(true); + sslSocket.setSoTimeout(10000); // 10 second timeout + + // Perform SSL handshake using the SSL socket with custom CA certificates + sslSocket.startHandshake(); + + // Step 3: Send CONNECT request through SSL connection + sendConnectRequest(sslSocket, targetHost, targetPort, proxyCredentialsProvider); + + // Step 4: Validate CONNECT response through SSL connection + validateConnectResponse(sslSocket); + Logger.v("SSL tunnel established successfully"); + + // Step 5: Return SSL socket for tunnel communication + return sslSocket; + + } catch (Exception e) { + Logger.e("SSL tunnel establishment failed: " + e.getMessage()); + + // Clean up resources on error + if (sslSocket != null) { + try { + sslSocket.close(); + } catch (IOException closeEx) { + // Ignore close exceptions + } + } else if (rawSocket != null) { + try { + rawSocket.close(); + } catch (IOException closeEx) { + // Ignore close exceptions + } + } + + if (e instanceof HttpRetryException) { + throw (HttpRetryException) e; + } else if (e instanceof IOException) { + throw (IOException) e; + } else { + throw new IOException("Failed to establish SSL tunnel", e); + } + } + } + + /** + * Sends CONNECT request through SSL connection to proxy. + */ + private void sendConnectRequest(@NonNull SSLSocket sslSocket, + @NonNull String targetHost, + int targetPort, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + Logger.v("Sending CONNECT request through SSL: CONNECT " + targetHost + ":" + targetPort + " HTTP/1.1"); + + PrintWriter writer = new PrintWriter(new OutputStreamWriter(sslSocket.getOutputStream(), StandardCharsets.UTF_8), false); + writer.write("CONNECT " + targetHost + ":" + targetPort + " HTTP/1.1" + CRLF); + writer.write("Host: " + targetHost + ":" + targetPort + CRLF); + + if (proxyCredentialsProvider != null) { + // Send Proxy-Authorization header if credentials are set + String bearerToken = proxyCredentialsProvider.getBearerToken(); + if (bearerToken != null && !bearerToken.trim().isEmpty()) { + writer.write(PROXY_AUTHORIZATION_HEADER + ": Bearer " + bearerToken + CRLF); + } + } + + // Send empty line to end headers + writer.write(CRLF); + writer.flush(); + // Note: Don't close the writer as it would close the underlying socket + + Logger.v("CONNECT request sent through SSL connection"); + } + + /** + * Validates CONNECT response through SSL connection. + * Only reads status line and headers, leaving the stream open for tunneling. + */ + private void validateConnectResponse(@NonNull SSLSocket sslSocket) throws IOException { + + Logger.v("Reading CONNECT response through SSL connection"); + + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(sslSocket.getInputStream(), StandardCharsets.UTF_8)); + + String statusLine = reader.readLine(); + if (statusLine == null) { + throw new IOException("No CONNECT response received from proxy"); + } + + Logger.v("Received CONNECT response through SSL: " + statusLine.trim()); + + // Parse status code + String[] statusParts = statusLine.split(" "); + if (statusParts.length < 2) { + throw new IOException("Invalid CONNECT response status line: " + statusLine); + } + + int statusCode; + try { + statusCode = Integer.parseInt(statusParts[1]); + } catch (NumberFormatException e) { + throw new IOException("Invalid CONNECT response status code: " + statusLine, e); + } + + // Read headers until empty line (but don't process them for CONNECT) + String headerLine; + while ((headerLine = reader.readLine()) != null && !headerLine.trim().isEmpty()) { + Logger.v("CONNECT response header: " + headerLine); + } + + // Check status code + if (statusCode != 200) { + if (statusCode == HttpURLConnection.HTTP_PROXY_AUTH) { + throw new HttpRetryException("CONNECT request failed with status " + statusCode + ": " + statusLine, HttpURLConnection.HTTP_PROXY_AUTH); + } + throw new IOException("CONNECT request failed with status " + statusCode + ": " + statusLine); + } + } catch (IOException e) { + if (e instanceof HttpRetryException) { + throw e; + } + + throw new IOException("Failed to validate CONNECT response from proxy: " + e.getMessage(), e); + } + } +} diff --git a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java new file mode 100644 index 000000000..01f46df5e --- /dev/null +++ b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java @@ -0,0 +1,246 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.net.Socket; +import java.security.KeyStore; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.tls.HeldCertificate; + +public class SslProxyTunnelEstablisherTest { + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + private TestSslProxy testProxy; + private SSLSocketFactory clientSslSocketFactory; + private ProxyCredentialsProvider mProxyCredentialsProvider; + + @Before + public void setUp() throws Exception { + mProxyCredentialsProvider = mock(ProxyCredentialsProvider.class); + // Create test certificates + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + + // Create SSL socket factory that trusts the proxy CA + File proxyCaFile = tempFolder.newFile("proxy-ca.pem"); + try (FileWriter writer = new FileWriter(proxyCaFile)) { + writer.write(proxyCa.certificatePem()); + } + + ProxySslSocketFactoryProvider factory = new ProxySslSocketFactoryProviderImpl(); + try (java.io.FileInputStream caInput = new java.io.FileInputStream(proxyCaFile)) { + clientSslSocketFactory = factory.create(caInput); + } + + // Start test SSL proxy + testProxy = new TestSslProxy(0, proxyServerCert); + testProxy.start(); + + // Wait for proxy to start + while (testProxy.getPort() == 0) { + Thread.sleep(10); + } + } + + @After + public void tearDown() throws Exception { + if (testProxy != null) { + testProxy.stop(); + } + } + + @Test + public void establishTunnel_withValidSslProxy_succeeds() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + String targetHost = "example.com"; + int targetPort = 443; + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + targetHost, + targetPort, + clientSslSocketFactory, + mProxyCredentialsProvider); + + assertNotNull("Tunnel socket should not be null", tunnelSocket); + assertTrue("Tunnel socket should be connected", tunnelSocket.isConnected()); +// assertTrue("SSL handshake should be completed", tunnelSocket.getSession().isValid()); + + // Verify CONNECT request was sent and successful + assertTrue("Proxy should have received CONNECT request", + testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + assertEquals("CONNECT example.com:443 HTTP/1.1", testProxy.getReceivedConnectLine()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnel_withInvalidSslCertificate_throwsException() throws Exception { + SSLContext untrustedContext = SSLContext.getInstance("TLS"); + untrustedContext.init(null, null, null); // Use default trust manager (won't trust our proxy) + SSLSocketFactory untrustedSocketFactory = untrustedContext.getSocketFactory(); + + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + try { + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + untrustedSocketFactory, + mProxyCredentialsProvider); + fail("Should have thrown exception for untrusted certificate"); + } catch (IOException e) { + assertTrue("Exception should be SSL-related", e.getMessage().contains("certification")); + } + } + + @Test + public void establishTunnel_withProxyConnectionFailure_throwsException() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + try { + establisher.establishTunnel( + "localhost", + -1234, + "example.com", + 443, + clientSslSocketFactory, + mProxyCredentialsProvider); + fail("Should have thrown exception for connection failure"); + } catch (IOException e) { + // The implementation wraps the original exception with a descriptive message + assertTrue(e.getMessage().contains("Connection")); + } + } + + /** + * Test SSL proxy that accepts SSL connections and handles CONNECT requests. + */ + private static class TestSslProxy extends Thread { + private final int mPort; + private final HeldCertificate mServerCert; + private SSLServerSocket mServerSocket; + private final AtomicBoolean mRunning = new AtomicBoolean(true); + private final CountDownLatch mConnectRequestReceived = new CountDownLatch(1); + private final AtomicReference mReceivedConnectLine = new AtomicReference<>(); + + public TestSslProxy(int port, HeldCertificate serverCert) { + this.mPort = port; + this.mServerCert = serverCert; + } + + @Override + public void run() { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", mServerCert.keyPair().getPrivate(), "password".toCharArray(), + new java.security.cert.Certificate[]{mServerCert.certificate()}); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + sslContext.init(kmf.getKeyManagers(), null, null); + + mServerSocket = (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(mPort); + mServerSocket.setWantClientAuth(false); + mServerSocket.setNeedClientAuth(false); + + while (mRunning.get()) { + try { + Socket client = mServerSocket.accept(); + handleClient(client); + } catch (IOException e) { + if (mRunning.get()) { + System.err.println("Error accepting client: " + e.getMessage()); + } + } + } + } catch (Exception e) { + throw new RuntimeException("Failed to start test SSL proxy", e); + } + } + + private void handleClient(Socket client) { + try { + java.io.BufferedReader reader = new java.io.BufferedReader( + new java.io.InputStreamReader(client.getInputStream())); + java.io.PrintWriter writer = new java.io.PrintWriter(client.getOutputStream(), true); + + // Read CONNECT request + String connectLine = reader.readLine(); + if (connectLine != null && connectLine.startsWith("CONNECT")) { + mReceivedConnectLine.set(connectLine); + mConnectRequestReceived.countDown(); + + // Send successful CONNECT response + writer.println("HTTP/1.1 200 Connection established"); + writer.println(); + writer.flush(); + + // Keep connection open for tunnel + Thread.sleep(100); // Brief pause to simulate tunnel establishment + } + } catch (Exception e) { + System.err.println("Error handling client: " + e.getMessage()); + } finally { + try { + client.close(); + } catch (IOException e) { + // Ignore + } + } + } + + public int getPort() { + return mServerSocket != null ? mServerSocket.getLocalPort() : 0; + } + + public void stopRun() throws IOException { + mRunning.set(false); + if (mServerSocket != null) { + mServerSocket.close(); + } + } + + public CountDownLatch getConnectRequestReceived() { + return mConnectRequestReceived; + } + + public String getReceivedConnectLine() { + return mReceivedConnectLine.get(); + } + } +} From dad59ecefba97401e8482d1f40015acd3b1154cc Mon Sep 17 00:00:00 2001 From: Gaston Thea Date: Wed, 16 Jul 2025 10:50:29 -0300 Subject: [PATCH 2/3] Additional test --- .../network/SslProxyTunnelEstablisher.java | 1 - .../SslProxyTunnelEstablisherTest.java | 73 +++++++++++++------ 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java index 17c4a3276..8f5043b0a 100644 --- a/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java +++ b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java @@ -127,7 +127,6 @@ private void sendConnectRequest(@NonNull SSLSocket sslSocket, // Send empty line to end headers writer.write(CRLF); writer.flush(); - // Note: Don't close the writer as it would close the underlying socket Logger.v("CONNECT request sent through SSL connection"); } diff --git a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java index 01f46df5e..46541f087 100644 --- a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java +++ b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java @@ -12,9 +12,12 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; +import java.io.BufferedReader; import java.io.File; import java.io.FileWriter; import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintWriter; import java.net.Socket; import java.security.KeyStore; import java.util.concurrent.CountDownLatch; @@ -80,7 +83,7 @@ public void tearDown() throws Exception { } @Test - public void establishTunnel_withValidSslProxy_succeeds() throws Exception { + public void establishTunnelWithValidSslProxySucceeds() throws Exception { SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); String targetHost = "example.com"; int targetPort = 443; @@ -95,7 +98,6 @@ public void establishTunnel_withValidSslProxy_succeeds() throws Exception { assertNotNull("Tunnel socket should not be null", tunnelSocket); assertTrue("Tunnel socket should be connected", tunnelSocket.isConnected()); -// assertTrue("SSL handshake should be completed", tunnelSocket.getSession().isValid()); // Verify CONNECT request was sent and successful assertTrue("Proxy should have received CONNECT request", @@ -106,9 +108,9 @@ public void establishTunnel_withValidSslProxy_succeeds() throws Exception { } @Test - public void establishTunnel_withInvalidSslCertificate_throwsException() throws Exception { + public void establishTunnelWithNotTrustedCertificatedThrows() throws Exception { SSLContext untrustedContext = SSLContext.getInstance("TLS"); - untrustedContext.init(null, null, null); // Use default trust manager (won't trust our proxy) + untrustedContext.init(null, null, null); SSLSocketFactory untrustedSocketFactory = untrustedContext.getSocketFactory(); SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); @@ -128,7 +130,7 @@ public void establishTunnel_withInvalidSslCertificate_throwsException() throws E } @Test - public void establishTunnel_withProxyConnectionFailure_throwsException() throws Exception { + public void establishTunnelWithFailingProxyConnectionThrows() { SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); try { @@ -142,7 +144,30 @@ public void establishTunnel_withProxyConnectionFailure_throwsException() throws fail("Should have thrown exception for connection failure"); } catch (IOException e) { // The implementation wraps the original exception with a descriptive message - assertTrue(e.getMessage().contains("Connection")); + assertTrue(e.getMessage().contains("Failed to establish SSL tunnel")); + } + } + + @Test + public void bearerTokenIsPassedWhenSet() { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + try { + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new ProxyCredentialsProvider() { + @Override + public String getBearerToken() { + return "token"; + } + }); + boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); + assertTrue("Proxy should have received authorization header", await); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); } } @@ -155,11 +180,12 @@ private static class TestSslProxy extends Thread { private SSLServerSocket mServerSocket; private final AtomicBoolean mRunning = new AtomicBoolean(true); private final CountDownLatch mConnectRequestReceived = new CountDownLatch(1); + private final CountDownLatch mAuthorizationHeaderReceived = new CountDownLatch(1); private final AtomicReference mReceivedConnectLine = new AtomicReference<>(); public TestSslProxy(int port, HeldCertificate serverCert) { - this.mPort = port; - this.mServerCert = serverCert; + mPort = port; + mServerCert = serverCert; } @Override @@ -195,23 +221,29 @@ public void run() { private void handleClient(Socket client) { try { - java.io.BufferedReader reader = new java.io.BufferedReader( - new java.io.InputStreamReader(client.getInputStream())); - java.io.PrintWriter writer = new java.io.PrintWriter(client.getOutputStream(), true); + BufferedReader reader = new BufferedReader( + new InputStreamReader(client.getInputStream())); + PrintWriter writer = new PrintWriter(client.getOutputStream(), true); // Read CONNECT request - String connectLine = reader.readLine(); - if (connectLine != null && connectLine.startsWith("CONNECT")) { - mReceivedConnectLine.set(connectLine); + String line = reader.readLine(); + if (line != null && line.startsWith("CONNECT")) { + mReceivedConnectLine.set(line); mConnectRequestReceived.countDown(); + while((line = reader.readLine()) != null && !line.isEmpty()) { + if (line.contains("Authorization") && line.contains("Bearer")) { + mAuthorizationHeaderReceived.countDown(); + } + } + // Send successful CONNECT response writer.println("HTTP/1.1 200 Connection established"); writer.println(); writer.flush(); // Keep connection open for tunnel - Thread.sleep(100); // Brief pause to simulate tunnel establishment + Thread.sleep(100); } } catch (Exception e) { System.err.println("Error handling client: " + e.getMessage()); @@ -228,17 +260,14 @@ public int getPort() { return mServerSocket != null ? mServerSocket.getLocalPort() : 0; } - public void stopRun() throws IOException { - mRunning.set(false); - if (mServerSocket != null) { - mServerSocket.close(); - } - } - public CountDownLatch getConnectRequestReceived() { return mConnectRequestReceived; } + public CountDownLatch getAuthorizationHeaderReceived() { + return mAuthorizationHeaderReceived; + } + public String getReceivedConnectLine() { return mReceivedConnectLine.get(); } From ead58231048c154a347c791a0483fcdb84104774 Mon Sep 17 00:00:00 2001 From: Gaston Thea Date: Wed, 16 Jul 2025 11:23:50 -0300 Subject: [PATCH 3/3] Additional tests --- .../SslProxyTunnelEstablisherTest.java | 171 +++++++++++++++--- 1 file changed, 143 insertions(+), 28 deletions(-) diff --git a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java index 46541f087..733147e92 100644 --- a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java +++ b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java @@ -2,6 +2,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; @@ -18,6 +19,7 @@ import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; +import java.net.HttpRetryException; import java.net.Socket; import java.security.KeyStore; import java.util.concurrent.CountDownLatch; @@ -39,11 +41,9 @@ public class SslProxyTunnelEstablisherTest { private TestSslProxy testProxy; private SSLSocketFactory clientSslSocketFactory; - private ProxyCredentialsProvider mProxyCredentialsProvider; @Before public void setUp() throws Exception { - mProxyCredentialsProvider = mock(ProxyCredentialsProvider.class); // Create test certificates HeldCertificate proxyCa = new HeldCertificate.Builder() .commonName("Test Proxy CA") @@ -87,6 +87,7 @@ public void establishTunnelWithValidSslProxySucceeds() throws Exception { SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); String targetHost = "example.com"; int targetPort = 443; + ProxyCredentialsProvider proxyCredentialsProvider = mock(ProxyCredentialsProvider.class); Socket tunnelSocket = establisher.establishTunnel( "localhost", @@ -94,7 +95,7 @@ public void establishTunnelWithValidSslProxySucceeds() throws Exception { targetHost, targetPort, clientSslSocketFactory, - mProxyCredentialsProvider); + proxyCredentialsProvider); assertNotNull("Tunnel socket should not be null", tunnelSocket); assertTrue("Tunnel socket should be connected", tunnelSocket.isConnected()); @@ -112,6 +113,7 @@ public void establishTunnelWithNotTrustedCertificatedThrows() throws Exception { SSLContext untrustedContext = SSLContext.getInstance("TLS"); untrustedContext.init(null, null, null); SSLSocketFactory untrustedSocketFactory = untrustedContext.getSocketFactory(); + ProxyCredentialsProvider proxyCredentialsProvider = mock(ProxyCredentialsProvider.class); SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); @@ -122,7 +124,7 @@ public void establishTunnelWithNotTrustedCertificatedThrows() throws Exception { "example.com", 443, untrustedSocketFactory, - mProxyCredentialsProvider); + proxyCredentialsProvider); fail("Should have thrown exception for untrusted certificate"); } catch (IOException e) { assertTrue("Exception should be SSL-related", e.getMessage().contains("certification")); @@ -132,6 +134,7 @@ public void establishTunnelWithNotTrustedCertificatedThrows() throws Exception { @Test public void establishTunnelWithFailingProxyConnectionThrows() { SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + ProxyCredentialsProvider proxyCredentialsProvider = mock(ProxyCredentialsProvider.class); try { establisher.establishTunnel( @@ -140,7 +143,7 @@ public void establishTunnelWithFailingProxyConnectionThrows() { "example.com", 443, clientSslSocketFactory, - mProxyCredentialsProvider); + proxyCredentialsProvider); fail("Should have thrown exception for connection failure"); } catch (IOException e) { // The implementation wraps the original exception with a descriptive message @@ -149,26 +152,130 @@ public void establishTunnelWithFailingProxyConnectionThrows() { } @Test - public void bearerTokenIsPassedWhenSet() { + public void bearerTokenIsPassedWhenSet() throws IOException, InterruptedException { SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); - try { - establisher.establishTunnel( - "localhost", - testProxy.getPort(), - "example.com", - 443, - clientSslSocketFactory, - new ProxyCredentialsProvider() { - @Override - public String getBearerToken() { - return "token"; - } - }); - boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); - assertTrue("Proxy should have received authorization header", await); - } catch (IOException | InterruptedException e) { - throw new RuntimeException(e); - } + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new ProxyCredentialsProvider() { + @Override + public String getBearerToken() { + return "token"; + } + }); + boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); + assertTrue("Proxy should have received authorization header", await); + } + + @Test + public void establishTunnelWithNullCredentialsProviderDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNullBearerTokenDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + () -> null); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithEmptyBearerTokenDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + () -> " "); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNullStatusLineThrowsIOException() { + testProxy.setConnectResponse(null); + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + IOException exception = assertThrows(IOException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null)); + + assertNotNull(exception); + } + + @Test + public void establishTunnelWithMalformedStatusLineThrowsIOException() { + testProxy.setConnectResponse("HTTP/1.1"); // Malformed, missing status code + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + IOException exception = assertThrows(IOException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null)); + + assertNotNull(exception); + } + + @Test + public void establishTunnelWithProxyAuthRequiredThrowsHttpRetryException() { + testProxy.setConnectResponse("HTTP/1.1 407 Proxy Authentication Required"); + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + HttpRetryException exception = assertThrows(HttpRetryException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null)); + + assertEquals(407, exception.responseCode()); } /** @@ -182,6 +289,7 @@ private static class TestSslProxy extends Thread { private final CountDownLatch mConnectRequestReceived = new CountDownLatch(1); private final CountDownLatch mAuthorizationHeaderReceived = new CountDownLatch(1); private final AtomicReference mReceivedConnectLine = new AtomicReference<>(); + private final AtomicReference mConnectResponse = new AtomicReference<>("HTTP/1.1 200 Connection established"); public TestSslProxy(int port, HeldCertificate serverCert) { mPort = port; @@ -237,10 +345,13 @@ private void handleClient(Socket client) { } } - // Send successful CONNECT response - writer.println("HTTP/1.1 200 Connection established"); - writer.println(); - writer.flush(); + // Send configured CONNECT response + String response = mConnectResponse.get(); + if (response != null) { + writer.println(response); + writer.println(); + writer.flush(); + } // Keep connection open for tunnel Thread.sleep(100); @@ -271,5 +382,9 @@ public CountDownLatch getAuthorizationHeaderReceived() { public String getReceivedConnectLine() { return mReceivedConnectLine.get(); } + + public void setConnectResponse(String connectResponse) { + mConnectResponse.set(connectResponse); + } } }