diff --git a/src/androidTest/java/fake/HttpResponseMock.java b/src/androidTest/java/fake/HttpResponseMock.java index 38fbc5ba6..ba0dd982b 100644 --- a/src/androidTest/java/fake/HttpResponseMock.java +++ b/src/androidTest/java/fake/HttpResponseMock.java @@ -1,8 +1,6 @@ package fake; -import java.io.PipedInputStream; -import java.io.PipedOutputStream; -import java.util.concurrent.BlockingQueue; +import java.security.cert.Certificate; import io.split.android.client.network.BaseHttpResponseImpl; import io.split.android.client.network.HttpResponse; @@ -25,4 +23,9 @@ public HttpResponseMock(int status, String data) { public String getData() { return data; } + + @Override + public Certificate[] getServerCertificates() { + return new Certificate[0]; + } } diff --git a/src/androidTest/java/fake/HttpResponseStub.java b/src/androidTest/java/fake/HttpResponseStub.java index 085ea5114..a23c08a17 100644 --- a/src/androidTest/java/fake/HttpResponseStub.java +++ b/src/androidTest/java/fake/HttpResponseStub.java @@ -1,5 +1,7 @@ package fake; +import java.security.cert.Certificate; + import io.split.android.client.network.BaseHttpResponseImpl; import io.split.android.client.network.HttpResponse; @@ -29,4 +31,9 @@ public boolean isSuccess() { public String getData() { return data; } + + @Override + public Certificate[] getServerCertificates() { + return new Certificate[0]; + } } diff --git a/src/androidTest/java/helper/TestableSplitConfigBuilder.java b/src/androidTest/java/helper/TestableSplitConfigBuilder.java index 34449f445..2854673cb 100644 --- a/src/androidTest/java/helper/TestableSplitConfigBuilder.java +++ b/src/androidTest/java/helper/TestableSplitConfigBuilder.java @@ -9,6 +9,7 @@ import io.split.android.client.impressions.ImpressionListener; import io.split.android.client.network.CertificatePinningConfiguration; import io.split.android.client.network.DevelopmentSslConfig; +import io.split.android.client.network.ProxyConfiguration; import io.split.android.client.network.SplitAuthenticator; import io.split.android.client.service.ServiceConstants; import io.split.android.client.service.impressions.ImpressionsMode; @@ -66,6 +67,7 @@ public class TestableSplitConfigBuilder { private CertificatePinningConfiguration mCertificatePinningConfiguration; private long mImpressionsDedupeTimeInterval = ServiceConstants.DEFAULT_IMPRESSIONS_DEDUPE_TIME_INTERVAL; private RolloutCacheConfiguration mRolloutCacheConfiguration = RolloutCacheConfiguration.builder().build(); + private ProxyConfiguration mProxyConfiguration = null; public TestableSplitConfigBuilder() { mServiceEndpoints = ServiceEndpoints.builder().build(); @@ -281,6 +283,11 @@ public TestableSplitConfigBuilder rolloutCacheConfiguration(RolloutCacheConfigur return this; } + public TestableSplitConfigBuilder logger(ProxyConfiguration proxyConfiguration) { + this.mProxyConfiguration = proxyConfiguration; + return this; + } + public SplitClientConfig build() { Constructor constructor = SplitClientConfig.class.getDeclaredConstructors()[0]; constructor.setAccessible(true); @@ -337,7 +344,8 @@ public SplitClientConfig build() { mObserverCacheExpirationPeriod, mCertificatePinningConfiguration, mImpressionsDedupeTimeInterval, - mRolloutCacheConfiguration); + mRolloutCacheConfiguration, + mProxyConfiguration); Logger.instance().setLevel(mLogLevel); return config; diff --git a/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java b/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java index f733e28f9..7871f6949 100644 --- a/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java +++ b/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java @@ -17,7 +17,6 @@ import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.X509TrustManager; -import io.split.android.client.utils.Base64Util; import io.split.android.client.utils.logger.Logger; class CertificateCheckerImpl implements CertificateChecker { @@ -100,17 +99,4 @@ private String certificateChainInfo(List cleanCertificates) { return builder.toString(); } - - private static class DefaultBase64Encoder implements Base64Encoder { - - @Override - public String encode(String value) { - return Base64Util.encode(value); - } - - @Override - public String encode(byte[] bytes) { - return Base64Util.encode(bytes); - } - } } diff --git a/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java b/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java new file mode 100644 index 000000000..e1333ca80 --- /dev/null +++ b/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java @@ -0,0 +1,16 @@ +package io.split.android.client.network; + +import io.split.android.client.utils.Base64Util; + +class DefaultBase64Encoder implements Base64Encoder { + + @Override + public String encode(String value) { + return Base64Util.encode(value); + } + + @Override + public String encode(byte[] bytes) { + return Base64Util.encode(bytes); + } +} diff --git a/src/main/java/io/split/android/client/network/HttpClientImpl.java b/src/main/java/io/split/android/client/network/HttpClientImpl.java index e66aa30c8..a517f84c9 100644 --- a/src/main/java/io/split/android/client/network/HttpClientImpl.java +++ b/src/main/java/io/split/android/client/network/HttpClientImpl.java @@ -44,6 +44,8 @@ public class HttpClientImpl implements HttpClient { private final UrlSanitizer mUrlSanitizer; @Nullable private final CertificateChecker mCertificateChecker; + @Nullable + private final ProxyCacertConnectionHandler mConnectionHandler; HttpClientImpl(@Nullable HttpProxy proxy, @Nullable SplitAuthenticator proxyAuthenticator, @@ -66,6 +68,9 @@ public class HttpClientImpl implements HttpClient { mSslSocketFactory = sslSocketFactory; mUrlSanitizer = urlSanitizer; mCertificateChecker = certificateChecker; + mConnectionHandler = mHttpProxy != null && mSslSocketFactory != null && + (mHttpProxy.getCaCertStream() != null || mHttpProxy.getClientCertStream() != null) ? + new ProxyCacertConnectionHandler() : null; } @Override @@ -113,7 +118,8 @@ public HttpStreamRequest streamRequest(URI uri) { mUrlSanitizer, mCertificateChecker, mHttpProxy, - mProxyCredentialsProvider); + mProxyCredentialsProvider, + mConnectionHandler); } @Override @@ -274,7 +280,7 @@ public HttpClient build() { } if (mProxy != null) { - mSslSocketFactory = createSslSocketFactoryFromProxy(); + mSslSocketFactory = createSslSocketFactoryFromProxy(mProxy); } else { try { mSslSocketFactory = new Tls12OnlySocketFactory(); @@ -311,23 +317,23 @@ public HttpClient build() { certificateChecker); } - private SSLSocketFactory createSslSocketFactoryFromProxy() { + private SSLSocketFactory createSslSocketFactoryFromProxy(HttpProxy proxyParams) { ProxySslSocketFactoryProviderImpl factoryProvider = new ProxySslSocketFactoryProviderImpl(mBase64Decoder); try { - if (mProxy.getClientCertStream() != null && mProxy.getClientKeyStream() != null) { - try (InputStream caInput = mProxy.getCaCertStream(); - InputStream certInput = mProxy.getClientCertStream(); - InputStream keyInput = mProxy.getClientKeyStream()) { - Logger.v("Custom proxy CA cert and client cert/key loaded for proxy: " + mProxy.getHost()); + if (proxyParams.getClientCertStream() != null && proxyParams.getClientKeyStream() != null) { + try (InputStream caInput = proxyParams.getCaCertStream(); + InputStream certInput = proxyParams.getClientCertStream(); + InputStream keyInput = proxyParams.getClientKeyStream()) { + Logger.v("Custom proxy CA cert and client cert/key loaded for proxy: " + proxyParams.getHost()); return factoryProvider.create(caInput, certInput, keyInput); } - } else if (mProxy.getCaCertStream() != null) { - try (InputStream caInput = mProxy.getCaCertStream()) { + } else if (proxyParams.getCaCertStream() != null) { + try (InputStream caInput = proxyParams.getCaCertStream()) { return factoryProvider.create(caInput); } } } catch (Exception e) { - Logger.e("Failed to create SSLSocketFactory for proxy: " + mProxy.getHost() + ", error: " + e.getMessage()); + Logger.e("Failed to create SSLSocketFactory for proxy: " + proxyParams.getHost() + ", error: " + e.getMessage()); } return null; } diff --git a/src/main/java/io/split/android/client/network/HttpOverTunnelExecutor.java b/src/main/java/io/split/android/client/network/HttpOverTunnelExecutor.java index 5b5019a1f..9500f8514 100644 --- a/src/main/java/io/split/android/client/network/HttpOverTunnelExecutor.java +++ b/src/main/java/io/split/android/client/network/HttpOverTunnelExecutor.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.io.PrintWriter; import java.net.Socket; +import java.net.SocketException; import java.net.URL; import java.security.cert.Certificate; import java.util.Map; @@ -31,20 +32,8 @@ public HttpOverTunnelExecutor() { mResponseParser = new RawHttpResponseParser(); } - /** - * Executes an HTTP request through the established tunnel socket. - * - * @param tunnelSocket The SSL Socket with tunnel established (connection maintained) - * @param targetUrl The final destination URL (HTTP or HTTPS) - * @param method HTTP method for the request - * @param headers Headers to include in the request - * @param body Request body (null for GET requests) - * @param serverCertificates The server certificates from the SSL connection (null if not available) - * @return HttpResponse containing the server's response - * @throws IOException if the request execution fails - */ @NonNull - public HttpResponse executeRequest( + HttpResponse executeRequest( @NonNull Socket tunnelSocket, @NonNull URL targetUrl, @NonNull HttpMethod method, @@ -58,12 +47,41 @@ public HttpResponse executeRequest( sendHttpRequest(tunnelSocket, targetUrl, method, headers, body); return readHttpResponse(tunnelSocket, serverCertificates); + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped + // This ensures consistent behavior with non-proxy flows + throw e; } catch (Exception e) { + // Wrap other exceptions in IOException Logger.e("Failed to execute request through tunnel: " + e.getMessage()); throw new IOException("Failed to execute HTTP request through tunnel to " + targetUrl, e); } } + @NonNull + HttpStreamResponse executeStreamRequest(@NonNull Socket finalSocket, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable Certificate[] serverCertificates) throws IOException { + Logger.v("Executing stream request through tunnel to: " + targetUrl); + + try { + sendHttpRequest(finalSocket, targetUrl, method, headers, null); + return readHttpStreamResponse(finalSocket, originSocket); + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped + // This ensures consistent behavior with non-proxy flows + throw e; + } catch (Exception e) { + // Wrap other exceptions in IOException + Logger.e("Failed to execute stream request through tunnel: " + e.getMessage()); + throw new IOException("Failed to execute HTTP stream request through tunnel to " + targetUrl, e); + } + } + /** * Sends the HTTP request through the tunnel socket. */ @@ -97,7 +115,6 @@ private void sendHttpRequest( host += ":" + port; } - Logger.v("Sending Host header: 'Host: " + host + "'"); writer.write("Host: " + host + CRLF); // 3. Send custom headers (excluding Host and Content-Length) @@ -151,6 +168,10 @@ private HttpResponse readHttpResponse(@NonNull Socket tunnelSocket, @Nullable Ce return mResponseParser.parseHttpResponse(tunnelSocket.getInputStream(), serverCertificates); } + private HttpStreamResponse readHttpStreamResponse(@NonNull Socket tunnelSocket, @Nullable Socket originSocket) throws IOException { + return mResponseParser.parseHttpStreamResponse(tunnelSocket.getInputStream(), tunnelSocket, originSocket); + } + /** * Gets the target port from URL, defaulting based on protocol. */ diff --git a/src/main/java/io/split/android/client/network/HttpRequestImpl.java b/src/main/java/io/split/android/client/network/HttpRequestImpl.java index a9e850fc0..1f2a0c402 100644 --- a/src/main/java/io/split/android/client/network/HttpRequestImpl.java +++ b/src/main/java/io/split/android/client/network/HttpRequestImpl.java @@ -190,23 +190,12 @@ private HttpURLConnection setUpConnection(boolean authenticate) throws IOExcepti HttpURLConnection connection; try { - connection = createConnection( - url, - mProxy, - mHttpProxy, - mProxyAuthenticator, - mHttpMethod, - mHeaders, - authenticate, - mSslSocketFactory, - mProxyCredentialsProvider, - mBody - ); + connection = getConnection(authenticate, url); } catch (HttpRetryException e) { if (mProxyAuthenticator == null) { throw e; } - connection = createConnection(url, mProxy, mHttpProxy, mProxyAuthenticator, mHttpMethod, mHeaders, authenticate, null, null, null); + connection = getConnection(authenticate, url); } applyTimeouts(mReadTimeout, mConnectionTimeout, connection); applySslConfig(mSslSocketFactory, mDevelopmentSslConfig, connection); @@ -226,6 +215,21 @@ private HttpURLConnection setUpConnection(boolean authenticate) throws IOExcepti return connection; } + @NonNull + private HttpURLConnection getConnection(boolean authenticate, URL url) throws IOException { + return createConnection( + url, + mProxy, + mHttpProxy, + mProxyAuthenticator, + mHttpMethod, + mHeaders, + authenticate, + mSslSocketFactory, + mProxyCredentialsProvider, + mBody); + } + private static HttpResponse buildResponse(HttpURLConnection connection) throws IOException { int responseCode = connection.getResponseCode(); diff --git a/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java b/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java index 367af1b02..fb269cbdf 100644 --- a/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java +++ b/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java @@ -32,7 +32,7 @@ class HttpResponseConnectionAdapter extends HttpsURLConnection { private final HttpResponse mResponse; private final URL mUrl; private final Certificate[] mServerCertificates; - private OutputStream mOutputStream; + private final OutputStream mOutputStream; private InputStream mInputStream; private InputStream mErrorStream; private boolean mDoOutput = false; @@ -57,7 +57,7 @@ class HttpResponseConnectionAdapter extends HttpsURLConnection { @NonNull OutputStream outputStream) { this(url, response, serverCertificates, outputStream, null, null); } - + @VisibleForTesting HttpResponseConnectionAdapter(@NonNull URL url, @NonNull HttpResponse response, diff --git a/src/main/java/io/split/android/client/network/HttpStreamRequest.java b/src/main/java/io/split/android/client/network/HttpStreamRequest.java index 0b11d6844..f0bb28c67 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamRequest.java +++ b/src/main/java/io/split/android/client/network/HttpStreamRequest.java @@ -1,7 +1,9 @@ package io.split.android.client.network; +import java.io.IOException; + public interface HttpStreamRequest { void addHeader(String name, String value); - HttpStreamResponse execute() throws HttpException; + HttpStreamResponse execute() throws HttpException, IOException; void close(); } diff --git a/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java b/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java index 1203c6467..3a010c04f 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java +++ b/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java @@ -18,6 +18,7 @@ import java.net.MalformedURLException; import java.net.ProtocolException; import java.net.Proxy; +import java.net.SocketException; import java.net.URI; import java.net.URL; import java.util.HashMap; @@ -56,6 +57,8 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { private final HttpProxy mHttpProxy; @Nullable private final ProxyCredentialsProvider mProxyCredentialsProvider; + @Nullable + private final ProxyCacertConnectionHandler mConnectionHandler; HttpStreamRequestImpl(@NonNull URI uri, @NonNull Map headers, @@ -67,7 +70,8 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { @NonNull UrlSanitizer urlSanitizer, @Nullable CertificateChecker certificateChecker, @Nullable HttpProxy httpProxy, - @Nullable ProxyCredentialsProvider proxyCredentialsProvider) { + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + @Nullable ProxyCacertConnectionHandler proxyCacertConnectionHandler) { mUri = checkNotNull(uri); mHttpMethod = HttpMethod.GET; mProxy = proxy; @@ -80,10 +84,11 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { mCertificateChecker = certificateChecker; mHttpProxy = httpProxy; mProxyCredentialsProvider = proxyCredentialsProvider; + mConnectionHandler = proxyCacertConnectionHandler; } @Override - public HttpStreamResponse execute() throws HttpException { + public HttpStreamResponse execute() throws HttpException, IOException { return getRequest(); } @@ -115,14 +120,18 @@ private void closeBufferedReader() { } } - private HttpStreamResponse getRequest() throws HttpException { + private HttpStreamResponse getRequest() throws HttpException, IOException { HttpStreamResponse response; try { - mConnection = setUpConnection(false); - response = buildResponse(mConnection); + if (mConnectionHandler != null && mHttpProxy != null && mSslSocketFactory != null && (mHttpProxy.getCaCertStream() != null || mHttpProxy.getClientCertStream() != null)) { + response = mConnectionHandler.executeStreamRequest(mHttpProxy, getUrl(), mHttpMethod, mHeaders, mSslSocketFactory, mProxyCredentialsProvider); + } else { + mConnection = setUpConnection(false); + response = buildResponse(mConnection); - if (response.getHttpStatus() == HttpURLConnection.HTTP_PROXY_AUTH) { - response = handleAuthentication(response); + if (response.getHttpStatus() == HttpURLConnection.HTTP_PROXY_AUTH) { + response = handleAuthentication(response); + } } } catch (MalformedURLException e) { disconnect(); @@ -133,6 +142,11 @@ private HttpStreamResponse getRequest() throws HttpException { } catch (SSLPeerUnverifiedException e) { disconnect(); throw new HttpException("SSL peer not verified: " + e.getLocalizedMessage(), HttpStatus.INTERNAL_NON_RETRYABLE.getCode()); + } catch (SocketException e) { + disconnect(); + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + // This ensures socket closures are treated the same in both direct and proxy flows + throw e; } catch (IOException e) { disconnect(); throw new HttpException("Something happened while retrieving data: " + e.getLocalizedMessage()); @@ -142,10 +156,7 @@ private HttpStreamResponse getRequest() throws HttpException { } private HttpURLConnection setUpConnection(boolean useProxyAuthenticator) throws IOException { - URL url = mUrlSanitizer.getUrl(mUri); - if (url == null) { - throw new IOException("Error parsing URL"); - } + URL url = getUrl(); HttpURLConnection connection = createConnection( url, @@ -167,6 +178,15 @@ private HttpURLConnection setUpConnection(boolean useProxyAuthenticator) throws return connection; } + @NonNull + private URL getUrl() throws IOException { + URL url = mUrlSanitizer.getUrl(mUri); + if (url == null) { + throw new IOException("Error parsing URL"); + } + return url; + } + private HttpStreamResponse handleAuthentication(HttpStreamResponse response) throws HttpException { if (!mWasRetried.getAndSet(true)) { try { @@ -190,11 +210,11 @@ private HttpStreamResponse buildResponse(HttpURLConnection connection) throws IO } mBufferedReader = new BufferedReader(new InputStreamReader(inputStream)); - return new HttpStreamResponseImpl(responseCode, mBufferedReader); + return HttpStreamResponseImpl.createFromHttpUrlConnection(responseCode, mBufferedReader); } } - return new HttpStreamResponseImpl(responseCode); + return HttpStreamResponseImpl.createFromHttpUrlConnection(responseCode, null); } private void disconnect() { diff --git a/src/main/java/io/split/android/client/network/HttpStreamResponse.java b/src/main/java/io/split/android/client/network/HttpStreamResponse.java index e20096041..f733bd3bc 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamResponse.java +++ b/src/main/java/io/split/android/client/network/HttpStreamResponse.java @@ -3,8 +3,9 @@ import androidx.annotation.Nullable; import java.io.BufferedReader; +import java.io.Closeable; -public interface HttpStreamResponse extends BaseHttpResponse { +public interface HttpStreamResponse extends BaseHttpResponse, Closeable { @Nullable BufferedReader getBufferedReader(); } diff --git a/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java b/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java index 175433c44..bf24d0e74 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java +++ b/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java @@ -3,18 +3,39 @@ import androidx.annotation.Nullable; import java.io.BufferedReader; +import java.io.IOException; +import java.net.Socket; + +import io.split.android.client.utils.logger.Logger; public class HttpStreamResponseImpl extends BaseHttpResponseImpl implements HttpStreamResponse { private final BufferedReader mData; - HttpStreamResponseImpl(int httpStatus) { - this(httpStatus, null); - } + // Sockets are referenced when using Proxy tunneling, in order to close them + @Nullable + private final Socket mTunnelSocket; + @Nullable + private final Socket mOriginSocket; - public HttpStreamResponseImpl(int httpStatus, BufferedReader data) { + private HttpStreamResponseImpl(int httpStatus, BufferedReader data, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) { super(httpStatus); mData = data; + mTunnelSocket = tunnelSocket; + mOriginSocket = originSocket; + } + + static HttpStreamResponseImpl createFromTunnelSocket(int httpStatus, + BufferedReader data, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) { + return new HttpStreamResponseImpl(httpStatus, data, tunnelSocket, originSocket); + } + + static HttpStreamResponseImpl createFromHttpUrlConnection(int httpStatus, BufferedReader data) { + return new HttpStreamResponseImpl(httpStatus, data, null, null); } @Override @@ -22,4 +43,37 @@ public HttpStreamResponseImpl(int httpStatus, BufferedReader data) { public BufferedReader getBufferedReader() { return mData; } + + @Override + public void close() throws IOException { + + // Close the BufferedReader first + if (mData != null) { + try { + mData.close(); + } catch (IOException e) { + Logger.w("Failed to close BufferedReader: " + e.getMessage()); + } + } + + // Close origin socket if it exists and is different from tunnel socket + if (mOriginSocket != null && mOriginSocket != mTunnelSocket) { + try { + mOriginSocket.close(); + Logger.v("Origin socket closed"); + } catch (IOException e) { + Logger.w("Failed to close origin socket: " + e.getMessage()); + } + } + + // Close tunnel socket + if (mTunnelSocket != null) { + try { + mTunnelSocket.close(); + Logger.v("Tunnel socket closed"); + } catch (IOException e) { + Logger.w("Failed to close tunnel socket: " + e.getMessage()); + } + } + } } diff --git a/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java b/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java index 44ff3e102..65cdcb6e4 100644 --- a/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java +++ b/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java @@ -4,8 +4,8 @@ import androidx.annotation.Nullable; import java.io.IOException; -import java.net.HttpRetryException; import java.net.Socket; +import java.net.SocketException; import java.net.URL; import java.security.cert.Certificate; import java.util.Map; @@ -33,8 +33,21 @@ public ProxyCacertConnectionHandler() { mTunnelExecutor = new HttpOverTunnelExecutor(); } + /** + * Executes an HTTP request through an SSL proxy tunnel. + * + * @param httpProxy The proxy configuration + * @param targetUrl The target URL to connect to + * @param method The HTTP method to use + * @param headers The HTTP headers to include + * @param body The request body (if any) + * @param sslSocketFactory The SSL socket factory for proxy and origin connections + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @return The HTTP response + * @throws IOException if the request fails + */ @NonNull - public HttpResponse executeRequest(@NonNull HttpProxy httpProxy, + HttpResponse executeRequest(@NonNull HttpProxy httpProxy, @NonNull URL targetUrl, @NonNull HttpMethod method, @NonNull Map headers, @@ -43,87 +56,58 @@ public HttpResponse executeRequest(@NonNull HttpProxy httpProxy, @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { try { - SslProxyTunnelEstablisher tunnelEstablisher = new SslProxyTunnelEstablisher(); - Socket tunnelSocket = null; - Socket finalSocket = null; - Certificate[] serverCertificates = null; - + TunnelConnection connection = establishTunnelConnection( + httpProxy, targetUrl, sslSocketFactory, proxyCredentialsProvider, false); + try { - tunnelSocket = tunnelEstablisher.establishTunnel( - httpProxy.getHost(), - httpProxy.getPort(), - targetUrl.getHost(), - getTargetPort(targetUrl), - sslSocketFactory, - proxyCredentialsProvider - ); - - Logger.v("SSL tunnel established successfully"); - - finalSocket = tunnelSocket; - - // If the origin is HTTPS, wrap the tunnel socket with a new SSLSocket (system CA) - if (HTTPS.equalsIgnoreCase(targetUrl.getProtocol())) { - Logger.v("Wrapping tunnel socket with new SSLSocket for origin server handshake"); - try { - // Use the provided SSLSocketFactory, which is configured to trust the origin's CA - finalSocket = sslSocketFactory.createSocket( - tunnelSocket, - targetUrl.getHost(), - getTargetPort(targetUrl), - true // autoClose - ); - if (finalSocket instanceof SSLSocket) { - SSLSocket originSslSocket = (SSLSocket) finalSocket; - originSslSocket.setUseClientMode(true); - originSslSocket.startHandshake(); - - // Capture server certificates after successful handshake - try { - serverCertificates = originSslSocket.getSession().getPeerCertificates(); - } catch (Exception certEx) { - Logger.w("Could not capture origin server certificates: " + certEx.getMessage()); - } - } else { - throw new IOException("Failed to create SSLSocket to origin"); - } - Logger.v("SSL handshake with origin server completed"); - } catch (Exception sslEx) { - Logger.e("Failed to establish SSL connection to origin: " + sslEx.getMessage()); - throw new IOException("Failed to establish SSL connection to origin server", sslEx); - } - } - return mTunnelExecutor.executeRequest( - finalSocket, + connection.finalSocket, targetUrl, method, headers, body, - serverCertificates - ); + connection.serverCertificates); } finally { - // If we have are tunelling, finalSocket is the tunnel socket - if (finalSocket != null && finalSocket != tunnelSocket) { - try { - finalSocket.close(); - } catch (IOException e) { - Logger.w("Failed to close origin SSL socket: " + e.getMessage()); - } - } - - if (tunnelSocket != null) { - try { - tunnelSocket.close(); - } catch (IOException e) { - Logger.w("Failed to close tunnel socket: " + e.getMessage()); - } - } + // Close all sockets for non-streaming requests + closeConnection(connection); } + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + throw e; + } catch (Exception e) { + Logger.e("Failed to execute request through custom tunnel: " + e.getMessage()); + throw new IOException("Failed to execute request through custom tunnel", e); + } + } + + @NonNull + HttpStreamResponse executeStreamRequest(@NonNull HttpProxy httpProxy, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + try { + TunnelConnection connection = establishTunnelConnection( + httpProxy, targetUrl, sslSocketFactory, proxyCredentialsProvider, true); + + // For streaming requests, pass socket references to the response for later cleanup + Socket originSocket = (connection.finalSocket != connection.tunnelSocket) ? connection.finalSocket : null; + return mTunnelExecutor.executeStreamRequest( + connection.finalSocket, + connection.tunnelSocket, + originSocket, + targetUrl, + method, + headers, + connection.serverCertificates); + // For streaming requests, sockets are NOT closed here + // They will be closed when the HttpStreamResponse.close() is called + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + throw e; } catch (Exception e) { - if (e instanceof HttpRetryException) { - throw (HttpRetryException) e; - } throw new IOException("Failed to execute request through custom tunnel", e); } } @@ -139,4 +123,120 @@ private static int getTargetPort(@NonNull URL targetUrl) { } return port; } + + /** + * Represents a connection through an SSL tunnel. + */ + private static class TunnelConnection { + final Socket tunnelSocket; + final Socket finalSocket; + final Certificate[] serverCertificates; + + TunnelConnection(Socket tunnelSocket, Socket finalSocket, Certificate[] serverCertificates) { + this.tunnelSocket = tunnelSocket; + this.finalSocket = finalSocket; + this.serverCertificates = serverCertificates; + } + } + + /** + * Establishes a tunnel connection to the target through the proxy. + * + * @param httpProxy The proxy configuration + * @param targetUrl The target URL to connect to + * @param sslSocketFactory SSL socket factory for connections + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @param isStreaming Whether this is a streaming connection + * @return A TunnelConnection object containing the established sockets + * @throws IOException if connection establishment fails + */ + private TunnelConnection establishTunnelConnection( + @NonNull HttpProxy httpProxy, + @NonNull URL targetUrl, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + boolean isStreaming) throws IOException { + + SslProxyTunnelEstablisher tunnelEstablisher = new SslProxyTunnelEstablisher(); + Socket tunnelSocket = null; + Socket finalSocket = null; + Certificate[] serverCertificates = null; + + try { + tunnelSocket = tunnelEstablisher.establishTunnel( + httpProxy.getHost(), + httpProxy.getPort(), + targetUrl.getHost(), + getTargetPort(targetUrl), + sslSocketFactory, + proxyCredentialsProvider, + isStreaming + ); + + finalSocket = tunnelSocket; + + // If the origin is HTTPS, wrap the tunnel socket with a new SSLSocket (system CA) + if (HTTPS.equalsIgnoreCase(targetUrl.getProtocol())) { + try { + // Use the provided SSLSocketFactory, which is configured to trust the origin's CA + finalSocket = sslSocketFactory.createSocket( + tunnelSocket, + targetUrl.getHost(), + getTargetPort(targetUrl), + true // autoClose + ); + if (finalSocket instanceof SSLSocket) { + SSLSocket originSslSocket = (SSLSocket) finalSocket; + originSslSocket.setUseClientMode(true); + originSslSocket.startHandshake(); + + // Capture server certificates after successful handshake + try { + serverCertificates = originSslSocket.getSession().getPeerCertificates(); + } catch (Exception certEx) { + Logger.w("Could not capture origin server certificates: " + certEx.getMessage()); + } + } else { + throw new IOException("Failed to create SSLSocket to origin"); + } + } catch (Exception sslEx) { + Logger.e("Failed to establish SSL connection to origin: " + sslEx.getMessage()); + throw new IOException("Failed to establish SSL connection to origin server", sslEx); + } + } + + return new TunnelConnection(tunnelSocket, finalSocket, serverCertificates); + } catch (Exception e) { + // Clean up resources on error + closeSockets(finalSocket, tunnelSocket); + throw e; + } + } + + private void closeConnection(TunnelConnection connection) { + if (connection == null) { + return; + } + + closeSockets(connection.finalSocket, connection.tunnelSocket); + } + + private void closeSockets(Socket finalSocket, Socket tunnelSocket) { + // If we are tunnelling, finalSocket is the tunnel socket + if (finalSocket != null && finalSocket != tunnelSocket) { + try { + finalSocket.close(); + } catch (IOException e) { + Logger.w("Failed to close origin SSL socket: " + e.getMessage()); + } + } + + if (tunnelSocket != null) { + try { + tunnelSocket.close(); + } catch (IOException e) { + Logger.w("Failed to close tunnel socket: " + e.getMessage()); + } + } + } } diff --git a/src/main/java/io/split/android/client/network/RawHttpResponseParser.java b/src/main/java/io/split/android/client/network/RawHttpResponseParser.java index b121b4126..5968419cc 100644 --- a/src/main/java/io/split/android/client/network/RawHttpResponseParser.java +++ b/src/main/java/io/split/android/client/network/RawHttpResponseParser.java @@ -3,9 +3,12 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import java.io.BufferedReader; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.Socket; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.security.cert.Certificate; @@ -30,14 +33,13 @@ class RawHttpResponseParser { * @throws IOException if parsing fails or the response is malformed */ @NonNull - public HttpResponse parseHttpResponse(@NonNull InputStream inputStream, Certificate[] serverCertificates) throws IOException { + HttpResponse parseHttpResponse(@NonNull InputStream inputStream, Certificate[] serverCertificates) throws IOException { // 1. Read and parse status line String statusLine = readLineFromStream(inputStream); if (statusLine == null) { throw new IOException("No HTTP response received from server"); } - Logger.v("Parsing HTTP status line: " + statusLine); int statusCode = parseStatusCode(statusLine); // 2. Read and parse response headers directly @@ -45,7 +47,7 @@ public HttpResponse parseHttpResponse(@NonNull InputStream inputStream, Certific // 3. Determine charset from Content-Type header Charset bodyCharset = extractCharsetFromContentType(responseHeaders.mContentType); - + // 4. Read response body using the same InputStream String responseBody = readResponseBody(inputStream, responseHeaders.mIsChunked, bodyCharset, responseHeaders.mContentLength, responseHeaders.mConnectionClose); @@ -57,6 +59,30 @@ public HttpResponse parseHttpResponse(@NonNull InputStream inputStream, Certific } } + @NonNull + HttpStreamResponse parseHttpStreamResponse(@NonNull InputStream inputStream, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) throws IOException { + // 1. Read and parse status line + String statusLine = readLineFromStream(inputStream); + if (statusLine == null) { + throw new IOException("No HTTP response received from server"); + } + + int statusCode = parseStatusCode(statusLine); + + // 2. Read and parse response headers directly + ParsedResponseHeaders responseHeaders = parseHeadersDirectly(inputStream); + + // 3. Determine charset from Content-Type header + Charset bodyCharset = extractCharsetFromContentType(responseHeaders.mContentType); + + return HttpStreamResponseImpl.createFromTunnelSocket(statusCode, + new BufferedReader(new InputStreamReader(inputStream, bodyCharset)), + tunnelSocket, + originSocket); + } + @NonNull private ParsedResponseHeaders parseHeadersDirectly(@NonNull InputStream inputStream) throws IOException { int contentLength = -1; @@ -174,7 +200,7 @@ private String readChunkedBodyWithCharset(InputStream inputStream, Charset chars // Read trailing headers until empty line String trailerLine; while ((trailerLine = readLineFromStream(inputStream)) != null && !trailerLine.trim().isEmpty()) { - Logger.v("Chunked trailer: " + trailerLine); + // no-op } break; } diff --git a/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java index 194c6bd75..52617f686 100644 --- a/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java +++ b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java @@ -2,6 +2,7 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; import java.io.BufferedReader; import java.io.IOException; @@ -13,11 +14,12 @@ import java.net.Socket; import java.nio.charset.StandardCharsets; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; -import io.split.android.client.utils.logger.Logger; - /** * Establishes SSL tunnels to SSL proxies using CONNECT protocol. */ @@ -25,9 +27,19 @@ class SslProxyTunnelEstablisher { private static final String CRLF = "\r\n"; private static final String PROXY_AUTHORIZATION_HEADER = "Proxy-Authorization"; - + private final Base64Encoder mBase64Encoder; + // Default timeout for regular connections (10 seconds) - private static final int DEFAULT_SOCKET_TIMEOUT = 10000; + private static final int DEFAULT_SOCKET_TIMEOUT = 20000; + + SslProxyTunnelEstablisher() { + this(new DefaultBase64Encoder()); + } + + @VisibleForTesting + SslProxyTunnelEstablisher(Base64Encoder base64Encoder) { + mBase64Encoder = base64Encoder; + } /** * Establishes an SSL tunnel through the proxy using the CONNECT method. @@ -40,63 +52,56 @@ class SslProxyTunnelEstablisher { * @param targetPort The target server port * @param sslSocketFactory SSL socket factory for proxy authentication * @param proxyCredentialsProvider Credentials provider for proxy authentication - * @return Raw socket with tunnel established (connection maintained) - * @throws IOException if tunnel establishment fails - */ - /** - * Establishes an SSL tunnel through the proxy using the CONNECT method. - * After successful tunnel establishment, extracts the underlying socket - * for use with origin server SSL connections. - * - * @param proxyHost The proxy server hostname - * @param proxyPort The proxy server port - * @param targetHost The target server hostname - * @param targetPort The target server port - * @param sslSocketFactory SSL socket factory for proxy authentication - * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @param isStreaming Whether this connection is for streaming (uses longer timeout) * @return Raw socket with tunnel established (connection maintained) * @throws IOException if tunnel establishment fails */ @NonNull Socket establishTunnel(@NonNull String proxyHost, - int proxyPort, - @NonNull String targetHost, - int targetPort, - @NonNull SSLSocketFactory sslSocketFactory, - @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + int proxyPort, + @NonNull String targetHost, + int targetPort, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + boolean isStreaming) throws IOException { Socket rawSocket = null; SSLSocket sslSocket = null; try { - // Determine which timeout to use based on connection type int timeout = DEFAULT_SOCKET_TIMEOUT; - // Step 1: Create raw TCP connection to proxy rawSocket = new Socket(proxyHost, proxyPort); rawSocket.setSoTimeout(timeout); // Create a temporary SSL socket to establish the SSL session with proper trust validation - sslSocket = (SSLSocket) sslSocketFactory.createSocket(rawSocket, proxyHost, proxyPort, false); + sslSocket = (SSLSocket) sslSocketFactory.createSocket(rawSocket, proxyHost, proxyPort, true); sslSocket.setUseClientMode(true); - sslSocket.setSoTimeout(timeout); + if (isStreaming) { + sslSocket.setSoTimeout(0); // no timeout for streaming + } else { + sslSocket.setSoTimeout(timeout); + } // Perform SSL handshake using the SSL socket with custom CA certificates sslSocket.startHandshake(); + // Validate the proxy hostname + HostnameVerifier verifier = HttpsURLConnection.getDefaultHostnameVerifier(); + if (!verifier.verify(proxyHost, sslSocket.getSession())) { + throw new SSLHandshakeException("Proxy hostname verification failed"); + } + // Step 3: Send CONNECT request through SSL connection sendConnectRequest(sslSocket, targetHost, targetPort, proxyCredentialsProvider); // Step 4: Validate CONNECT response through SSL connection validateConnectResponse(sslSocket); - Logger.v("SSL tunnel established successfully"); // Step 5: Return SSL socket for tunnel communication return sslSocket; } catch (Exception e) { - Logger.e("SSL tunnel establishment failed: " + e.getMessage()); - // Clean up resources on error if (sslSocket != null) { try { @@ -130,27 +135,34 @@ private void sendConnectRequest(@NonNull SSLSocket sslSocket, int targetPort, @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { - Logger.v("Sending CONNECT request through SSL: CONNECT " + targetHost + ":" + targetPort + " HTTP/1.1"); - PrintWriter writer = new PrintWriter(new OutputStreamWriter(sslSocket.getOutputStream(), StandardCharsets.UTF_8), false); writer.write("CONNECT " + targetHost + ":" + targetPort + " HTTP/1.1" + CRLF); writer.write("Host: " + targetHost + ":" + targetPort + CRLF); if (proxyCredentialsProvider != null) { - if (proxyCredentialsProvider instanceof BearerCredentialsProvider) { - // Send Proxy-Authorization header if credentials are set - String bearerToken = ((BearerCredentialsProvider) proxyCredentialsProvider).getToken(); - if (bearerToken != null && !bearerToken.trim().isEmpty()) { - writer.write(PROXY_AUTHORIZATION_HEADER + ": Bearer " + bearerToken + CRLF); - } - } + addProxyAuthHeader(proxyCredentialsProvider, writer); } // Send empty line to end headers writer.write(CRLF); writer.flush(); + } - Logger.v("CONNECT request sent through SSL connection"); + private void addProxyAuthHeader(@NonNull ProxyCredentialsProvider proxyCredentialsProvider, PrintWriter writer) { + if (proxyCredentialsProvider instanceof BearerCredentialsProvider) { + // Send Proxy-Authorization header if credentials are set + String bearerToken = ((BearerCredentialsProvider) proxyCredentialsProvider).getToken(); + if (bearerToken != null && !bearerToken.trim().isEmpty()) { + writer.write(PROXY_AUTHORIZATION_HEADER + ": Bearer " + bearerToken + CRLF); + } + } else if (proxyCredentialsProvider instanceof BasicCredentialsProvider) { + BasicCredentialsProvider basicCredentialsProvider = (BasicCredentialsProvider) proxyCredentialsProvider; + String userName = basicCredentialsProvider.getUserName(); + String password = basicCredentialsProvider.getPassword(); + if (userName != null && !userName.trim().isEmpty() && password != null && !password.trim().isEmpty()) { + writer.write(PROXY_AUTHORIZATION_HEADER + ": Basic " + mBase64Encoder.encode(userName + ":" + password) + CRLF); + } + } } /** @@ -159,8 +171,6 @@ private void sendConnectRequest(@NonNull SSLSocket sslSocket, */ private void validateConnectResponse(@NonNull SSLSocket sslSocket) throws IOException { - Logger.v("Reading CONNECT response through SSL connection"); - try { BufferedReader reader = new BufferedReader(new InputStreamReader(sslSocket.getInputStream(), StandardCharsets.UTF_8)); @@ -169,8 +179,6 @@ private void validateConnectResponse(@NonNull SSLSocket sslSocket) throws IOExce throw new IOException("No CONNECT response received from proxy"); } - Logger.v("Received CONNECT response through SSL: " + statusLine.trim()); - // Parse status code String[] statusParts = statusLine.split(" "); if (statusParts.length < 2) { @@ -187,7 +195,7 @@ private void validateConnectResponse(@NonNull SSLSocket sslSocket) throws IOExce // Read headers until empty line (but don't process them for CONNECT) String headerLine; while ((headerLine = reader.readLine()) != null && !headerLine.trim().isEmpty()) { - Logger.v("CONNECT response header: " + headerLine); + // no-op } // Check status code diff --git a/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java b/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java index 9e8f11f21..78a8f316b 100644 --- a/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java +++ b/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java @@ -36,6 +36,7 @@ public class SseClientImpl implements SseClient { private final StringHelper mStringHelper; private HttpStreamRequest mHttpStreamRequest = null; + private HttpStreamResponse mHttpStreamResponse = null; private static final String PUSH_NOTIFICATION_CHANNELS_PARAM = "channel"; private static final String PUSH_NOTIFICATION_TOKEN_PARAM = "accessToken"; @@ -71,8 +72,21 @@ public void disconnect() { private void close() { Logger.d("Disconnecting SSE client"); if (mStatus.getAndSet(DISCONNECTED) != DISCONNECTED) { + // Close the HttpStreamResponse first to clean up sockets + if (mHttpStreamResponse != null) { + try { + mHttpStreamResponse.close(); + Logger.v("HttpStreamResponse closed successfully"); + } catch (IOException e) { + Logger.w("Failed to close HttpStreamResponse: " + e.getMessage()); + } + mHttpStreamResponse = null; + } + + // Close the HttpStreamRequest if (mHttpStreamRequest != null) { mHttpStreamRequest.close(); + mHttpStreamRequest = null; } Logger.d("SSE client disconnected"); } @@ -94,9 +108,9 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { .addParameter(PUSH_NOTIFICATION_TOKEN_PARAM, rawToken) .build(); mHttpStreamRequest = mHttpClient.streamRequest(url); - HttpStreamResponse response = mHttpStreamRequest.execute(); - if (response.isSuccess()) { - bufferedReader = response.getBufferedReader(); + mHttpStreamResponse = mHttpStreamRequest.execute(); + if (mHttpStreamResponse.isSuccess()) { + bufferedReader = mHttpStreamResponse.getBufferedReader(); if (bufferedReader != null) { Logger.d("Streaming connection opened"); mStatus.set(CONNECTED); @@ -126,8 +140,8 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { throw (new IOException("Buffer is null")); } } else { - Logger.e("Streaming connection error. Http return code " + response.getHttpStatus()); - isErrorRetryable = !response.isClientRelatedError(); + Logger.e("Streaming connection error. Http return code " + mHttpStreamResponse.getHttpStatus()); + isErrorRetryable = !mHttpStreamResponse.isClientRelatedError(); } } catch (URISyntaxException e) { logError("An error has occurred while creating stream Url ", e); @@ -149,8 +163,7 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { } } - private void logError(String message, Exception e) { + private static void logError(String message, Exception e) { Logger.e(message + " : " + e.getLocalizedMessage()); } - } diff --git a/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java b/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java new file mode 100644 index 000000000..738300ce7 --- /dev/null +++ b/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java @@ -0,0 +1,47 @@ +package io.split.android.client.network; + +import static org.mockito.Mockito.mockStatic; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockedStatic; + +import java.nio.charset.StandardCharsets; + +import io.split.android.client.utils.Base64Util; + +public class DefaultBase64EncoderTest { + + private DefaultBase64Encoder encoder; + private MockedStatic mockedBase64Util; + + @Before + public void setUp() { + encoder = new DefaultBase64Encoder(); + mockedBase64Util = mockStatic(Base64Util.class); + } + + @After + public void tearDown() { + mockedBase64Util.close(); + } + + @Test + public void encodeStringUsesBase64Util() { + String input = "test string"; + + encoder.encode(input); + + mockedBase64Util.verify(() -> Base64Util.encode(input)); + } + + @Test + public void encodeByteArrayUsesBase64Util() { + byte[] input = "test bytes".getBytes(StandardCharsets.UTF_8); + + encoder.encode(input); + + mockedBase64Util.verify(() -> Base64Util.encode(input)); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpClientTest.java b/src/test/java/io/split/android/client/network/HttpClientTest.java index 1ad68cd2f..c7f53124f 100644 --- a/src/test/java/io/split/android/client/network/HttpClientTest.java +++ b/src/test/java/io/split/android/client/network/HttpClientTest.java @@ -219,7 +219,7 @@ public void addHeaders() throws InterruptedException, URISyntaxException, HttpEx } @Test - public void addStreamingHeaders() throws InterruptedException, URISyntaxException, HttpException { + public void addStreamingHeaders() throws InterruptedException, HttpException, IOException { client.addStreamingHeaders(Collections.singletonMap("my_header", "my_header_value")); HttpUrl url = mWebServer.url("/test1/"); diff --git a/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java b/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java index 75e43ef1c..2f3cb17ce 100644 --- a/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java +++ b/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java @@ -29,8 +29,11 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocketFactory; import okhttp3.mockwebserver.Dispatcher; @@ -47,6 +50,13 @@ public class HttpClientTunnellingProxyTest { @Before public void setUp() { + // override the default hostname verifier for testing + HttpsURLConnection.setDefaultHostnameVerifier(new HostnameVerifier() { + @Override + public boolean verify(String hostname, SSLSession sslSession) { + return true; + } + }); mUrlSanitizerMock = mock(UrlSanitizer.class); when(mUrlSanitizerMock.getUrl(any())).thenAnswer(new Answer() { @Override diff --git a/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java b/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java index b4f74d166..2fdc64e30 100644 --- a/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java +++ b/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java @@ -10,6 +10,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -142,4 +143,70 @@ public void tearDown() throws IOException { mOutputStream.close(); mInputStream.close(); } + + @Test + public void executeStreamRequestTest() throws IOException { + // Prepare HTTP response with headers and body + String httpResponse = "HTTP/1.1 200 OK\r\n" + + "Content-Type: application/json; charset=utf-8\r\n" + + "Content-Length: 16\r\n" + + "\r\n" + + "{\"data\":\"test\"}"; + + // Set up input stream with the HTTP response + ByteArrayInputStream inputStream = new ByteArrayInputStream(httpResponse.getBytes()); + when(mSocket.getInputStream()).thenReturn(inputStream); + + // Execute the stream request + URL url = new URL("https://test.com/stream"); + HttpStreamResponse response = mExecutor.executeStreamRequest( + mSocket, // finalSocket + mSocket, // tunnelSocket (using same mock for simplicity) + null, // originSocket + url, + HttpMethod.GET, + Collections.emptyMap(), + null // serverCertificates + ); + + // Verify the request was sent correctly + String expectedRequest = "GET /stream HTTP/1.1\r\n" + + "Host: test.com\r\n" + + "Connection: close\r\n" + + "\r\n"; + assertEquals(expectedRequest, mOutputStream.toString()); + + // Verify the response properties + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + + // Verify we can read from the response stream + BufferedReader reader = response.getBufferedReader(); + assertNotNull(reader); + StringBuilder responseBody = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + responseBody.append(line); + } + assertEquals("{\"data\":\"test\"}", responseBody.toString()); + + // Close the response + response.close(); + } + + @Test(expected = IOException.class) + public void executeStreamRequestWithSocketException() throws IOException { + URL url = new URL("http://test.com/stream"); + when(mSocket.getOutputStream()).thenThrow(new IOException("Socket error")); + + mExecutor.executeStreamRequest( + mSocket, + mSocket, + null, + url, + HttpMethod.GET, + Collections.emptyMap(), + null + ); + } } diff --git a/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java b/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java new file mode 100644 index 000000000..aa60be76e --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java @@ -0,0 +1,148 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.BufferedReader; +import java.io.IOException; +import java.net.Socket; + +public class HttpStreamResponseTest { + + private static final int HTTP_STATUS_OK = 200; + private static final int HTTP_STATUS_BAD_REQUEST = 400; + + @Mock + private BufferedReader mockBufferedReader; + + @Mock + private Socket mockTunnelSocket; + + @Mock + private Socket mockOriginSocket; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void createFromTunnelSocketReturnsValidResponse() { + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Verify the response is created correctly + assertNotNull(response); + assertEquals(HTTP_STATUS_OK, response.getHttpStatus()); + } + + @Test + public void createFromTunnelSocketWithNullSocketsReturnsValidResponse() { + // Create response with null sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_BAD_REQUEST, + mockBufferedReader, + null, + null + ); + + // Verify the response is created correctly + assertNotNull(response); + assertEquals(HTTP_STATUS_BAD_REQUEST, response.getHttpStatus()); + } + + @Test + public void closeSuccessfullyClosesAllResources() throws IOException { + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Close the response + response.close(); + + // Verify all resources were closed in the correct order + verify(mockBufferedReader, times(1)).close(); + verify(mockOriginSocket, times(1)).close(); + verify(mockTunnelSocket, times(1)).close(); + } + + @Test + public void closeWithNullSocketsOnlyClosesBufferedReader() throws IOException { + // Create response with null sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + null, + null + ); + + // Close the response + response.close(); + + // Verify only the BufferedReader was closed + verify(mockBufferedReader, times(1)).close(); + verifyNoMoreInteractions(mockTunnelSocket, mockOriginSocket); + } + + @Test + public void closeWithSameTunnelAndOriginSocketClosesSocketOnce() throws IOException { + // Create response with the same socket for tunnel and origin + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockTunnelSocket + ); + + // Close the response + response.close(); + + // Verify BufferedReader was closed + verify(mockBufferedReader, times(1)).close(); + + // Verify tunnel socket was closed only once (since it's the same as origin socket) + verify(mockTunnelSocket, times(1)).close(); + } + + @Test + public void closeWithExceptionsSucceeds() throws IOException { + // Setup mocks to throw exceptions when closed + doThrow(new IOException("BufferedReader close error")).when(mockBufferedReader).close(); + doThrow(new IOException("Origin socket close error")).when(mockOriginSocket).close(); + doThrow(new IOException("Tunnel socket close error")).when(mockTunnelSocket).close(); + + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Close the response - should not throw exceptions + response.close(); + + // Verify all resources were attempted to be closed despite exceptions + verify(mockBufferedReader, times(1)).close(); + verify(mockOriginSocket, times(1)).close(); + verify(mockTunnelSocket, times(1)).close(); + } +} diff --git a/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java b/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java index fb1eadc54..203e50570 100644 --- a/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java +++ b/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java @@ -5,12 +5,14 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; import org.junit.Test; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.net.Socket; import java.security.cert.Certificate; import java.util.Objects; @@ -158,4 +160,91 @@ public void responseWithChunkedEncodingHandlesCorrectly() throws Exception { assertTrue("Response data should contain expected content", response.getData().contains("This is chunked data!")); } + + // Tests for parseHttpStreamResponse method + + @Test + public void parseHttpStreamResponseWithValidInputReturnsCorrectResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: application/json\r\n" + + "Content-Length: 25\r\n" + + "\r\n" + + "{\"message\":\"Hello World\"}"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + Socket mockTunnelSocket = mock(Socket.class); + Socket mockOriginSocket = mock(Socket.class); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, mockTunnelSocket, mockOriginSocket); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithNullSocketsReturnsValidResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 13\r\n" + + "\r\n" + + "Hello, World!"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, null, null); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithDifferentContentTypeUsesCorrectCharset() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/html; charset=ISO-8859-1\r\n" + + "Content-Length: 20\r\n" + + "\r\n" + + "Test Page"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("ISO-8859-1")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, null, null); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithEmptyStreamThrowsException() throws Exception { + InputStream inputStream = new ByteArrayInputStream(new byte[0]); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpStreamResponse(inputStream, null, null); + fail("Should have thrown exception for empty stream"); + } catch (IOException e) { + assertTrue("Exception should mention no response", + e.getMessage().contains("No HTTP response")); + } + } + + @Test + public void parseHttpStreamResponseWithInvalidStatusLineThrowsException() throws Exception { + String rawHttpResponse = "INVALID STATUS LINE\r\n\r\n"; + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpStreamResponse(inputStream, null, null); + fail("Should have thrown exception for invalid status line"); + } catch (IOException e) { + assertTrue("Exception should mention invalid status", + Objects.requireNonNull(e.getMessage()).contains("Invalid HTTP status")); + } + } } diff --git a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java index 33b877d2b..994e9d590 100644 --- a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java +++ b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java @@ -6,6 +6,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import org.junit.After; import org.junit.Before; @@ -27,9 +28,12 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocketFactory; import okhttp3.tls.HeldCertificate; @@ -44,6 +48,14 @@ public class SslProxyTunnelEstablisherTest { @Before public void setUp() throws Exception { + // override the default hostname verifier for testing + HttpsURLConnection.setDefaultHostnameVerifier(new HostnameVerifier() { + @Override + public boolean verify(String hostname, SSLSession sslSession) { + return true; + } + }); + // Create test certificates HeldCertificate proxyCa = new HeldCertificate.Builder() .commonName("Test Proxy CA") @@ -95,7 +107,8 @@ public void establishTunnelWithValidSslProxySucceeds() throws Exception { targetHost, targetPort, clientSslSocketFactory, - proxyCredentialsProvider); + proxyCredentialsProvider, + false); assertNotNull("Tunnel socket should not be null", tunnelSocket); assertTrue("Tunnel socket should be connected", tunnelSocket.isConnected()); @@ -124,7 +137,8 @@ public void establishTunnelWithNotTrustedCertificatedThrows() throws Exception { "example.com", 443, untrustedSocketFactory, - proxyCredentialsProvider); + proxyCredentialsProvider, + false); fail("Should have thrown exception for untrusted certificate"); } catch (IOException e) { assertTrue("Exception should be SSL-related", e.getMessage().contains("certification")); @@ -143,7 +157,8 @@ public void establishTunnelWithFailingProxyConnectionThrows() { "example.com", 443, clientSslSocketFactory, - proxyCredentialsProvider); + proxyCredentialsProvider, + false); fail("Should have thrown exception for connection failure"); } catch (IOException e) { // The implementation wraps the original exception with a descriptive message @@ -153,6 +168,7 @@ public void establishTunnelWithFailingProxyConnectionThrows() { @Test public void bearerTokenIsPassedWhenSet() throws IOException, InterruptedException { + // For Bearer token, we don't need to mock the Base64Encoder since it's not used SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); establisher.establishTunnel( "localhost", @@ -165,9 +181,47 @@ public void bearerTokenIsPassedWhenSet() throws IOException, InterruptedExceptio public String getToken() { return "token"; } - }); + }, + false); + boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); + assertTrue("Proxy should have received authorization header", await); + assertEquals("Proxy-Authorization: Bearer token", testProxy.getReceivedAuthHeader()); + } + + @Test + public void basicAuthIsPassedWhenSet() throws IOException, InterruptedException { + // Create a mock Base64Encoder + Base64Encoder mockEncoder = mock(Base64Encoder.class); + String mockEncodedCredentials = "MOCK_ENCODED_CREDENTIALS"; + when(mockEncoder.encode("username:password")).thenReturn(mockEncodedCredentials); + + // Create SslProxyTunnelEstablisher with the mock encoder + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(mockEncoder); + + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BasicCredentialsProvider() { + @Override + public String getUserName() { + return "username"; + } + + @Override + public String getPassword() { + return "password"; + } + }, + false); boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); assertTrue("Proxy should have received authorization header", await); + + // The expected header should contain the mock encoded credentials + String expectedHeader = "Proxy-Authorization: Basic " + mockEncodedCredentials; + assertEquals(expectedHeader, testProxy.getReceivedAuthHeader()); } @Test @@ -180,7 +234,8 @@ public void establishTunnelWithNullCredentialsProviderDoesNotAddAuthHeader() thr "example.com", 443, clientSslSocketFactory, - null); + null, + false); assertNotNull(tunnelSocket); assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); @@ -205,7 +260,8 @@ public void establishTunnelWithNullBearerTokenDoesNotAddAuthHeader() throws Exce public String getToken() { return null; } - }); + }, + false); assertNotNull(tunnelSocket); assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); @@ -230,7 +286,8 @@ public void establishTunnelWithEmptyBearerTokenDoesNotAddAuthHeader() throws Exc public String getToken() { return ""; } - }); + }, + false); assertNotNull(tunnelSocket); assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); @@ -251,7 +308,7 @@ public void establishTunnelWithNullStatusLineThrowsIOException() { "example.com", 443, clientSslSocketFactory, - null)); + null, false)); assertNotNull(exception); } @@ -267,7 +324,8 @@ public void establishTunnelWithMalformedStatusLineThrowsIOException() { "example.com", 443, clientSslSocketFactory, - null)); + null, + false)); assertNotNull(exception); } @@ -283,7 +341,8 @@ public void establishTunnelWithProxyAuthRequiredThrowsHttpRetryException() { "example.com", 443, clientSslSocketFactory, - null)); + null, + false)); assertEquals(407, exception.responseCode()); } @@ -299,6 +358,7 @@ private static class TestSslProxy extends Thread { private final CountDownLatch mConnectRequestReceived = new CountDownLatch(1); private final CountDownLatch mAuthorizationHeaderReceived = new CountDownLatch(1); private final AtomicReference mReceivedConnectLine = new AtomicReference<>(); + private final AtomicReference mReceivedAuthHeader = new AtomicReference<>(); private final AtomicReference mConnectResponse = new AtomicReference<>("HTTP/1.1 200 Connection established"); public TestSslProxy(int port, HeldCertificate serverCert) { @@ -350,8 +410,9 @@ private void handleClient(Socket client) { mConnectRequestReceived.countDown(); while((line = reader.readLine()) != null && !line.isEmpty()) { - if (line.contains("Authorization") && line.contains("Bearer")) { + if (line.contains("Authorization") && (line.contains("Bearer") || line.contains("Basic"))) { mAuthorizationHeaderReceived.countDown(); + mReceivedAuthHeader.set(line); } } @@ -393,6 +454,10 @@ public String getReceivedConnectLine() { return mReceivedConnectLine.get(); } + public String getReceivedAuthHeader() { + return mReceivedAuthHeader.get(); + } + public void setConnectResponse(String connectResponse) { mConnectResponse.set(connectResponse); } diff --git a/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java b/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java index 1ab36656e..eeb53f2e1 100644 --- a/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java +++ b/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java @@ -1,5 +1,13 @@ package io.split.android.client.service.sseclient; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -29,14 +37,6 @@ import io.split.android.client.service.sseclient.sseclient.SseHandler; import io.split.sharedtest.fake.HttpStreamResponseMock; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - public class SseClientTest { @Mock @@ -66,7 +66,7 @@ public void setup() throws URISyntaxException { } @Test - public void onConnect() throws InterruptedException, HttpException { + public void onConnect() throws InterruptedException, HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); TestConnListener connListener = spy(new TestConnListener(onOpenLatch)); @@ -87,7 +87,7 @@ public void onConnect() throws InterruptedException, HttpException { } @Test - public void onConnectNotConfirmed() throws InterruptedException, HttpException { + public void onConnectNotConfirmed() throws InterruptedException, HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); TestConnListener connListener = spy(new TestConnListener(onOpenLatch)); @@ -268,7 +268,7 @@ public void onConnectionSuccess() { } @Test - public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() throws HttpException { + public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() throws HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); BufferedReader reader = Mockito.mock(BufferedReader.class); @@ -286,7 +286,7 @@ public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() thr } @Test - public void retryableErrorWhenRequestFailsWithHttpExceptionWithNullCode() throws HttpException { + public void retryableErrorWhenRequestFailsWithHttpExceptionWithNullCode() throws HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); BufferedReader reader = Mockito.mock(BufferedReader.class);