diff --git a/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java new file mode 100644 index 000000000..8f5043b0a --- /dev/null +++ b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java @@ -0,0 +1,186 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.net.HttpRetryException; +import java.net.HttpURLConnection; +import java.net.Socket; +import java.nio.charset.StandardCharsets; + +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +import io.split.android.client.utils.logger.Logger; + +/** + * Establishes SSL tunnels to SSL proxies using CONNECT protocol. + */ +class SslProxyTunnelEstablisher { + + private static final String CRLF = "\r\n"; + private static final String PROXY_AUTHORIZATION_HEADER = "Proxy-Authorization"; + + /** + * Establishes an SSL tunnel through the proxy using the CONNECT method. + * After successful tunnel establishment, extracts the underlying socket + * for use with origin server SSL connections. + * + * @param proxyHost The proxy server hostname + * @param proxyPort The proxy server port + * @param targetHost The target server hostname + * @param targetPort The target server port + * @param sslSocketFactory SSL socket factory for proxy authentication + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @return Raw socket with tunnel established (connection maintained) + * @throws IOException if tunnel establishment fails + */ + @NonNull + public Socket establishTunnel(@NonNull String proxyHost, + int proxyPort, + @NonNull String targetHost, + int targetPort, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + Socket rawSocket = null; + SSLSocket sslSocket = null; + + try { + // Step 1: Create raw TCP connection to proxy + rawSocket = new Socket(proxyHost, proxyPort); + rawSocket.setSoTimeout(10000); // 10 second timeout + + // Create a temporary SSL socket to establish the SSL session with proper trust validation + sslSocket = (SSLSocket) sslSocketFactory.createSocket(rawSocket, proxyHost, proxyPort, false); + sslSocket.setUseClientMode(true); + sslSocket.setSoTimeout(10000); // 10 second timeout + + // Perform SSL handshake using the SSL socket with custom CA certificates + sslSocket.startHandshake(); + + // Step 3: Send CONNECT request through SSL connection + sendConnectRequest(sslSocket, targetHost, targetPort, proxyCredentialsProvider); + + // Step 4: Validate CONNECT response through SSL connection + validateConnectResponse(sslSocket); + Logger.v("SSL tunnel established successfully"); + + // Step 5: Return SSL socket for tunnel communication + return sslSocket; + + } catch (Exception e) { + Logger.e("SSL tunnel establishment failed: " + e.getMessage()); + + // Clean up resources on error + if (sslSocket != null) { + try { + sslSocket.close(); + } catch (IOException closeEx) { + // Ignore close exceptions + } + } else if (rawSocket != null) { + try { + rawSocket.close(); + } catch (IOException closeEx) { + // Ignore close exceptions + } + } + + if (e instanceof HttpRetryException) { + throw (HttpRetryException) e; + } else if (e instanceof IOException) { + throw (IOException) e; + } else { + throw new IOException("Failed to establish SSL tunnel", e); + } + } + } + + /** + * Sends CONNECT request through SSL connection to proxy. + */ + private void sendConnectRequest(@NonNull SSLSocket sslSocket, + @NonNull String targetHost, + int targetPort, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + Logger.v("Sending CONNECT request through SSL: CONNECT " + targetHost + ":" + targetPort + " HTTP/1.1"); + + PrintWriter writer = new PrintWriter(new OutputStreamWriter(sslSocket.getOutputStream(), StandardCharsets.UTF_8), false); + writer.write("CONNECT " + targetHost + ":" + targetPort + " HTTP/1.1" + CRLF); + writer.write("Host: " + targetHost + ":" + targetPort + CRLF); + + if (proxyCredentialsProvider != null) { + // Send Proxy-Authorization header if credentials are set + String bearerToken = proxyCredentialsProvider.getBearerToken(); + if (bearerToken != null && !bearerToken.trim().isEmpty()) { + writer.write(PROXY_AUTHORIZATION_HEADER + ": Bearer " + bearerToken + CRLF); + } + } + + // Send empty line to end headers + writer.write(CRLF); + writer.flush(); + + Logger.v("CONNECT request sent through SSL connection"); + } + + /** + * Validates CONNECT response through SSL connection. + * Only reads status line and headers, leaving the stream open for tunneling. + */ + private void validateConnectResponse(@NonNull SSLSocket sslSocket) throws IOException { + + Logger.v("Reading CONNECT response through SSL connection"); + + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(sslSocket.getInputStream(), StandardCharsets.UTF_8)); + + String statusLine = reader.readLine(); + if (statusLine == null) { + throw new IOException("No CONNECT response received from proxy"); + } + + Logger.v("Received CONNECT response through SSL: " + statusLine.trim()); + + // Parse status code + String[] statusParts = statusLine.split(" "); + if (statusParts.length < 2) { + throw new IOException("Invalid CONNECT response status line: " + statusLine); + } + + int statusCode; + try { + statusCode = Integer.parseInt(statusParts[1]); + } catch (NumberFormatException e) { + throw new IOException("Invalid CONNECT response status code: " + statusLine, e); + } + + // Read headers until empty line (but don't process them for CONNECT) + String headerLine; + while ((headerLine = reader.readLine()) != null && !headerLine.trim().isEmpty()) { + Logger.v("CONNECT response header: " + headerLine); + } + + // Check status code + if (statusCode != 200) { + if (statusCode == HttpURLConnection.HTTP_PROXY_AUTH) { + throw new HttpRetryException("CONNECT request failed with status " + statusCode + ": " + statusLine, HttpURLConnection.HTTP_PROXY_AUTH); + } + throw new IOException("CONNECT request failed with status " + statusCode + ": " + statusLine); + } + } catch (IOException e) { + if (e instanceof HttpRetryException) { + throw e; + } + + throw new IOException("Failed to validate CONNECT response from proxy: " + e.getMessage(), e); + } + } +} diff --git a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java new file mode 100644 index 000000000..733147e92 --- /dev/null +++ b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java @@ -0,0 +1,390 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintWriter; +import java.net.HttpRetryException; +import java.net.Socket; +import java.security.KeyStore; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.tls.HeldCertificate; + +public class SslProxyTunnelEstablisherTest { + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + private TestSslProxy testProxy; + private SSLSocketFactory clientSslSocketFactory; + + @Before + public void setUp() throws Exception { + // Create test certificates + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + + // Create SSL socket factory that trusts the proxy CA + File proxyCaFile = tempFolder.newFile("proxy-ca.pem"); + try (FileWriter writer = new FileWriter(proxyCaFile)) { + writer.write(proxyCa.certificatePem()); + } + + ProxySslSocketFactoryProvider factory = new ProxySslSocketFactoryProviderImpl(); + try (java.io.FileInputStream caInput = new java.io.FileInputStream(proxyCaFile)) { + clientSslSocketFactory = factory.create(caInput); + } + + // Start test SSL proxy + testProxy = new TestSslProxy(0, proxyServerCert); + testProxy.start(); + + // Wait for proxy to start + while (testProxy.getPort() == 0) { + Thread.sleep(10); + } + } + + @After + public void tearDown() throws Exception { + if (testProxy != null) { + testProxy.stop(); + } + } + + @Test + public void establishTunnelWithValidSslProxySucceeds() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + String targetHost = "example.com"; + int targetPort = 443; + ProxyCredentialsProvider proxyCredentialsProvider = mock(ProxyCredentialsProvider.class); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + targetHost, + targetPort, + clientSslSocketFactory, + proxyCredentialsProvider); + + assertNotNull("Tunnel socket should not be null", tunnelSocket); + assertTrue("Tunnel socket should be connected", tunnelSocket.isConnected()); + + // Verify CONNECT request was sent and successful + assertTrue("Proxy should have received CONNECT request", + testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + assertEquals("CONNECT example.com:443 HTTP/1.1", testProxy.getReceivedConnectLine()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNotTrustedCertificatedThrows() throws Exception { + SSLContext untrustedContext = SSLContext.getInstance("TLS"); + untrustedContext.init(null, null, null); + SSLSocketFactory untrustedSocketFactory = untrustedContext.getSocketFactory(); + ProxyCredentialsProvider proxyCredentialsProvider = mock(ProxyCredentialsProvider.class); + + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + try { + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + untrustedSocketFactory, + proxyCredentialsProvider); + fail("Should have thrown exception for untrusted certificate"); + } catch (IOException e) { + assertTrue("Exception should be SSL-related", e.getMessage().contains("certification")); + } + } + + @Test + public void establishTunnelWithFailingProxyConnectionThrows() { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + ProxyCredentialsProvider proxyCredentialsProvider = mock(ProxyCredentialsProvider.class); + + try { + establisher.establishTunnel( + "localhost", + -1234, + "example.com", + 443, + clientSslSocketFactory, + proxyCredentialsProvider); + fail("Should have thrown exception for connection failure"); + } catch (IOException e) { + // The implementation wraps the original exception with a descriptive message + assertTrue(e.getMessage().contains("Failed to establish SSL tunnel")); + } + } + + @Test + public void bearerTokenIsPassedWhenSet() throws IOException, InterruptedException { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new ProxyCredentialsProvider() { + @Override + public String getBearerToken() { + return "token"; + } + }); + boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); + assertTrue("Proxy should have received authorization header", await); + } + + @Test + public void establishTunnelWithNullCredentialsProviderDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNullBearerTokenDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + () -> null); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithEmptyBearerTokenDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + () -> " "); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNullStatusLineThrowsIOException() { + testProxy.setConnectResponse(null); + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + IOException exception = assertThrows(IOException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null)); + + assertNotNull(exception); + } + + @Test + public void establishTunnelWithMalformedStatusLineThrowsIOException() { + testProxy.setConnectResponse("HTTP/1.1"); // Malformed, missing status code + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + IOException exception = assertThrows(IOException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null)); + + assertNotNull(exception); + } + + @Test + public void establishTunnelWithProxyAuthRequiredThrowsHttpRetryException() { + testProxy.setConnectResponse("HTTP/1.1 407 Proxy Authentication Required"); + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + HttpRetryException exception = assertThrows(HttpRetryException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null)); + + assertEquals(407, exception.responseCode()); + } + + /** + * Test SSL proxy that accepts SSL connections and handles CONNECT requests. + */ + private static class TestSslProxy extends Thread { + private final int mPort; + private final HeldCertificate mServerCert; + private SSLServerSocket mServerSocket; + private final AtomicBoolean mRunning = new AtomicBoolean(true); + private final CountDownLatch mConnectRequestReceived = new CountDownLatch(1); + private final CountDownLatch mAuthorizationHeaderReceived = new CountDownLatch(1); + private final AtomicReference mReceivedConnectLine = new AtomicReference<>(); + private final AtomicReference mConnectResponse = new AtomicReference<>("HTTP/1.1 200 Connection established"); + + public TestSslProxy(int port, HeldCertificate serverCert) { + mPort = port; + mServerCert = serverCert; + } + + @Override + public void run() { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", mServerCert.keyPair().getPrivate(), "password".toCharArray(), + new java.security.cert.Certificate[]{mServerCert.certificate()}); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + sslContext.init(kmf.getKeyManagers(), null, null); + + mServerSocket = (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(mPort); + mServerSocket.setWantClientAuth(false); + mServerSocket.setNeedClientAuth(false); + + while (mRunning.get()) { + try { + Socket client = mServerSocket.accept(); + handleClient(client); + } catch (IOException e) { + if (mRunning.get()) { + System.err.println("Error accepting client: " + e.getMessage()); + } + } + } + } catch (Exception e) { + throw new RuntimeException("Failed to start test SSL proxy", e); + } + } + + private void handleClient(Socket client) { + try { + BufferedReader reader = new BufferedReader( + new InputStreamReader(client.getInputStream())); + PrintWriter writer = new PrintWriter(client.getOutputStream(), true); + + // Read CONNECT request + String line = reader.readLine(); + if (line != null && line.startsWith("CONNECT")) { + mReceivedConnectLine.set(line); + mConnectRequestReceived.countDown(); + + while((line = reader.readLine()) != null && !line.isEmpty()) { + if (line.contains("Authorization") && line.contains("Bearer")) { + mAuthorizationHeaderReceived.countDown(); + } + } + + // Send configured CONNECT response + String response = mConnectResponse.get(); + if (response != null) { + writer.println(response); + writer.println(); + writer.flush(); + } + + // Keep connection open for tunnel + Thread.sleep(100); + } + } catch (Exception e) { + System.err.println("Error handling client: " + e.getMessage()); + } finally { + try { + client.close(); + } catch (IOException e) { + // Ignore + } + } + } + + public int getPort() { + return mServerSocket != null ? mServerSocket.getLocalPort() : 0; + } + + public CountDownLatch getConnectRequestReceived() { + return mConnectRequestReceived; + } + + public CountDownLatch getAuthorizationHeaderReceived() { + return mAuthorizationHeaderReceived; + } + + public String getReceivedConnectLine() { + return mReceivedConnectLine.get(); + } + + public void setConnectResponse(String connectResponse) { + mConnectResponse.set(connectResponse); + } + } +}