diff --git a/build.gradle b/build.gradle index 4df504395..3ec95d825 100644 --- a/build.gradle +++ b/build.gradle @@ -19,7 +19,7 @@ apply from: 'spec.gradle' apply from: 'jacoco.gradle' ext { - splitVersion = '5.3.2' + splitVersion = '5.4.0-rc1' jacocoVersion = '0.8.8' } diff --git a/src/androidTest/java/fake/HttpResponseMock.java b/src/androidTest/java/fake/HttpResponseMock.java index 38fbc5ba6..ba0dd982b 100644 --- a/src/androidTest/java/fake/HttpResponseMock.java +++ b/src/androidTest/java/fake/HttpResponseMock.java @@ -1,8 +1,6 @@ package fake; -import java.io.PipedInputStream; -import java.io.PipedOutputStream; -import java.util.concurrent.BlockingQueue; +import java.security.cert.Certificate; import io.split.android.client.network.BaseHttpResponseImpl; import io.split.android.client.network.HttpResponse; @@ -25,4 +23,9 @@ public HttpResponseMock(int status, String data) { public String getData() { return data; } + + @Override + public Certificate[] getServerCertificates() { + return new Certificate[0]; + } } diff --git a/src/androidTest/java/fake/HttpResponseStub.java b/src/androidTest/java/fake/HttpResponseStub.java index 085ea5114..a23c08a17 100644 --- a/src/androidTest/java/fake/HttpResponseStub.java +++ b/src/androidTest/java/fake/HttpResponseStub.java @@ -1,5 +1,7 @@ package fake; +import java.security.cert.Certificate; + import io.split.android.client.network.BaseHttpResponseImpl; import io.split.android.client.network.HttpResponse; @@ -29,4 +31,9 @@ public boolean isSuccess() { public String getData() { return data; } + + @Override + public Certificate[] getServerCertificates() { + return new Certificate[0]; + } } diff --git a/src/androidTest/java/helper/TestableSplitConfigBuilder.java b/src/androidTest/java/helper/TestableSplitConfigBuilder.java index 34449f445..2854673cb 100644 --- a/src/androidTest/java/helper/TestableSplitConfigBuilder.java +++ b/src/androidTest/java/helper/TestableSplitConfigBuilder.java @@ -9,6 +9,7 @@ import io.split.android.client.impressions.ImpressionListener; import io.split.android.client.network.CertificatePinningConfiguration; import io.split.android.client.network.DevelopmentSslConfig; +import io.split.android.client.network.ProxyConfiguration; import io.split.android.client.network.SplitAuthenticator; import io.split.android.client.service.ServiceConstants; import io.split.android.client.service.impressions.ImpressionsMode; @@ -66,6 +67,7 @@ public class TestableSplitConfigBuilder { private CertificatePinningConfiguration mCertificatePinningConfiguration; private long mImpressionsDedupeTimeInterval = ServiceConstants.DEFAULT_IMPRESSIONS_DEDUPE_TIME_INTERVAL; private RolloutCacheConfiguration mRolloutCacheConfiguration = RolloutCacheConfiguration.builder().build(); + private ProxyConfiguration mProxyConfiguration = null; public TestableSplitConfigBuilder() { mServiceEndpoints = ServiceEndpoints.builder().build(); @@ -281,6 +283,11 @@ public TestableSplitConfigBuilder rolloutCacheConfiguration(RolloutCacheConfigur return this; } + public TestableSplitConfigBuilder logger(ProxyConfiguration proxyConfiguration) { + this.mProxyConfiguration = proxyConfiguration; + return this; + } + public SplitClientConfig build() { Constructor constructor = SplitClientConfig.class.getDeclaredConstructors()[0]; constructor.setAccessible(true); @@ -337,7 +344,8 @@ public SplitClientConfig build() { mObserverCacheExpirationPeriod, mCertificatePinningConfiguration, mImpressionsDedupeTimeInterval, - mRolloutCacheConfiguration); + mRolloutCacheConfiguration, + mProxyConfiguration); Logger.instance().setLevel(mLogLevel); return config; diff --git a/src/androidTest/java/tests/storage/GeneralInfoStorageTest.java b/src/androidTest/java/tests/storage/GeneralInfoStorageTest.java index 91361df4c..f214b523d 100644 --- a/src/androidTest/java/tests/storage/GeneralInfoStorageTest.java +++ b/src/androidTest/java/tests/storage/GeneralInfoStorageTest.java @@ -21,7 +21,7 @@ public class GeneralInfoStorageTest { @Before public void setUp() { mDb = DatabaseHelper.getTestDatabase(InstrumentationRegistry.getInstrumentation().getContext()); - mGeneralInfoStorage = new GeneralInfoStorageImpl(mDb.generalInfoDao()); + mGeneralInfoStorage = new GeneralInfoStorageImpl(mDb.generalInfoDao(), null); } @After diff --git a/src/androidTest/java/tests/workmanager/WorkManagerWrapperTest.java b/src/androidTest/java/tests/workmanager/WorkManagerWrapperTest.java index a98e6a0e2..8f3539423 100644 --- a/src/androidTest/java/tests/workmanager/WorkManagerWrapperTest.java +++ b/src/androidTest/java/tests/workmanager/WorkManagerWrapperTest.java @@ -295,6 +295,7 @@ private Data buildInputData(Data customData) { dataBuilder.putString("databaseName", "test_database_name"); dataBuilder.putString("apiKey", "api_key"); dataBuilder.putBoolean("encryptionEnabled", false); + dataBuilder.putBoolean("usesProxy", false); if (customData != null) { dataBuilder.putAll(customData); } diff --git a/src/main/java/io/split/android/client/SplitClientConfig.java b/src/main/java/io/split/android/client/SplitClientConfig.java index 00af53115..004db0ad0 100644 --- a/src/main/java/io/split/android/client/SplitClientConfig.java +++ b/src/main/java/io/split/android/client/SplitClientConfig.java @@ -4,6 +4,7 @@ import static io.split.android.client.utils.Utils.checkNotNull; import androidx.annotation.NonNull; +import androidx.annotation.Nullable; import java.net.URI; import java.util.concurrent.TimeUnit; @@ -17,6 +18,7 @@ import io.split.android.client.network.CertificatePinningConfiguration; import io.split.android.client.network.DevelopmentSslConfig; import io.split.android.client.network.HttpProxy; +import io.split.android.client.network.ProxyConfiguration; import io.split.android.client.network.SplitAuthenticator; import io.split.android.client.service.ServiceConstants; import io.split.android.client.service.impressions.ImpressionsMode; @@ -132,6 +134,8 @@ public class SplitClientConfig { private final long mImpressionsDedupeTimeInterval; @NonNull private final RolloutCacheConfiguration mRolloutCacheConfiguration; + @Nullable + private final ProxyConfiguration mProxyConfiguration; public static Builder builder() { return new Builder(); @@ -187,7 +191,8 @@ private SplitClientConfig(String endpoint, long observerCacheExpirationPeriod, CertificatePinningConfiguration certificatePinningConfiguration, long impressionsDedupeTimeInterval, - RolloutCacheConfiguration rolloutCacheConfiguration) { + @NonNull RolloutCacheConfiguration rolloutCacheConfiguration, + @Nullable ProxyConfiguration proxyConfiguration) { mEndpoint = endpoint; mEventsEndpoint = eventsEndpoint; mTelemetryEndpoint = telemetryEndpoint; @@ -246,6 +251,7 @@ private SplitClientConfig(String endpoint, mCertificatePinningConfiguration = certificatePinningConfiguration; mImpressionsDedupeTimeInterval = impressionsDedupeTimeInterval; mRolloutCacheConfiguration = rolloutCacheConfiguration; + mProxyConfiguration = proxyConfiguration; } public String trafficType() { @@ -436,7 +442,9 @@ public boolean persistentAttributesEnabled() { return mIsPersistentAttributesEnabled; } - public int offlineRefreshRate() { return mOfflineRefreshRate; } + public int offlineRefreshRate() { + return mOfflineRefreshRate; + } public boolean shouldRecordTelemetry() { return mShouldRecordTelemetry; @@ -446,7 +454,9 @@ public long telemetryRefreshRate() { return mTelemetryRefreshRate; } - public boolean syncEnabled() { return mSyncEnabled; } + public boolean syncEnabled() { + return mSyncEnabled; + } public int mtkPerPush() { return mMtkPerPush; @@ -476,7 +486,9 @@ public int sseDisconnectionDelay() { return mSSEDisconnectionDelayInSecs; } - private void enableTelemetry() { mShouldRecordTelemetry = true; } + private void enableTelemetry() { + mShouldRecordTelemetry = true; + } public long observerCacheExpirationPeriod() { return Math.max(mImpressionsDedupeTimeInterval, mObserverCacheExpirationPeriod); @@ -572,6 +584,8 @@ public static final class Builder { private RolloutCacheConfiguration mRolloutCacheConfiguration = RolloutCacheConfiguration.builder().build(); + private ProxyConfiguration mProxyConfiguration = null; + public Builder() { mServiceEndpoints = ServiceEndpoints.builder().build(); } @@ -806,7 +820,9 @@ public Builder ready(int milliseconds) { * * @param proxyHost proxy URI * @return this builder + * @deprecated use {@link #proxyConfiguration(ProxyConfiguration)} */ + @Deprecated public Builder proxyHost(String proxyHost) { if (proxyHost != null && proxyHost.endsWith("/")) { mProxyHost = proxyHost.substring(0, proxyHost.length() - 1); @@ -823,6 +839,7 @@ public Builder proxyHost(String proxyHost) { * @param proxyAuthenticator * @return this builder */ + @Deprecated public Builder proxyAuthenticator(SplitAuthenticator proxyAuthenticator) { mProxyAuthenticator = proxyAuthenticator; return this; @@ -1030,6 +1047,7 @@ public Builder offlineRefreshRate(int offlineRefreshRate) { *

* This is an ADVANCED parameter *

+ * * @param telemetryRefreshRate Rate in seconds for telemetry refresh. * @return This builder * @default 3600 seconds @@ -1101,10 +1119,9 @@ public Builder certificatePinningConfiguration(CertificatePinningConfiguration c /** * This configuration is used to control the size of the impressions deduplication window. * + * @param impressionsDedupeTimeInterval The time interval in milliseconds. * @Experimental This method is experimental and may change or be removed in future versions. * To be used upon Split team recommendation. - * - * @param impressionsDedupeTimeInterval The time interval in milliseconds. */ @Deprecated public Builder impressionsDedupeTimeInterval(long impressionsDedupeTimeInterval) { @@ -1128,6 +1145,17 @@ public Builder rolloutCacheConfiguration(@NonNull RolloutCacheConfiguration roll return this; } + /** + * Sets the proxy configuration + * + * @param proxyConfiguration + * @return this builder + */ + public Builder proxyConfiguration(ProxyConfiguration proxyConfiguration) { + mProxyConfiguration = proxyConfiguration; + return this; + } + public SplitClientConfig build() { Logger.instance().setLevel(mLogLevel); @@ -1207,7 +1235,7 @@ public SplitClientConfig build() { mImpressionsDedupeTimeInterval = ServiceConstants.DEFAULT_IMPRESSIONS_DEDUPE_TIME_INTERVAL; } - HttpProxy proxy = parseProxyHost(mProxyHost); + HttpProxy proxy = parseProxyHost(mProxyHost, mProxyConfiguration); return new SplitClientConfig( mServiceEndpoints.getSdkEndpoint(), @@ -1260,10 +1288,33 @@ public SplitClientConfig build() { mObserverCacheExpirationPeriod, mCertificatePinningConfiguration, mImpressionsDedupeTimeInterval, - mRolloutCacheConfiguration); + mRolloutCacheConfiguration, + mProxyConfiguration); + } + + private HttpProxy parseProxyHost(String proxyUri, ProxyConfiguration proxyConfiguration) { + // Use legacy proxy behavior if proxyConfiguration is null + if (proxyConfiguration == null) { + return legacyProxyBehavior(proxyUri); + } + + if (mProxyHost != null || mProxyAuthenticator != null) { + Logger.w("Both the deprecated proxy configuration methods (proxyHost, proxyAuthenticator) and the new ProxyConfiguration builder are being used. ProxyConfiguration will take precedence."); + } + + // Initialize internal config with null url. This will be verified when building the factory. + HttpProxy.Builder builder = HttpProxy.newBuilder(null, -1); + if (proxyConfiguration.getUrl() != null) { + builder = HttpProxy.newBuilder(proxyConfiguration.getUrl().getHost(), proxyConfiguration.getUrl().getPort()) + .mtls(proxyConfiguration.getClientCert(), proxyConfiguration.getClientPk()) + .proxyCacert(proxyConfiguration.getCaCert()) + .credentialsProvider(proxyConfiguration.getCredentialsProvider()); + } + return builder.build(); } - private HttpProxy parseProxyHost(String proxyUri) { + @Nullable + private HttpProxy legacyProxyBehavior(String proxyUri) { if (!Utils.isNullOrEmpty(proxyUri)) { try { String username = null; @@ -1271,15 +1322,19 @@ private HttpProxy parseProxyHost(String proxyUri) { URI uri = URI.create(proxyUri); int port = uri.getPort() != -1 ? uri.getPort() : PROXY_PORT_DEFAULT; String userInfo = uri.getUserInfo(); - if(!Utils.isNullOrEmpty(userInfo)) { + if (!Utils.isNullOrEmpty(userInfo)) { String[] userInfoComponents = userInfo.split(":"); - if(userInfoComponents.length > 1) { + if (userInfoComponents.length > 1) { username = userInfoComponents[0]; password = userInfoComponents[1]; } } String host = String.format("%s%s", uri.getHost(), uri.getPath()); - return new HttpProxy(host, port, username, password); + if (username != null && password != null) { + return HttpProxy.newBuilder(host, port).basicAuth(username, password).buildLegacy(); + } else { + return HttpProxy.newBuilder(host, port).buildLegacy(); + } } catch (IllegalArgumentException e) { Logger.e("Proxy URI not valid: " + e.getLocalizedMessage()); throw new IllegalArgumentException(); diff --git a/src/main/java/io/split/android/client/SplitFactoryBuilder.java b/src/main/java/io/split/android/client/SplitFactoryBuilder.java index ca31ee935..ab2bc3108 100644 --- a/src/main/java/io/split/android/client/SplitFactoryBuilder.java +++ b/src/main/java/io/split/android/client/SplitFactoryBuilder.java @@ -67,6 +67,9 @@ public static synchronized SplitFactory build(@NonNull String sdkKey, @NonNull K return new SplitFactoryImpl(sdkKey, key, config, context); } } catch (Exception ex) { + if (ex instanceof SplitInstantiationException) { + throw (SplitInstantiationException) ex; + } throw new SplitInstantiationException("Could not instantiate SplitFactory", ex); } } @@ -97,5 +100,9 @@ private static void checkPreconditions(@NonNull String sdkKey, @NonNull Key key, if (context == null) { throw new SplitInstantiationException("Could not instantiate SplitFactory. Context cannot be null"); } + + if (config.proxy() != null && config.proxy().getHost() == null) { + throw new SplitInstantiationException("Could not instantiate SplitFactory. When configured, proxy host cannot be null"); + } } } diff --git a/src/main/java/io/split/android/client/SplitFactoryHelper.java b/src/main/java/io/split/android/client/SplitFactoryHelper.java index 2c1c33d95..fa1831a09 100644 --- a/src/main/java/io/split/android/client/SplitFactoryHelper.java +++ b/src/main/java/io/split/android/client/SplitFactoryHelper.java @@ -49,7 +49,6 @@ import io.split.android.client.service.sseclient.notifications.MySegmentsV2PayloadDecoder; import io.split.android.client.service.sseclient.notifications.NotificationParser; import io.split.android.client.service.sseclient.notifications.NotificationProcessor; -import io.split.android.client.service.sseclient.notifications.SplitsChangeNotification; import io.split.android.client.service.sseclient.notifications.mysegments.MembershipsNotificationProcessorFactory; import io.split.android.client.service.sseclient.notifications.mysegments.MembershipsNotificationProcessorFactoryImpl; import io.split.android.client.service.sseclient.reactor.MySegmentsUpdateWorkerRegistry; @@ -94,6 +93,7 @@ import io.split.android.client.telemetry.TelemetrySynchronizerStub; import io.split.android.client.telemetry.storage.TelemetryRuntimeProducer; import io.split.android.client.telemetry.storage.TelemetryStorage; +import io.split.android.client.utils.HttpProxySerializer; import io.split.android.client.utils.Utils; import io.split.android.client.utils.logger.Logger; @@ -165,14 +165,15 @@ SplitStorageContainer buildStorageContainer(UserConsent userConsentStatus, TelemetryStorage telemetryStorage, long observerCacheExpirationPeriod, ScheduledThreadPoolExecutor impressionsObserverExecutor, - SplitsStorage splitsStorage) { + SplitsStorage splitsStorage, + SplitCipher alwaysEncryptedSplitCipher) { boolean isPersistenceEnabled = userConsentStatus == UserConsent.GRANTED; PersistentEventsStorage persistentEventsStorage = StorageFactory.getPersistentEventsStorage(splitRoomDatabase, splitCipher); PersistentImpressionsStorage persistentImpressionsStorage = StorageFactory.getPersistentImpressionsStorage(splitRoomDatabase, splitCipher); - GeneralInfoStorage generalInfoStorage = StorageFactory.getGeneralInfoStorage(splitRoomDatabase); + GeneralInfoStorage generalInfoStorage = StorageFactory.getGeneralInfoStorage(splitRoomDatabase, alwaysEncryptedSplitCipher); return new SplitStorageContainer( splitsStorage, StorageFactory.getMySegmentsStorage(splitRoomDatabase, splitCipher), @@ -228,6 +229,29 @@ WorkManagerWrapper buildWorkManagerWrapper(Context context, SplitClientConfig sp } + static void setupProxyForBackgroundSync(@NonNull SplitClientConfig config, Runnable proxyConfigSaveTask) { + if (config.proxy() != null && !config.proxy().isLegacy() && config.synchronizeInBackground()) { + // Store proxy config for background sync usage + new Thread(proxyConfigSaveTask).start(); + } + } + + // Visible to inject for testing + @NonNull + static Runnable getProxyConfigSaveTask(@NonNull SplitClientConfig config, WorkManagerWrapper workManagerWrapper, GeneralInfoStorage generalInfoStorage) { + return new Runnable() { + @Override + public void run() { + try { + generalInfoStorage.setProxyConfig(HttpProxySerializer.serialize(config.proxy())); + } catch (Exception ex) { + Logger.w("Failed to store proxy config for background sync. Disabling background sync", ex); + workManagerWrapper.removeWork(); + } + } + }; + } + SyncManager buildSyncManager(SplitClientConfig config, SplitTaskExecutor splitTaskExecutor, Synchronizer synchronizer, diff --git a/src/main/java/io/split/android/client/SplitFactoryImpl.java b/src/main/java/io/split/android/client/SplitFactoryImpl.java index 4aea0bb13..2ac7d457d 100644 --- a/src/main/java/io/split/android/client/SplitFactoryImpl.java +++ b/src/main/java/io/split/android/client/SplitFactoryImpl.java @@ -155,13 +155,17 @@ private SplitFactoryImpl(@NonNull String apiToken, @NonNull Key key, @NonNull Sp mConfig = config; SplitCipher splitCipher = factoryHelper.getCipher(apiToken, config.encryptionEnabled()); - SplitsStorage splitsStorage = getSplitsStorage(splitDatabase, splitCipher); + // At the moment this cipher is only used for proxy config + SplitCipher alwaysEncryptedSplitCipher = (config.synchronizeInBackground() && config.proxy() != null && !config.proxy().isLegacy()) ? + factoryHelper.getCipher(apiToken, true) : null; + + SplitsStorage splitsStorage = StorageFactory.getSplitsStorage(splitDatabase, splitCipher); ScheduledThreadPoolExecutor impressionsObserverExecutor = new ScheduledThreadPoolExecutor(1, new ThreadPoolExecutor.CallerRunsPolicy()); mStorageContainer = factoryHelper.buildStorageContainer(config.userConsent(), - splitDatabase, config.shouldRecordTelemetry(), splitCipher, telemetryStorage, config.observerCacheExpirationPeriod(), impressionsObserverExecutor, splitsStorage); + splitDatabase, config.shouldRecordTelemetry(), splitCipher, telemetryStorage, config.observerCacheExpirationPeriod(), impressionsObserverExecutor, splitsStorage, alwaysEncryptedSplitCipher); mSplitTaskExecutor = new SplitTaskExecutorImpl(); mSplitTaskExecutor.pause(); @@ -173,20 +177,27 @@ private SplitFactoryImpl(@NonNull String apiToken, @NonNull Key key, @NonNull Sp String splitsFilterQueryStringFromConfig = filtersConfig.second; String flagsSpec = getFlagsSpec(testingConfig); + FlagSetsFilter flagSetsFilter = factoryHelper.getFlagSetsFilter(filters); + WorkManagerWrapper workManagerWrapper = factoryHelper.buildWorkManagerWrapper(context, config, apiToken, databaseName, filters); + HttpClient defaultHttpClient; if (httpClient == null) { HttpClientImpl.Builder builder = new HttpClientImpl.Builder() .setConnectionTimeout(config.connectionTimeout()) .setReadTimeout(config.readTimeout()) - .setProxy(config.proxy()) .setDevelopmentSslConfig(config.developmentSslConfig()) .setContext(context) .setProxyAuthenticator(config.authenticator()); + if (config.proxy() != null) { + builder.setProxy(config.proxy()); + } if (config.certificatePinningConfiguration() != null) { builder.setCertificatePinningConfiguration(config.certificatePinningConfiguration()); } defaultHttpClient = builder.build(); + + SplitFactoryHelper.setupProxyForBackgroundSync(config, SplitFactoryHelper.getProxyConfigSaveTask(config, workManagerWrapper, mStorageContainer.getGeneralInfoStorage())); } else { defaultHttpClient = httpClient; } @@ -195,26 +206,23 @@ private SplitFactoryImpl(@NonNull String apiToken, @NonNull Key key, @NonNull Sp SplitApiFacade splitApiFacade = factoryHelper.buildApiFacade( config, defaultHttpClient, splitsFilterQueryStringFromConfig); - FlagSetsFilter flagSetsFilter = factoryHelper.getFlagSetsFilter(filters); - SplitTaskFactory splitTaskFactory = new SplitTaskFactoryImpl( config, splitApiFacade, mStorageContainer, splitsFilterQueryStringFromConfig, getFlagsSpec(testingConfig), mEventsManagerCoordinator, filters, flagSetsFilter, testingConfig); - WorkManagerWrapper workManagerWrapper = factoryHelper.buildWorkManagerWrapper(context, config, apiToken, databaseName, filters); - + SplitSingleThreadTaskExecutor splitSingleThreadTaskExecutor = new SplitSingleThreadTaskExecutor(); splitSingleThreadTaskExecutor.pause(); ImpressionStrategyProvider impressionStrategyProvider = factoryHelper.getImpressionStrategyProvider(mSplitTaskExecutor, splitTaskFactory, mStorageContainer, config); Pair noneComponents = impressionStrategyProvider.getNoneComponents(); - + mImpressionManager = new StrategyImpressionManager(noneComponents, impressionStrategyProvider.getStrategy(config.impressionsMode())); final RetryBackoffCounterTimerFactory retryBackoffCounterTimerFactory = new RetryBackoffCounterTimerFactory(); StreamingComponents streamingComponents = factoryHelper.buildStreamingComponents(mSplitTaskExecutor, splitTaskFactory, config, defaultHttpClient, splitApiFacade, mStorageContainer, flagsSpec); - + Synchronizer mSynchronizer = new SynchronizerImpl( config, mSplitTaskExecutor, @@ -379,11 +387,6 @@ public void run() { new SplitValidatorImpl(), splitParser); } - @NonNull - private static SplitsStorage getSplitsStorage(SplitRoomDatabase splitDatabase, SplitCipher splitCipher) { - return StorageFactory.getSplitsStorage(splitDatabase, splitCipher); - } - private static String getFlagsSpec(@Nullable TestingConfig testingConfig) { if (testingConfig == null) { return BuildConfig.FLAGS_SPEC; diff --git a/src/main/java/io/split/android/client/dtos/HttpProxyDto.java b/src/main/java/io/split/android/client/dtos/HttpProxyDto.java new file mode 100644 index 000000000..323b32191 --- /dev/null +++ b/src/main/java/io/split/android/client/dtos/HttpProxyDto.java @@ -0,0 +1,117 @@ +package io.split.android.client.dtos; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import com.google.gson.annotations.SerializedName; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; + +import io.split.android.client.network.BasicCredentialsProvider; +import io.split.android.client.network.BearerCredentialsProvider; + +/** + * DTO for HttpProxy serialization to JSON for storage in GeneralInfoStorage. + */ +public class HttpProxyDto { + + @SerializedName("host") + public String host; + + @SerializedName("port") + public int port; + + @SerializedName("username") + public String username; + + @SerializedName("password") + public String password; + + @SerializedName("client_cert") + public String clientCert; + + @SerializedName("client_key") + public String clientKey; + + @SerializedName("ca_cert") + public String caCert; + + @SerializedName("bearer_token") + public String bearerToken; + + public HttpProxyDto() { + // Default constructor for deserialization + } + + /** + * Constructor that creates a DTO from an HttpProxy instance. + * Note that we don't store the actual stream data, only whether they exist. + * + * @param httpProxy The HttpProxy instance to convert + */ + public HttpProxyDto(@NonNull io.split.android.client.network.HttpProxy httpProxy) { + this.host = httpProxy.getHost(); + this.port = httpProxy.getPort(); + if (httpProxy.getCredentialsProvider() instanceof BasicCredentialsProvider) { + BasicCredentialsProvider provider = (BasicCredentialsProvider) httpProxy.getCredentialsProvider(); + this.username = provider.getUsername(); + this.password = provider.getPassword(); + } else if (httpProxy.getCredentialsProvider() instanceof BearerCredentialsProvider) { + BearerCredentialsProvider provider = (BearerCredentialsProvider) httpProxy.getCredentialsProvider(); + this.bearerToken = provider.getToken(); + } + + this.clientCert = streamToString(httpProxy.getClientCertStream()); + this.clientKey = streamToString(httpProxy.getClientKeyStream()); + this.caCert = streamToString(httpProxy.getCaCertStream()); + } + + /** + * Converts an InputStream to a String. + * + * @param inputStream The InputStream to convert + * @return String representation of the InputStream contents, or null if the stream is null + */ + @Nullable + private String streamToString(@Nullable InputStream inputStream) { + if (inputStream == null) { + return null; + } + + try { + StringBuilder content = getStringBuilder(inputStream); + + // Reset the stream if possible to allow reuse + try { + inputStream.reset(); + } catch (IOException ignored) { + + } + return content.toString(); + } catch (Exception e) { + return null; + } + } + + @NonNull + private static StringBuilder getStringBuilder(@NonNull InputStream inputStream) throws IOException { + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)); + StringBuilder content = new StringBuilder(); + String line; + boolean firstLine = true; + + while ((line = reader.readLine()) != null) { + if (!firstLine) { + content.append("\n"); + } else { + firstLine = false; + } + content.append(line); + } + return content; + } +} diff --git a/src/main/java/io/split/android/client/network/BasicCredentialsProvider.java b/src/main/java/io/split/android/client/network/BasicCredentialsProvider.java new file mode 100644 index 000000000..b68a6659b --- /dev/null +++ b/src/main/java/io/split/android/client/network/BasicCredentialsProvider.java @@ -0,0 +1,13 @@ +package io.split.android.client.network; + +/** + * Interface for providing basic credentials. + *

+ * The username and password will be used to create a Proxy-Authorization header using Basic authentication + */ +public interface BasicCredentialsProvider extends ProxyCredentialsProvider { + + String getUsername(); + + String getPassword(); +} diff --git a/src/main/java/io/split/android/client/network/BearerCredentialsProvider.java b/src/main/java/io/split/android/client/network/BearerCredentialsProvider.java new file mode 100644 index 000000000..d372ce5e7 --- /dev/null +++ b/src/main/java/io/split/android/client/network/BearerCredentialsProvider.java @@ -0,0 +1,11 @@ +package io.split.android.client.network; + +/** + * Interface for providing proxy credentials. + *

+ * The token will be sent in the header "Proxy-Authorization: Bearer " + */ +public interface BearerCredentialsProvider extends ProxyCredentialsProvider { + + String getToken(); +} diff --git a/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java b/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java index f733e28f9..7871f6949 100644 --- a/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java +++ b/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java @@ -17,7 +17,6 @@ import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.X509TrustManager; -import io.split.android.client.utils.Base64Util; import io.split.android.client.utils.logger.Logger; class CertificateCheckerImpl implements CertificateChecker { @@ -100,17 +99,4 @@ private String certificateChainInfo(List cleanCertificates) { return builder.toString(); } - - private static class DefaultBase64Encoder implements Base64Encoder { - - @Override - public String encode(String value) { - return Base64Util.encode(value); - } - - @Override - public String encode(byte[] bytes) { - return Base64Util.encode(bytes); - } - } } diff --git a/src/main/java/io/split/android/client/network/CertificatePinningConfiguration.java b/src/main/java/io/split/android/client/network/CertificatePinningConfiguration.java index d7549bb6f..23ec94d5c 100644 --- a/src/main/java/io/split/android/client/network/CertificatePinningConfiguration.java +++ b/src/main/java/io/split/android/client/network/CertificatePinningConfiguration.java @@ -209,12 +209,5 @@ private Set getInitializedPins(String host) { } return pins; } - - private static class DefaultBase64Decoder implements Base64Decoder { - @Override - public byte[] decode(String base64) { - return Base64Util.bytesDecode(base64); - } - } } } diff --git a/src/main/java/io/split/android/client/network/DefaultBase64Decoder.java b/src/main/java/io/split/android/client/network/DefaultBase64Decoder.java new file mode 100644 index 000000000..c84903fb6 --- /dev/null +++ b/src/main/java/io/split/android/client/network/DefaultBase64Decoder.java @@ -0,0 +1,11 @@ +package io.split.android.client.network; + +import io.split.android.client.utils.Base64Util; + +class DefaultBase64Decoder implements Base64Decoder { + + @Override + public byte[] decode(String base64) { + return Base64Util.bytesDecode(base64); + } +} diff --git a/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java b/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java new file mode 100644 index 000000000..e1333ca80 --- /dev/null +++ b/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java @@ -0,0 +1,16 @@ +package io.split.android.client.network; + +import io.split.android.client.utils.Base64Util; + +class DefaultBase64Encoder implements Base64Encoder { + + @Override + public String encode(String value) { + return Base64Util.encode(value); + } + + @Override + public String encode(byte[] bytes) { + return Base64Util.encode(bytes); + } +} diff --git a/src/main/java/io/split/android/client/network/HttpClientImpl.java b/src/main/java/io/split/android/client/network/HttpClientImpl.java index ca0a1d46d..0d955e19d 100644 --- a/src/main/java/io/split/android/client/network/HttpClientImpl.java +++ b/src/main/java/io/split/android/client/network/HttpClientImpl.java @@ -6,6 +6,10 @@ import androidx.annotation.Nullable; import androidx.annotation.VisibleForTesting; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.net.InetSocketAddress; import java.net.Proxy; import java.net.URI; @@ -28,7 +32,11 @@ public class HttpClientImpl implements HttpClient { @Nullable private final Proxy mProxy; @Nullable + private final HttpProxy mHttpProxy; + @Nullable private final SplitUrlConnectionAuthenticator mProxyAuthenticator; + @Nullable + private final ProxyCredentialsProvider mProxyCredentialsProvider; private final long mReadTimeout; private final long mConnectionTimeout; @Nullable @@ -39,17 +47,22 @@ public class HttpClientImpl implements HttpClient { private final UrlSanitizer mUrlSanitizer; @Nullable private final CertificateChecker mCertificateChecker; + @Nullable + private final ProxyCacertConnectionHandler mConnectionHandler; HttpClientImpl(@Nullable HttpProxy proxy, @Nullable SplitAuthenticator proxyAuthenticator, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, long readTimeout, long connectionTimeout, @Nullable DevelopmentSslConfig developmentSslConfig, @Nullable SSLSocketFactory sslSocketFactory, @NonNull UrlSanitizer urlSanitizer, @Nullable CertificateChecker certificateChecker) { + mHttpProxy = proxy; mProxy = initializeProxy(proxy); mProxyAuthenticator = initializeProxyAuthenticator(proxy, proxyAuthenticator); + mProxyCredentialsProvider = proxyCredentialsProvider; mReadTimeout = readTimeout; mConnectionTimeout = connectionTimeout; mDevelopmentSslConfig = developmentSslConfig; @@ -58,6 +71,9 @@ public class HttpClientImpl implements HttpClient { mSslSocketFactory = sslSocketFactory; mUrlSanitizer = urlSanitizer; mCertificateChecker = certificateChecker; + mConnectionHandler = mHttpProxy != null && mSslSocketFactory != null && + (mHttpProxy.getCaCertStream() != null || mHttpProxy.getClientCertStream() != null) ? + new ProxyCacertConnectionHandler() : null; } @Override @@ -73,7 +89,9 @@ public HttpRequest request(URI uri, HttpMethod requestMethod, String body, Map + * This class is responsible for executing HTTP requests through tunnel sockets that have been + * created by the SSL tunnel establisher using the custom tunneling approach. + */ +class HttpOverTunnelExecutor { + + public static final int HTTP_PORT = 80; + public static final int HTTPS_PORT = 443; + public static final int UNSET_PORT = -1; + private static final String CRLF = "\r\n"; + + private final RawHttpResponseParser mResponseParser; + + public HttpOverTunnelExecutor() { + mResponseParser = new RawHttpResponseParser(); + } + + @NonNull + HttpResponse executeRequest( + @NonNull Socket tunnelSocket, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable String body, + @Nullable Certificate[] serverCertificates) throws IOException { + + Logger.v("Executing request through tunnel to: " + targetUrl); + + try { + sendHttpRequest(tunnelSocket, targetUrl, method, headers, body); + + return readHttpResponse(tunnelSocket, serverCertificates); + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped + // This ensures consistent behavior with non-proxy flows + throw e; + } catch (Exception e) { + // Wrap other exceptions in IOException + Logger.e("Failed to execute request through tunnel: " + e.getMessage()); + throw new IOException("Failed to execute HTTP request through tunnel to " + targetUrl, e); + } + } + + @NonNull + HttpStreamResponse executeStreamRequest(@NonNull Socket finalSocket, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable Certificate[] serverCertificates) throws IOException { + Logger.v("Executing stream request through tunnel to: " + targetUrl); + + try { + sendHttpRequest(finalSocket, targetUrl, method, headers, null); + return readHttpStreamResponse(finalSocket, originSocket); + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped + // This ensures consistent behavior with non-proxy flows + throw e; + } catch (Exception e) { + // Wrap other exceptions in IOException + Logger.e("Failed to execute stream request through tunnel: " + e.getMessage()); + throw new IOException("Failed to execute HTTP stream request through tunnel to " + targetUrl, e); + } + } + + /** + * Sends the HTTP request through the tunnel socket. + */ + private void sendHttpRequest( + @NonNull Socket tunnelSocket, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable String body) throws IOException { + + PrintWriter writer = new PrintWriter(tunnelSocket.getOutputStream(), true); + + // 1. Send request line + String path = targetUrl.getPath(); + if (path.isEmpty()) { + path = "/"; + } + if (targetUrl.getQuery() != null) { + path += "?" + targetUrl.getQuery(); + } + + String requestLine = method.name() + " " + path + " HTTP/1.1"; + writer.write(requestLine + CRLF); + + // 2. Send Host header (required for HTTP/1.1) + String host = targetUrl.getHost(); + int port = getTargetPort(targetUrl); + + // Add port to Host header if it's not the default port for the protocol + if (!isIsDefaultPort(targetUrl, port)) { + host += ":" + port; + } + + writer.write("Host: " + host + CRLF); + + // 3. Send custom headers (excluding Host and Content-Length) + for (Map.Entry header : headers.entrySet()) { + if (header.getKey() != null && header.getValue() != null && + !"content-length".equalsIgnoreCase(header.getKey()) && + !"host".equalsIgnoreCase(header.getKey())) { + String headerLine = header.getKey() + ": " + header.getValue(); + writer.write(headerLine + CRLF); + } + } + + // 4. Send Content-Length header if body is present + if (body != null) { + String contentLengthHeader = "Content-Length: " + body.getBytes("UTF-8").length; + writer.write(contentLengthHeader + CRLF); + } + + // 5. Send Connection: close to ensure response completion + writer.write("Connection: close" + CRLF); + + // 6. End headers with empty line + writer.write(CRLF); + + // 7. Send body if present + if (body != null) { + Logger.v("Sending request body: '" + body + "'"); + writer.write(body); + } + + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Failed to send HTTP request through tunnel"); + } + } + + private static boolean isIsDefaultPort(@NonNull URL targetUrl, int port) { + return ("http".equalsIgnoreCase(targetUrl.getProtocol()) && port == HTTP_PORT) || + ("https".equalsIgnoreCase(targetUrl.getProtocol()) && port == HTTPS_PORT); + } + + /** + * Reads HTTP response from the tunnel socket. + * + * @param tunnelSocket The socket to read from + * @param serverCertificates The server certificates to include in the response + * @return HttpResponse with server certificates + */ + private HttpResponse readHttpResponse(@NonNull Socket tunnelSocket, @Nullable Certificate[] serverCertificates) throws IOException { + return mResponseParser.parseHttpResponse(tunnelSocket.getInputStream(), serverCertificates); + } + + private HttpStreamResponse readHttpStreamResponse(@NonNull Socket tunnelSocket, @Nullable Socket originSocket) throws IOException { + return mResponseParser.parseHttpStreamResponse(tunnelSocket.getInputStream(), tunnelSocket, originSocket); + } + + /** + * Gets the target port from URL, defaulting based on protocol. + */ + private int getTargetPort(@NonNull URL targetUrl) { + int port = targetUrl.getPort(); + if (port == UNSET_PORT) { + if ("https".equalsIgnoreCase(targetUrl.getProtocol())) { + return HTTPS_PORT; + } else if ("http".equalsIgnoreCase(targetUrl.getProtocol())) { + return HTTP_PORT; + } + } + return port; + } +} diff --git a/src/main/java/io/split/android/client/network/HttpProxy.java b/src/main/java/io/split/android/client/network/HttpProxy.java index 8f9734fca..a6dc011fa 100644 --- a/src/main/java/io/split/android/client/network/HttpProxy.java +++ b/src/main/java/io/split/android/client/network/HttpProxy.java @@ -1,47 +1,118 @@ package io.split.android.client.network; -import static io.split.android.client.utils.Utils.checkNotNull; - import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import java.io.InputStream; + public class HttpProxy { - final private String host; - final private int port; - final private String username; - final private String password; + private final @NonNull String mHost; + private final int mPort; + private final @Nullable String mUsername; + private final @Nullable String mPassword; + private final @Nullable InputStream mClientCertStream; + private final @Nullable InputStream mClientKeyStream; + private final @Nullable InputStream mCaCertStream; + private final @Nullable ProxyCredentialsProvider mCredentialsProvider; + private final boolean mIsLegacy; - public HttpProxy(@NonNull String host, int port) { - this(host, port, null, null); + private HttpProxy(Builder builder, boolean isLegacy) { + mHost = builder.mHost; + mPort = builder.mPort; + mUsername = builder.mUsername; + mPassword = builder.mPassword; + mClientCertStream = builder.mClientCertStream; + mClientKeyStream = builder.mClientKeyStream; + mCaCertStream = builder.mCaCertStream; + mCredentialsProvider = builder.mCredentialsProvider; + mIsLegacy = isLegacy; } - public HttpProxy(@NonNull String host, int port, @Nullable String username, @Nullable String password) { - checkNotNull(host); + public @Nullable String getHost() { + return mHost; + } - this.host = host; - this.port = port; - this.username = username; - this.password = password; + public int getPort() { + return mPort; } - public String getHost() { - return host; + public @Nullable String getUsername() { + return mUsername; } - public int getPort() { - return port; + public @Nullable String getPassword() { + return mPassword; + } + + public @Nullable InputStream getClientCertStream() { + return mClientCertStream; + } + + public @Nullable InputStream getClientKeyStream() { + return mClientKeyStream; + } + + public @Nullable InputStream getCaCertStream() { + return mCaCertStream; + } + + public @Nullable ProxyCredentialsProvider getCredentialsProvider() { + return mCredentialsProvider; } - public String getUsername() { - return username; + public static Builder newBuilder(@Nullable String host, int port) { + return new Builder(host, port); } - public String getPassword() { - return password; + public boolean isLegacy() { + return mIsLegacy; } - public boolean usesCredentials() { - return username == null; + public static class Builder { + private final @Nullable String mHost; + private final int mPort; + private @Nullable String mUsername; + private @Nullable String mPassword; + private @Nullable InputStream mClientCertStream; + private @Nullable InputStream mClientKeyStream; + private @Nullable InputStream mCaCertStream; + @Nullable + private ProxyCredentialsProvider mCredentialsProvider; + + private Builder(@Nullable String host, int port) { + mHost = host; + mPort = port; + } + + public Builder basicAuth(@NonNull String username, @NonNull String password) { + mUsername = username; + mPassword = password; + return this; + } + + public Builder proxyCacert(@NonNull InputStream caCertStream) { + mCaCertStream = caCertStream; + return this; + } + + public Builder mtls(@NonNull InputStream clientCertStream, @NonNull InputStream keyStream) { + mClientCertStream = clientCertStream; + mClientKeyStream = keyStream; + return this; + } + + public Builder credentialsProvider(@NonNull ProxyCredentialsProvider credentialsProvider) { + mCredentialsProvider = credentialsProvider; + return this; + } + + public HttpProxy build() { + return new HttpProxy(this, false); + } + + public HttpProxy buildLegacy() { + return new HttpProxy(this, true); + } } } diff --git a/src/main/java/io/split/android/client/network/HttpRequestHelper.java b/src/main/java/io/split/android/client/network/HttpRequestHelper.java index 8deca2548..0fe702bd4 100644 --- a/src/main/java/io/split/android/client/network/HttpRequestHelper.java +++ b/src/main/java/io/split/android/client/network/HttpRequestHelper.java @@ -19,12 +19,54 @@ class HttpRequestHelper { - static HttpURLConnection openConnection(@Nullable Proxy proxy, - @Nullable SplitUrlConnectionAuthenticator proxyAuthenticator, - @NonNull URL url, - @NonNull HttpMethod method, - @NonNull Map headers, - boolean useProxyAuthentication) throws IOException { + private static final ProxyCacertConnectionHandler mConnectionHandler = new ProxyCacertConnectionHandler(); + + static HttpURLConnection createConnection(@NonNull URL url, + @Nullable Proxy proxy, + @Nullable HttpProxy httpProxy, + @Nullable SplitUrlConnectionAuthenticator proxyAuthenticator, + @NonNull HttpMethod method, + @NonNull Map headers, + boolean useProxyAuthentication, + @Nullable SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + @Nullable String body) throws IOException { + + if (httpProxy != null && sslSocketFactory != null && (httpProxy.getCaCertStream() != null || httpProxy.getClientCertStream() != null)) { + try { + HttpResponse response = mConnectionHandler.executeRequest( + httpProxy, + url, + method, + headers, + body, + sslSocketFactory, + proxyCredentialsProvider + ); + + return new HttpResponseConnectionAdapter(url, response, response.getServerCertificates()); + } catch (UnsupportedOperationException e) { + // Fall through to standard handling + } + } + + return openConnection(proxy, httpProxy, proxyAuthenticator, url, method, headers, useProxyAuthentication); + } + + private static HttpURLConnection openConnection(@Nullable Proxy proxy, + @Nullable HttpProxy httpProxy, + @Nullable SplitUrlConnectionAuthenticator proxyAuthenticator, + @NonNull URL url, + @NonNull HttpMethod method, + @NonNull Map headers, + boolean useProxyAuthentication) throws IOException { + + // Check if we need custom SSL proxy handling + if (httpProxy != null && (httpProxy.getCaCertStream() != null || httpProxy.getClientCertStream() != null)) { + throw new IOException("SSL proxy scenarios require custom handling - use executeRequest method instead"); + } + + // Standard HttpURLConnection proxy handling HttpURLConnection connection; if (proxy != null) { connection = (HttpURLConnection) url.openConnection(proxy); @@ -84,7 +126,7 @@ static void checkPins(HttpURLConnection connection, @Nullable CertificateChecker private static void addHeaders(HttpURLConnection request, Map headers) { for (Map.Entry entry : headers.entrySet()) { - if (entry == null) { + if (entry == null || entry.getKey() == null) { continue; } diff --git a/src/main/java/io/split/android/client/network/HttpRequestImpl.java b/src/main/java/io/split/android/client/network/HttpRequestImpl.java index 6db18413d..1f2a0c402 100644 --- a/src/main/java/io/split/android/client/network/HttpRequestImpl.java +++ b/src/main/java/io/split/android/client/network/HttpRequestImpl.java @@ -4,7 +4,8 @@ import static io.split.android.client.network.HttpRequestHelper.applySslConfig; import static io.split.android.client.network.HttpRequestHelper.applyTimeouts; -import static io.split.android.client.network.HttpRequestHelper.openConnection; +import static io.split.android.client.network.HttpRequestHelper.createConnection; + import androidx.annotation.NonNull; import androidx.annotation.Nullable; @@ -14,6 +15,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; +import java.net.HttpRetryException; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.ProtocolException; @@ -43,7 +45,11 @@ public class HttpRequestImpl implements HttpRequest { @Nullable private final Proxy mProxy; @Nullable + private final HttpProxy mHttpProxy; + @Nullable private final SplitUrlConnectionAuthenticator mProxyAuthenticator; + @Nullable + private final ProxyCredentialsProvider mProxyCredentialsProvider; private final long mReadTimeout; private final long mConnectionTimeout; @Nullable @@ -58,7 +64,9 @@ public class HttpRequestImpl implements HttpRequest { @Nullable String body, @NonNull Map headers, @Nullable Proxy proxy, + @Nullable HttpProxy httpProxy, @Nullable SplitUrlConnectionAuthenticator proxyAuthenticator, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, long readTimeout, long connectionTimeout, @Nullable DevelopmentSslConfig developmentSslConfig, @@ -71,7 +79,9 @@ public class HttpRequestImpl implements HttpRequest { mUrlSanitizer = checkNotNull(urlSanitizer); mHeaders = new HashMap<>(checkNotNull(headers)); mProxy = proxy; + mHttpProxy = httpProxy; mProxyAuthenticator = proxyAuthenticator; + mProxyCredentialsProvider = proxyCredentialsProvider; mReadTimeout = readTimeout; mConnectionTimeout = connectionTimeout; mDevelopmentSslConfig = developmentSslConfig; @@ -178,7 +188,15 @@ private HttpURLConnection setUpConnection(boolean authenticate) throws IOExcepti throw new IOException("Error parsing URL"); } - HttpURLConnection connection = openConnection(mProxy, mProxyAuthenticator, url, mHttpMethod, mHeaders, authenticate); + HttpURLConnection connection; + try { + connection = getConnection(authenticate, url); + } catch (HttpRetryException e) { + if (mProxyAuthenticator == null) { + throw e; + } + connection = getConnection(authenticate, url); + } applyTimeouts(mReadTimeout, mConnectionTimeout, connection); applySslConfig(mSslSocketFactory, mDevelopmentSslConfig, connection); @@ -197,6 +215,21 @@ private HttpURLConnection setUpConnection(boolean authenticate) throws IOExcepti return connection; } + @NonNull + private HttpURLConnection getConnection(boolean authenticate, URL url) throws IOException { + return createConnection( + url, + mProxy, + mHttpProxy, + mProxyAuthenticator, + mHttpMethod, + mHeaders, + authenticate, + mSslSocketFactory, + mProxyCredentialsProvider, + mBody); + } + private static HttpResponse buildResponse(HttpURLConnection connection) throws IOException { int responseCode = connection.getResponseCode(); diff --git a/src/main/java/io/split/android/client/network/HttpResponse.java b/src/main/java/io/split/android/client/network/HttpResponse.java index 417dfb289..42a3d0a93 100644 --- a/src/main/java/io/split/android/client/network/HttpResponse.java +++ b/src/main/java/io/split/android/client/network/HttpResponse.java @@ -1,5 +1,9 @@ package io.split.android.client.network; +import java.security.cert.Certificate; + public interface HttpResponse extends BaseHttpResponse { String getData(); + + Certificate[] getServerCertificates(); } diff --git a/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java b/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java new file mode 100644 index 000000000..fb269cbdf --- /dev/null +++ b/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java @@ -0,0 +1,455 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.Permission; +import java.security.cert.Certificate; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLSocketFactory; + +/** + * Adapter that wraps an HttpResponse as an HttpURLConnection. + *

+ * This is only used to adapt the response from request through the TLS tunnel. + */ +class HttpResponseConnectionAdapter extends HttpsURLConnection { + + private final HttpResponse mResponse; + private final URL mUrl; + private final Certificate[] mServerCertificates; + private final OutputStream mOutputStream; + private InputStream mInputStream; + private InputStream mErrorStream; + private boolean mDoOutput = false; + + /** + * Creates an adapter that wraps an HttpResponse as an HttpURLConnection. + * + * @param url The URL of the request + * @param response The HTTP response from the SSL proxy + * @param serverCertificates The server certificates from the SSL connection + */ + HttpResponseConnectionAdapter(@NonNull URL url, + @NonNull HttpResponse response, + Certificate[] serverCertificates) { + this(url, response, serverCertificates, new ByteArrayOutputStream()); + } + + @VisibleForTesting + HttpResponseConnectionAdapter(@NonNull URL url, + @NonNull HttpResponse response, + Certificate[] serverCertificates, + @NonNull OutputStream outputStream) { + this(url, response, serverCertificates, outputStream, null, null); + } + + @VisibleForTesting + HttpResponseConnectionAdapter(@NonNull URL url, + @NonNull HttpResponse response, + Certificate[] serverCertificates, + @NonNull OutputStream outputStream, + @Nullable InputStream inputStream, + @Nullable InputStream errorStream) { + super(url); + mUrl = url; + mResponse = response; + mServerCertificates = serverCertificates; + mOutputStream = outputStream; + mInputStream = inputStream; + mErrorStream = errorStream; + } + + @Override + public int getResponseCode() throws IOException { + return mResponse.getHttpStatus(); + } + + @Override + public String getResponseMessage() throws IOException { + // Map common HTTP status codes to messages + switch (mResponse.getHttpStatus()) { + case 200: + return "OK"; + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 500: + return "Internal Server Error"; + default: + return "HTTP " + mResponse.getHttpStatus(); + } + } + + @Override + public InputStream getInputStream() throws IOException { + if (mResponse.getHttpStatus() >= 400) { + throw new IOException("HTTP " + mResponse.getHttpStatus()); + } + if (mInputStream == null) { + String data = mResponse.getData(); + if (data == null) { + data = ""; + } + mInputStream = new ByteArrayInputStream(data.getBytes(StandardCharsets.UTF_8)); + } + return mInputStream; + } + + @Override + public InputStream getErrorStream() { + if (mResponse.getHttpStatus() >= 400) { + if (mErrorStream == null) { + String data = mResponse.getData(); + if (data == null) { + data = ""; + } + mErrorStream = new ByteArrayInputStream(data.getBytes(StandardCharsets.UTF_8)); + } + return mErrorStream; + } + return null; + } + + @Override + public void connect() throws IOException { + // Already connected + } + + @Override + public boolean usingProxy() { + return true; + } + + @Override + public void disconnect() { + // Close output stream if it exists + try { + if (mOutputStream != null) { + mOutputStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } + + // Close input stream if it exists + try { + if (mInputStream != null) { + mInputStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } + + // Close error stream if it exists + try { + if (mErrorStream != null) { + mErrorStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } + } + + // Required abstract method implementations for HTTPS connection + @Override + public String getCipherSuite() { + return null; + } + + @Override + public Certificate[] getLocalCertificates() { + return null; + } + + @Override + public Certificate[] getServerCertificates() { + // Return the server certificates from the SSL connection + return mServerCertificates; + } + + // Minimal implementations for other required methods + @Override + public void setRequestMethod(String method) { + } + + @Override + public String getRequestMethod() { + return "GET"; + } + + @Override + public void setInstanceFollowRedirects(boolean followRedirects) { + } + + @Override + public boolean getInstanceFollowRedirects() { + return true; + } + + @Override + public void setDoOutput(boolean doOutput) { + mDoOutput = doOutput; + } + + @Override + public boolean getDoOutput() { + return mDoOutput; + } + + @Override + public void setDoInput(boolean doInput) { + } + + @Override + public boolean getDoInput() { + return true; + } + + @Override + public void setUseCaches(boolean useCaches) { + } + + @Override + public boolean getUseCaches() { + return false; + } + + @Override + public void setIfModifiedSince(long ifModifiedSince) { + } + + @Override + public long getIfModifiedSince() { + return 0; + } + + @Override + public void setDefaultUseCaches(boolean defaultUseCaches) { + } + + @Override + public boolean getDefaultUseCaches() { + return false; + } + + @Override + public void setRequestProperty(String key, String value) { + } + + @Override + public void addRequestProperty(String key, String value) { + } + + @Override + public String getRequestProperty(String key) { + return null; + } + + @Override + public Map> getRequestProperties() { + return null; + } + + @Override + public String getHeaderField(String name) { + if (name == null) { + return null; + } + Map> headers = getHeaderFields(); + List values = headers.get(name.toLowerCase()); + + return (values != null && !values.isEmpty()) ? values.get(0) : null; + } + + @Override + public Map> getHeaderFields() { + Map> headers = new HashMap<>(); + + // Add synthetic headers based on response data + String contentType = getContentType(); + if (contentType != null) { + headers.put("content-type", Collections.singletonList(contentType)); + } + + long contentLength = getContentLengthLong(); + if (contentLength >= 0) { + headers.put("content-length", Collections.singletonList(String.valueOf(contentLength))); + } + + String contentEncoding = getContentEncoding(); + if (contentEncoding != null) { + headers.put("content-encoding", Collections.singletonList(contentEncoding)); + } + + try { + headers.put("status", Collections.singletonList(getResponseCode() + " " + getResponseMessage())); + } catch (IOException e) { + // Ignore if we can't get response code + } + + return headers; + } + + @Override + public int getHeaderFieldInt(String name, int defaultValue) { + String value = getHeaderField(name); + if (value != null) { + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + // Fall through to default + } + } + return defaultValue; + } + + @Override + public long getHeaderFieldDate(String name, long defaultValue) { + // We don't have actual date headers + if ("date".equalsIgnoreCase(name)) { + return System.currentTimeMillis(); + } + return defaultValue; + } + + @Override + public String getHeaderFieldKey(int n) { + Map> headers = getHeaderFields(); + if (n >= 0 && n < headers.size()) { + return (String) headers.keySet().toArray()[n]; + } + return null; + } + + @Override + public String getHeaderField(int n) { + String key = getHeaderFieldKey(n); + return key != null ? getHeaderField(key) : null; + } + + @Override + public long getContentLengthLong() { + String data = mResponse.getData(); + if (data == null) { + return 0; + } + return data.getBytes(StandardCharsets.UTF_8).length; + } + + @Override + public String getContentType() { + // Try to detect content type from response data, default to JSON for API responses + String data = mResponse.getData(); + if (data == null || data.trim().isEmpty()) { + return null; + } + String trimmed = data.trim(); + if (trimmed.startsWith("{") || trimmed.startsWith("[")) { + return "application/json; charset=utf-8"; + } + if (trimmed.startsWith("<")) { + return "text/html; charset=utf-8"; + } + return "text/plain; charset=utf-8"; + } + + @Override + public String getContentEncoding() { + return "utf-8"; + } + + @Override + public long getExpiration() { + return 0; + } + + @Override + public long getDate() { + return System.currentTimeMillis(); + } + + @Override + public long getLastModified() { + return 0; + } + + @Override + public URL getURL() { + return mUrl; + } + + @Override + public int getContentLength() { + long length = getContentLengthLong(); + return length > Integer.MAX_VALUE ? -1 : (int) length; + } + + @Override + public Permission getPermission() throws IOException { + return null; + } + + @Override + public OutputStream getOutputStream() throws IOException { + if (!mDoOutput) { + throw new IOException("Output not enabled for this connection. Call setDoOutput(true) first."); + } + return mOutputStream; + } + + @Override + public void setConnectTimeout(int timeout) { + } + + @Override + public int getConnectTimeout() { + return 0; + } + + @Override + public void setReadTimeout(int timeout) { + } + + @Override + public int getReadTimeout() { + return 0; + } + + @Override + public void setHostnameVerifier(HostnameVerifier v) { + } + + @Override + public HostnameVerifier getHostnameVerifier() { + return null; + } + + @Override + public void setSSLSocketFactory(SSLSocketFactory sf) { + } + + @Override + public SSLSocketFactory getSSLSocketFactory() { + return null; + } +} diff --git a/src/main/java/io/split/android/client/network/HttpResponseImpl.java b/src/main/java/io/split/android/client/network/HttpResponseImpl.java index 1e1f05a0d..07c970d46 100644 --- a/src/main/java/io/split/android/client/network/HttpResponseImpl.java +++ b/src/main/java/io/split/android/client/network/HttpResponseImpl.java @@ -1,20 +1,37 @@ package io.split.android.client.network; -public class HttpResponseImpl extends BaseHttpResponseImpl implements HttpResponse { +import java.security.cert.Certificate; + +public class HttpResponseImpl extends BaseHttpResponseImpl implements HttpResponse { private final String mData; + private final Certificate[] mServerCertificates; HttpResponseImpl(int httpStatus) { - this(httpStatus, null); + this(httpStatus, (String) null); + } + + HttpResponseImpl(int httpStatus, Certificate[] serverCertificates) { + this(httpStatus, null, serverCertificates); } public HttpResponseImpl(int httpStatus, String data) { + this(httpStatus, data, null); + } + + public HttpResponseImpl(int httpStatus, String data, Certificate[] serverCertificates) { super(httpStatus); mData = data; + mServerCertificates = serverCertificates; } @Override public String getData() { return mData; } + + @Override + public Certificate[] getServerCertificates() { + return mServerCertificates; + } } diff --git a/src/main/java/io/split/android/client/network/HttpStreamRequest.java b/src/main/java/io/split/android/client/network/HttpStreamRequest.java index 0b11d6844..f0bb28c67 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamRequest.java +++ b/src/main/java/io/split/android/client/network/HttpStreamRequest.java @@ -1,7 +1,9 @@ package io.split.android.client.network; +import java.io.IOException; + public interface HttpStreamRequest { void addHeader(String name, String value); - HttpStreamResponse execute() throws HttpException; + HttpStreamResponse execute() throws HttpException, IOException; void close(); } diff --git a/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java b/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java index 60453013b..3a010c04f 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java +++ b/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java @@ -1,11 +1,11 @@ package io.split.android.client.network; import static io.split.android.client.network.HttpRequestHelper.checkPins; +import static io.split.android.client.network.HttpRequestHelper.createConnection; import static io.split.android.client.utils.Utils.checkNotNull; import static io.split.android.client.network.HttpRequestHelper.applySslConfig; import static io.split.android.client.network.HttpRequestHelper.applyTimeouts; -import static io.split.android.client.network.HttpRequestHelper.openConnection; import androidx.annotation.NonNull; import androidx.annotation.Nullable; @@ -18,6 +18,7 @@ import java.net.MalformedURLException; import java.net.ProtocolException; import java.net.Proxy; +import java.net.SocketException; import java.net.URI; import java.net.URL; import java.util.HashMap; @@ -52,6 +53,12 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { @Nullable private final CertificateChecker mCertificateChecker; private final AtomicBoolean mWasRetried = new AtomicBoolean(false); + @Nullable + private final HttpProxy mHttpProxy; + @Nullable + private final ProxyCredentialsProvider mProxyCredentialsProvider; + @Nullable + private final ProxyCacertConnectionHandler mConnectionHandler; HttpStreamRequestImpl(@NonNull URI uri, @NonNull Map headers, @@ -61,7 +68,10 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { @Nullable DevelopmentSslConfig developmentSslConfig, @Nullable SSLSocketFactory sslSocketFactory, @NonNull UrlSanitizer urlSanitizer, - @Nullable CertificateChecker certificateChecker) { + @Nullable CertificateChecker certificateChecker, + @Nullable HttpProxy httpProxy, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + @Nullable ProxyCacertConnectionHandler proxyCacertConnectionHandler) { mUri = checkNotNull(uri); mHttpMethod = HttpMethod.GET; mProxy = proxy; @@ -72,10 +82,13 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { mDevelopmentSslConfig = developmentSslConfig; mSslSocketFactory = sslSocketFactory; mCertificateChecker = certificateChecker; + mHttpProxy = httpProxy; + mProxyCredentialsProvider = proxyCredentialsProvider; + mConnectionHandler = proxyCacertConnectionHandler; } @Override - public HttpStreamResponse execute() throws HttpException { + public HttpStreamResponse execute() throws HttpException, IOException { return getRequest(); } @@ -107,14 +120,18 @@ private void closeBufferedReader() { } } - private HttpStreamResponse getRequest() throws HttpException { + private HttpStreamResponse getRequest() throws HttpException, IOException { HttpStreamResponse response; try { - mConnection = setUpConnection(false); - response = buildResponse(mConnection); + if (mConnectionHandler != null && mHttpProxy != null && mSslSocketFactory != null && (mHttpProxy.getCaCertStream() != null || mHttpProxy.getClientCertStream() != null)) { + response = mConnectionHandler.executeStreamRequest(mHttpProxy, getUrl(), mHttpMethod, mHeaders, mSslSocketFactory, mProxyCredentialsProvider); + } else { + mConnection = setUpConnection(false); + response = buildResponse(mConnection); - if (response.getHttpStatus() == HttpURLConnection.HTTP_PROXY_AUTH) { - response = handleAuthentication(response); + if (response.getHttpStatus() == HttpURLConnection.HTTP_PROXY_AUTH) { + response = handleAuthentication(response); + } } } catch (MalformedURLException e) { disconnect(); @@ -125,6 +142,11 @@ private HttpStreamResponse getRequest() throws HttpException { } catch (SSLPeerUnverifiedException e) { disconnect(); throw new HttpException("SSL peer not verified: " + e.getLocalizedMessage(), HttpStatus.INTERNAL_NON_RETRYABLE.getCode()); + } catch (SocketException e) { + disconnect(); + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + // This ensures socket closures are treated the same in both direct and proxy flows + throw e; } catch (IOException e) { disconnect(); throw new HttpException("Something happened while retrieving data: " + e.getLocalizedMessage()); @@ -134,12 +156,20 @@ private HttpStreamResponse getRequest() throws HttpException { } private HttpURLConnection setUpConnection(boolean useProxyAuthenticator) throws IOException { - URL url = mUrlSanitizer.getUrl(mUri); - if (url == null) { - throw new IOException("Error parsing URL"); - } - - HttpURLConnection connection = openConnection(mProxy, mProxyAuthenticator, url, mHttpMethod, mHeaders, useProxyAuthenticator); + URL url = getUrl(); + + HttpURLConnection connection = createConnection( + url, + mProxy, + mHttpProxy, + mProxyAuthenticator, + mHttpMethod, + mHeaders, + useProxyAuthenticator, + mSslSocketFactory, + mProxyCredentialsProvider, + null + ); applyTimeouts(HttpStreamRequestImpl.STREAMING_READ_TIMEOUT_IN_MILLISECONDS, mConnectionTimeout, connection); applySslConfig(mSslSocketFactory, mDevelopmentSslConfig, connection); connection.connect(); @@ -148,6 +178,15 @@ private HttpURLConnection setUpConnection(boolean useProxyAuthenticator) throws return connection; } + @NonNull + private URL getUrl() throws IOException { + URL url = mUrlSanitizer.getUrl(mUri); + if (url == null) { + throw new IOException("Error parsing URL"); + } + return url; + } + private HttpStreamResponse handleAuthentication(HttpStreamResponse response) throws HttpException { if (!mWasRetried.getAndSet(true)) { try { @@ -171,11 +210,11 @@ private HttpStreamResponse buildResponse(HttpURLConnection connection) throws IO } mBufferedReader = new BufferedReader(new InputStreamReader(inputStream)); - return new HttpStreamResponseImpl(responseCode, mBufferedReader); + return HttpStreamResponseImpl.createFromHttpUrlConnection(responseCode, mBufferedReader); } } - return new HttpStreamResponseImpl(responseCode); + return HttpStreamResponseImpl.createFromHttpUrlConnection(responseCode, null); } private void disconnect() { diff --git a/src/main/java/io/split/android/client/network/HttpStreamResponse.java b/src/main/java/io/split/android/client/network/HttpStreamResponse.java index e20096041..f733bd3bc 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamResponse.java +++ b/src/main/java/io/split/android/client/network/HttpStreamResponse.java @@ -3,8 +3,9 @@ import androidx.annotation.Nullable; import java.io.BufferedReader; +import java.io.Closeable; -public interface HttpStreamResponse extends BaseHttpResponse { +public interface HttpStreamResponse extends BaseHttpResponse, Closeable { @Nullable BufferedReader getBufferedReader(); } diff --git a/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java b/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java index 175433c44..bf24d0e74 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java +++ b/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java @@ -3,18 +3,39 @@ import androidx.annotation.Nullable; import java.io.BufferedReader; +import java.io.IOException; +import java.net.Socket; + +import io.split.android.client.utils.logger.Logger; public class HttpStreamResponseImpl extends BaseHttpResponseImpl implements HttpStreamResponse { private final BufferedReader mData; - HttpStreamResponseImpl(int httpStatus) { - this(httpStatus, null); - } + // Sockets are referenced when using Proxy tunneling, in order to close them + @Nullable + private final Socket mTunnelSocket; + @Nullable + private final Socket mOriginSocket; - public HttpStreamResponseImpl(int httpStatus, BufferedReader data) { + private HttpStreamResponseImpl(int httpStatus, BufferedReader data, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) { super(httpStatus); mData = data; + mTunnelSocket = tunnelSocket; + mOriginSocket = originSocket; + } + + static HttpStreamResponseImpl createFromTunnelSocket(int httpStatus, + BufferedReader data, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) { + return new HttpStreamResponseImpl(httpStatus, data, tunnelSocket, originSocket); + } + + static HttpStreamResponseImpl createFromHttpUrlConnection(int httpStatus, BufferedReader data) { + return new HttpStreamResponseImpl(httpStatus, data, null, null); } @Override @@ -22,4 +43,37 @@ public HttpStreamResponseImpl(int httpStatus, BufferedReader data) { public BufferedReader getBufferedReader() { return mData; } + + @Override + public void close() throws IOException { + + // Close the BufferedReader first + if (mData != null) { + try { + mData.close(); + } catch (IOException e) { + Logger.w("Failed to close BufferedReader: " + e.getMessage()); + } + } + + // Close origin socket if it exists and is different from tunnel socket + if (mOriginSocket != null && mOriginSocket != mTunnelSocket) { + try { + mOriginSocket.close(); + Logger.v("Origin socket closed"); + } catch (IOException e) { + Logger.w("Failed to close origin socket: " + e.getMessage()); + } + } + + // Close tunnel socket + if (mTunnelSocket != null) { + try { + mTunnelSocket.close(); + Logger.v("Tunnel socket closed"); + } catch (IOException e) { + Logger.w("Failed to close tunnel socket: " + e.getMessage()); + } + } + } } diff --git a/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java b/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java new file mode 100644 index 000000000..65cdcb6e4 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java @@ -0,0 +1,242 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.IOException; +import java.net.Socket; +import java.net.SocketException; +import java.net.URL; +import java.security.cert.Certificate; +import java.util.Map; + +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +import io.split.android.client.utils.logger.Logger; + +/** + * Handles PROXY_CACERT SSL proxy connections. + *

+ * This handler establishes SSL tunnels through SSL proxies using custom CA certificates + * for proxy authentication, then executes HTTP requests through the SSL tunnel. + */ +class ProxyCacertConnectionHandler { + + public static final String HTTPS = "https"; + public static final String HTTP = "http"; + public static final int PORT_HTTPS = 443; + public static final int PORT_HTTP = 80; + private final HttpOverTunnelExecutor mTunnelExecutor; + + public ProxyCacertConnectionHandler() { + mTunnelExecutor = new HttpOverTunnelExecutor(); + } + + /** + * Executes an HTTP request through an SSL proxy tunnel. + * + * @param httpProxy The proxy configuration + * @param targetUrl The target URL to connect to + * @param method The HTTP method to use + * @param headers The HTTP headers to include + * @param body The request body (if any) + * @param sslSocketFactory The SSL socket factory for proxy and origin connections + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @return The HTTP response + * @throws IOException if the request fails + */ + @NonNull + HttpResponse executeRequest(@NonNull HttpProxy httpProxy, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable String body, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + try { + TunnelConnection connection = establishTunnelConnection( + httpProxy, targetUrl, sslSocketFactory, proxyCredentialsProvider, false); + + try { + return mTunnelExecutor.executeRequest( + connection.finalSocket, + targetUrl, + method, + headers, + body, + connection.serverCertificates); + } finally { + // Close all sockets for non-streaming requests + closeConnection(connection); + } + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + throw e; + } catch (Exception e) { + Logger.e("Failed to execute request through custom tunnel: " + e.getMessage()); + throw new IOException("Failed to execute request through custom tunnel", e); + } + } + + @NonNull + HttpStreamResponse executeStreamRequest(@NonNull HttpProxy httpProxy, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + try { + TunnelConnection connection = establishTunnelConnection( + httpProxy, targetUrl, sslSocketFactory, proxyCredentialsProvider, true); + + // For streaming requests, pass socket references to the response for later cleanup + Socket originSocket = (connection.finalSocket != connection.tunnelSocket) ? connection.finalSocket : null; + return mTunnelExecutor.executeStreamRequest( + connection.finalSocket, + connection.tunnelSocket, + originSocket, + targetUrl, + method, + headers, + connection.serverCertificates); + // For streaming requests, sockets are NOT closed here + // They will be closed when the HttpStreamResponse.close() is called + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + throw e; + } catch (Exception e) { + throw new IOException("Failed to execute request through custom tunnel", e); + } + } + + private static int getTargetPort(@NonNull URL targetUrl) { + int port = targetUrl.getPort(); + if (port == -1) { + if (HTTPS.equalsIgnoreCase(targetUrl.getProtocol())) { + return PORT_HTTPS; + } else if (HTTP.equalsIgnoreCase(targetUrl.getProtocol())) { + return PORT_HTTP; + } + } + return port; + } + + /** + * Represents a connection through an SSL tunnel. + */ + private static class TunnelConnection { + final Socket tunnelSocket; + final Socket finalSocket; + final Certificate[] serverCertificates; + + TunnelConnection(Socket tunnelSocket, Socket finalSocket, Certificate[] serverCertificates) { + this.tunnelSocket = tunnelSocket; + this.finalSocket = finalSocket; + this.serverCertificates = serverCertificates; + } + } + + /** + * Establishes a tunnel connection to the target through the proxy. + * + * @param httpProxy The proxy configuration + * @param targetUrl The target URL to connect to + * @param sslSocketFactory SSL socket factory for connections + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @param isStreaming Whether this is a streaming connection + * @return A TunnelConnection object containing the established sockets + * @throws IOException if connection establishment fails + */ + private TunnelConnection establishTunnelConnection( + @NonNull HttpProxy httpProxy, + @NonNull URL targetUrl, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + boolean isStreaming) throws IOException { + + SslProxyTunnelEstablisher tunnelEstablisher = new SslProxyTunnelEstablisher(); + Socket tunnelSocket = null; + Socket finalSocket = null; + Certificate[] serverCertificates = null; + + try { + tunnelSocket = tunnelEstablisher.establishTunnel( + httpProxy.getHost(), + httpProxy.getPort(), + targetUrl.getHost(), + getTargetPort(targetUrl), + sslSocketFactory, + proxyCredentialsProvider, + isStreaming + ); + + finalSocket = tunnelSocket; + + // If the origin is HTTPS, wrap the tunnel socket with a new SSLSocket (system CA) + if (HTTPS.equalsIgnoreCase(targetUrl.getProtocol())) { + try { + // Use the provided SSLSocketFactory, which is configured to trust the origin's CA + finalSocket = sslSocketFactory.createSocket( + tunnelSocket, + targetUrl.getHost(), + getTargetPort(targetUrl), + true // autoClose + ); + if (finalSocket instanceof SSLSocket) { + SSLSocket originSslSocket = (SSLSocket) finalSocket; + originSslSocket.setUseClientMode(true); + originSslSocket.startHandshake(); + + // Capture server certificates after successful handshake + try { + serverCertificates = originSslSocket.getSession().getPeerCertificates(); + } catch (Exception certEx) { + Logger.w("Could not capture origin server certificates: " + certEx.getMessage()); + } + } else { + throw new IOException("Failed to create SSLSocket to origin"); + } + } catch (Exception sslEx) { + Logger.e("Failed to establish SSL connection to origin: " + sslEx.getMessage()); + throw new IOException("Failed to establish SSL connection to origin server", sslEx); + } + } + + return new TunnelConnection(tunnelSocket, finalSocket, serverCertificates); + } catch (Exception e) { + // Clean up resources on error + closeSockets(finalSocket, tunnelSocket); + throw e; + } + } + + private void closeConnection(TunnelConnection connection) { + if (connection == null) { + return; + } + + closeSockets(connection.finalSocket, connection.tunnelSocket); + } + + private void closeSockets(Socket finalSocket, Socket tunnelSocket) { + // If we are tunnelling, finalSocket is the tunnel socket + if (finalSocket != null && finalSocket != tunnelSocket) { + try { + finalSocket.close(); + } catch (IOException e) { + Logger.w("Failed to close origin SSL socket: " + e.getMessage()); + } + } + + if (tunnelSocket != null) { + try { + tunnelSocket.close(); + } catch (IOException e) { + Logger.w("Failed to close tunnel socket: " + e.getMessage()); + } + } + } +} diff --git a/src/main/java/io/split/android/client/network/ProxyConfiguration.java b/src/main/java/io/split/android/client/network/ProxyConfiguration.java new file mode 100644 index 000000000..e25314992 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxyConfiguration.java @@ -0,0 +1,137 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; + +import io.split.android.client.utils.logger.Logger; + +/** + * Proxy configuration + */ +public class ProxyConfiguration { + + private final URI mUrl; + private final ProxyCredentialsProvider mCredentialsProvider; + private final InputStream mClientCert; + private final InputStream mClientPk; + private final InputStream mCaCert; + + ProxyConfiguration(@NonNull URI url, + @Nullable ProxyCredentialsProvider credentialsProvider, + @Nullable InputStream clientCert, + @Nullable InputStream clientPk, + @Nullable InputStream caCert) { + mUrl = url; + mCredentialsProvider = credentialsProvider; + mClientCert = clientCert; + mClientPk = clientPk; + mCaCert = caCert; + } + + public URI getUrl() { + return mUrl; + } + + public ProxyCredentialsProvider getCredentialsProvider() { + return mCredentialsProvider; + } + + public InputStream getClientCert() { + return mClientCert; + } + + public InputStream getClientPk() { + return mClientPk; + } + + public InputStream getCaCert() { + return mCaCert; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private URI mUrl; + private ProxyCredentialsProvider mCredentialsProvider; + private InputStream mClientCert; + private InputStream mClientPk; + private InputStream mCaCert; + + private Builder() { + + } + + /** + * Set the proxy URL + * + * @param url MUST NOT be null + * @return this builder + */ + public Builder url(@NonNull String url) { + try { + mUrl = new URI(url); + } catch (NullPointerException | URISyntaxException e) { + Logger.e("Proxy url was not a valid URL."); + } + return this; + } + + /** + * Set the credentials provider. + *

+ * Can be an implementation of {@link BearerCredentialsProvider} or {@link BasicCredentialsProvider} + * + * @param credentialsProvider A non null credentials provider + * @return this builder + */ + public Builder credentialsProvider(@NonNull ProxyCredentialsProvider credentialsProvider) { + mCredentialsProvider = credentialsProvider; + return this; + } + + /** + * Set the client certificate and private key in PKCS#8 format + * + * @param clientCert The client certificate + * @param clientPk The client private key + * @return this builder + */ + public Builder mtls(@NonNull InputStream clientCert, @NonNull InputStream clientPk) { + mClientCert = clientCert; + mClientPk = clientPk; + return this; + } + + /** + * Set the Proxy CA certificate + * + * @param caCert The CA certificate in PEM or DER format + * @return this builder + */ + public Builder caCert(@NonNull InputStream caCert) { + mCaCert = caCert; + return this; + } + + /** + * Build the proxy configuration. + * This method will return null if the proxy URL is not set. + * + * @return The proxy configuration + */ + @Nullable + public ProxyConfiguration build() { + if (mUrl == null) { + Logger.w("Proxy configuration with no URL. This will prevent SplitFactory from working."); + } + + return new ProxyConfiguration(mUrl, mCredentialsProvider, mClientCert, mClientPk, mCaCert); + } + } +} diff --git a/src/main/java/io/split/android/client/network/ProxyCredentialsProvider.java b/src/main/java/io/split/android/client/network/ProxyCredentialsProvider.java new file mode 100644 index 000000000..6533883a8 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxyCredentialsProvider.java @@ -0,0 +1,10 @@ +package io.split.android.client.network; + +/** + * Interface for providing proxy credentials. + *

+ * Implementations can be {@link BasicCredentialsProvider} or {@link BearerCredentialsProvider} + */ +public interface ProxyCredentialsProvider { + +} diff --git a/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProvider.java b/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProvider.java new file mode 100644 index 000000000..10c4d8862 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProvider.java @@ -0,0 +1,31 @@ +package io.split.android.client.network; + +import androidx.annotation.Nullable; + +import java.io.InputStream; + +import javax.net.ssl.SSLSocketFactory; + +interface ProxySslSocketFactoryProvider { + + /** + * Create an SSLSocketFactory for proxy connections using a CA certificate from an InputStream. + * The InputStream will be closed after use. + * + * @param caCertInputStream InputStream containing CA certificate (PEM or DER). + * @return SSLSocketFactory configured for the requested scenario + */ + SSLSocketFactory create(@Nullable InputStream caCertInputStream) throws Exception; + + /** + * Create an SSLSocketFactory for proxy connections using CA cert and separate client certificate and key files. + * All InputStreams will be closed after use. + * + * @param caCertInputStream InputStream containing one or more CA certificates (PEM or DER). + * @param clientCertInputStream InputStream containing client certificate (PEM or DER). + * @param clientKeyInputStream InputStream containing client private key (PEM format, PKCS#8). + + * @return SSLSocketFactory configured for mTLS proxy authentication + */ + SSLSocketFactory create(@Nullable InputStream caCertInputStream, @Nullable InputStream clientCertInputStream, @Nullable InputStream clientKeyInputStream) throws Exception; +} diff --git a/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProviderImpl.java b/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProviderImpl.java new file mode 100644 index 000000000..49a84c134 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProviderImpl.java @@ -0,0 +1,270 @@ +package io.split.android.client.network; + +import static io.split.android.client.utils.Utils.checkNotNull; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Collection; +import java.util.Enumeration; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +import io.split.android.client.utils.logger.Logger; + +class ProxySslSocketFactoryProviderImpl implements ProxySslSocketFactoryProvider { + + private final Base64Decoder mBase64Decoder; + + ProxySslSocketFactoryProviderImpl() { + this(new DefaultBase64Decoder()); + } + + ProxySslSocketFactoryProviderImpl(@NonNull Base64Decoder base64Decoder) { + mBase64Decoder = checkNotNull(base64Decoder); + } + + @Override + public SSLSocketFactory create(@Nullable InputStream caCertInputStream) throws Exception { + // The TrustManagerFactory is necessary because of the CA cert + return createSslSocketFactory(null, createTrustManagerFactory(caCertInputStream)); + } + + @Override + public SSLSocketFactory create(@Nullable InputStream caCertInputStream, @Nullable InputStream clientCertInputStream, @Nullable InputStream clientKeyInputStream) throws Exception { + // The KeyManagerFactory is necessary because of the client certificate and key files + KeyManagerFactory keyManagerFactory = createKeyManagerFactory(clientCertInputStream, clientKeyInputStream); + + // The TrustManagerFactory is necessary because of the CA cert + TrustManagerFactory trustManagerFactory = createTrustManagerFactory(caCertInputStream); + + return createSslSocketFactory(keyManagerFactory, trustManagerFactory); + } + + @Nullable + private TrustManagerFactory createTrustManagerFactory(@Nullable InputStream caCertInputStream) throws Exception { + if (caCertInputStream == null) { + return null; + } + + try { + // Generate Certificate objects from the InputStream + CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + Collection caCertificates = certificateFactory.generateCertificates(caCertInputStream); + + KeyStore combinedTrustStore = getCombinedStore(caCertificates); + + // Initialize the TrustManagerFactory with the combined trust store + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(combinedTrustStore); + + return trustManagerFactory; + } finally { + caCertInputStream.close(); + } + } + + /** + * Create a KeyStore with both system CAs and user provided CAs + * @param caCertificates User provided CAs + * @return KeyStore + */ + @NonNull + private static KeyStore getCombinedStore(Collection caCertificates) throws NoSuchAlgorithmException, KeyStoreException, CertificateException, IOException { + // Start with the system's default trust store to include standard CA certificates + TrustManagerFactory defaultTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + defaultTrustManagerFactory.init((KeyStore) null); // Initialize with system default keystore + + // Get the default trust store + KeyStore defaultTrustStore = null; + for (TrustManager tm : defaultTrustManagerFactory.getTrustManagers()) { + if (tm instanceof X509TrustManager) { + // Create a new keystore and populate it with system CAs + defaultTrustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + defaultTrustStore.load(null, null); + + X509Certificate[] acceptedIssuers = ((X509TrustManager) tm).getAcceptedIssuers(); + for (int j = 0; j < acceptedIssuers.length; j++) { + defaultTrustStore.setCertificateEntry("systemCA" + j, acceptedIssuers[j]); + } + break; + } + } + + // Create combined trust store with both system CAs and custom proxy CAs + KeyStore combinedTrustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + combinedTrustStore.load(null, null); + + // Add system CA certificates if we found them + if (defaultTrustStore != null) { + Enumeration aliases = defaultTrustStore.aliases(); + while (aliases.hasMoreElements()) { + String alias = aliases.nextElement(); + Certificate cert = defaultTrustStore.getCertificate(alias); + if (cert != null) { + combinedTrustStore.setCertificateEntry(alias, cert); + } + } + } + + // Add custom proxy CA certificates + int i = 0; + for (Certificate ca : caCertificates) { + combinedTrustStore.setCertificateEntry("proxyCA" + (i++), ca); + } + return combinedTrustStore; + } + + /** + * Creates a KeyManagerFactory from separate client certificate and key files. + * + * @param clientCertInputStream InputStream containing client certificate (PEM or DER) + * @param clientKeyInputStream InputStream containing client private key (PEM format) + * @return KeyManagerFactory initialized with the client certificate and key + * @throws Exception if there is an error loading the certificate or key + */ + private KeyManagerFactory createKeyManagerFactory(@Nullable InputStream clientCertInputStream, + @Nullable InputStream clientKeyInputStream) throws Exception { + if (clientCertInputStream == null || clientKeyInputStream == null) { + return null; + } + + try { + // Get cert and key + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + Certificate cert = cf.generateCertificate(clientCertInputStream); + + PrivateKey privateKey = loadPrivateKeyFromPem(clientKeyInputStream); + + // Initialize a KeyStore and add the cert and key + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, null); // Initialize empty keystore + keyStore.setKeyEntry("client", privateKey, new char[0], new Certificate[] { cert }); + + // Initialize the KeyManagerFactory with the created KeyStore + KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, new char[0]); + + return keyManagerFactory; + } finally { + clientCertInputStream.close(); + clientKeyInputStream.close(); + } + } + + /** + * Loads a private key from a PEM-encoded input stream. + * Only supports PKCS#8 format (BEGIN PRIVATE KEY). + * + * @param keyInputStream InputStream containing the PEM-encoded private key + * @return PrivateKey object + * @throws Exception if there is an error loading the key + */ + private PrivateKey loadPrivateKeyFromPem(InputStream keyInputStream) throws Exception { + try { + // Read the key file + String keyContent = readInputStream(keyInputStream); + if (keyContent.contains("BEGIN PRIVATE KEY")) { + // PKCS#8 format - can be loaded directly + return loadPkcs8PrivateKey(keyContent); + } else { + throw new IllegalArgumentException("Unsupported private key format. Must be PEM encoded PKCS#8 format (BEGIN PRIVATE KEY)"); + } + } catch (Exception e) { + Logger.e("Error loading private key: " + e.getMessage()); + throw e; + } + } + + /** + * Loads a PKCS#8 format private key. + */ + private PrivateKey loadPkcs8PrivateKey(String keyContent) throws Exception { + // Extract the base64 encoded private key using proper PEM parsing + String privateKeyPEM = extractPemContent(keyContent); + + // Decode the Base64 encoded private key + byte[] encoded = mBase64Decoder.decode(privateKeyPEM); + + // Create a PKCS8 key spec and generate the private key + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(encoded); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + try { + return keyFactory.generatePrivate(keySpec); + } catch (InvalidKeySpecException e) { + // Try with EC algorithm if RSA fails + try { + keyFactory = KeyFactory.getInstance("EC"); + return keyFactory.generatePrivate(keySpec); + } catch (InvalidKeySpecException ecException) { + Logger.e("Error loading private key: Neither RSA nor EC algorithms could load the key"); + throw new IllegalArgumentException("Invalid PKCS#8 private key format. Key could not be loaded with RSA or EC algorithms.", e); + } + } + } + + private String readInputStream(InputStream inputStream) throws IOException { + StringBuilder stringBuilder = new StringBuilder(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + stringBuilder.append(line).append("\n"); + } + } + return stringBuilder.toString(); + } + + /** + * Extracts the base64 content from a PKCS#8 String. + * + * @param pemContent The full PEM content + * @return The base64 encoded content without PEM boundaries or whitespace + * @throws IllegalArgumentException if PEM boundaries are not found or malformed + */ + private String extractPemContent(String pemContent) throws IllegalArgumentException { + String beginMarker = "-----BEGIN PRIVATE KEY-----"; + int beginIndex = pemContent.indexOf(beginMarker); + if (beginIndex == -1) { + throw new IllegalArgumentException("PEM begin marker not found: " + beginMarker); + } + String endMarker = "-----END PRIVATE KEY-----"; + int endIndex = pemContent.indexOf(endMarker, beginIndex + beginMarker.length()); + if (endIndex == -1) { + throw new IllegalArgumentException("PEM end marker not found: " + endMarker); + } + String base64Content = pemContent.substring(beginIndex + beginMarker.length(), endIndex); + return base64Content.replaceAll("\\s+", ""); + } + + private SSLSocketFactory createSslSocketFactory(@Nullable KeyManagerFactory keyManagerFactory, @Nullable TrustManagerFactory trustManagerFactory) throws Exception { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyManager[] keyManagers = keyManagerFactory != null ? keyManagerFactory.getKeyManagers() : null; + TrustManager[] trustManagers = trustManagerFactory != null ? trustManagerFactory.getTrustManagers() : null; + + sslContext.init(keyManagers, trustManagers, null); + + return sslContext.getSocketFactory(); + } +} diff --git a/src/main/java/io/split/android/client/network/RawHttpResponseParser.java b/src/main/java/io/split/android/client/network/RawHttpResponseParser.java new file mode 100644 index 000000000..5968419cc --- /dev/null +++ b/src/main/java/io/split/android/client/network/RawHttpResponseParser.java @@ -0,0 +1,299 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.Socket; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.cert.Certificate; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import io.split.android.client.utils.logger.Logger; + +/** + * Parses raw HTTP protocol responses from socket input streams. + * Handles the HTTP protocol parsing (status line, headers, body) from socket streams. + */ +class RawHttpResponseParser { + + /** + * Parses a raw HTTP response from an input stream. + * + * @param inputStream The input stream containing the raw HTTP response + * @param serverCertificates The server certificates to include in the response + * @return HttpResponse containing the parsed status code, headers, and response data + * @throws IOException if parsing fails or the response is malformed + */ + @NonNull + HttpResponse parseHttpResponse(@NonNull InputStream inputStream, Certificate[] serverCertificates) throws IOException { + // 1. Read and parse status line + String statusLine = readLineFromStream(inputStream); + if (statusLine == null) { + throw new IOException("No HTTP response received from server"); + } + + int statusCode = parseStatusCode(statusLine); + + // 2. Read and parse response headers directly + ParsedResponseHeaders responseHeaders = parseHeadersDirectly(inputStream); + + // 3. Determine charset from Content-Type header + Charset bodyCharset = extractCharsetFromContentType(responseHeaders.mContentType); + + // 4. Read response body using the same InputStream + String responseBody = readResponseBody(inputStream, responseHeaders.mIsChunked, bodyCharset, responseHeaders.mContentLength, responseHeaders.mConnectionClose); + + // 5. Create and return HttpResponse + if (responseBody != null && !responseBody.trim().isEmpty()) { + return new HttpResponseImpl(statusCode, responseBody, serverCertificates); + } else { + return new HttpResponseImpl(statusCode, serverCertificates); + } + } + + @NonNull + HttpStreamResponse parseHttpStreamResponse(@NonNull InputStream inputStream, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) throws IOException { + // 1. Read and parse status line + String statusLine = readLineFromStream(inputStream); + if (statusLine == null) { + throw new IOException("No HTTP response received from server"); + } + + int statusCode = parseStatusCode(statusLine); + + // 2. Read and parse response headers directly + ParsedResponseHeaders responseHeaders = parseHeadersDirectly(inputStream); + + // 3. Determine charset from Content-Type header + Charset bodyCharset = extractCharsetFromContentType(responseHeaders.mContentType); + + return HttpStreamResponseImpl.createFromTunnelSocket(statusCode, + new BufferedReader(new InputStreamReader(inputStream, bodyCharset)), + tunnelSocket, + originSocket); + } + + @NonNull + private ParsedResponseHeaders parseHeadersDirectly(@NonNull InputStream inputStream) throws IOException { + int contentLength = -1; + boolean isChunked = false; + boolean connectionClose = false; + String contentType = null; + String headerLine; + + while ((headerLine = readLineFromStream(inputStream)) != null && !headerLine.trim().isEmpty()) { + int colonIndex = headerLine.indexOf(':'); + if (colonIndex > 0) { + String headerName = headerLine.substring(0, colonIndex).trim(); + String headerValue = headerLine.substring(colonIndex + 1).trim(); + + String lowerHeaderName = headerName.toLowerCase(Locale.US); + if ("content-length".equals(lowerHeaderName)) { + try { + contentLength = Integer.parseInt(headerValue); + } catch (NumberFormatException e) { + Logger.w("Invalid Content-Length header: " + headerLine); + } + } else if ("transfer-encoding".equals(lowerHeaderName) && headerValue.toLowerCase(Locale.US).contains("chunked")) { + isChunked = true; + } else if ("connection".equals(lowerHeaderName) && headerValue.toLowerCase(Locale.US).contains("close")) { + connectionClose = true; + } else if ("content-type".equals(lowerHeaderName)) { + contentType = headerValue; + } + } + } + return new ParsedResponseHeaders(contentLength, isChunked, connectionClose, contentType); + } + + @Nullable + private String readResponseBody(@NonNull InputStream inputStream, boolean isChunked, Charset bodyCharset, int contentLength, boolean connectionClose) throws IOException { + String responseBody = null; + if (isChunked) { + responseBody = readChunkedBodyWithCharset(inputStream, bodyCharset); + } else if (contentLength > 0) { + responseBody = readFixedLengthBodyWithCharset(inputStream, contentLength, bodyCharset); + } else if (connectionClose) { + responseBody = readUntilCloseWithCharset(inputStream, bodyCharset); + } + return responseBody; + } + + /** + * Parses the HTTP status code from the status line. + */ + private int parseStatusCode(@NonNull String statusLine) throws IOException { + // Status line format: "HTTP/1.1 200 OK" or "HTTP/1.0 404 Not Found" + String[] parts = statusLine.split(" "); + if (parts.length < 2) { + throw new IOException("Invalid HTTP status line: " + statusLine); + } + + try { + return Integer.parseInt(parts[1]); + } catch (NumberFormatException e) { + throw new IOException("Invalid HTTP status code in line: " + statusLine, e); + } + } + + /** + * Extracts charset from Content-Type header, defaulting to UTF-8. + */ + private Charset extractCharsetFromContentType(String contentType) { + if (contentType == null) { + return StandardCharsets.UTF_8; + } + + // Pattern to match charset=value in Content-Type header + Pattern charsetPattern = Pattern.compile("charset\\s*=\\s*([^\\s;]+)", Pattern.CASE_INSENSITIVE); + Matcher matcher = charsetPattern.matcher(contentType); + + if (matcher.find()) { + String charsetName = matcher.group(1).replaceAll("[\"']", ""); // Remove quotes + try { + return Charset.forName(charsetName); + } catch (Exception e) { + Logger.w("Unsupported charset: " + charsetName + ", using UTF-8"); + } + } + + return StandardCharsets.UTF_8; + } + + private String readChunkedBodyWithCharset(InputStream inputStream, Charset charset) throws IOException { + ByteArrayOutputStream bodyBytes = new ByteArrayOutputStream(); + + while (true) { + // Read chunk size line + String chunkSizeLine = readLineFromStream(inputStream); + if (chunkSizeLine == null) { + throw new IOException("Unexpected EOF while reading chunk size"); + } + + // Parse chunk size (ignore extensions after semicolon) + int semicolonIndex = chunkSizeLine.indexOf(';'); + String sizeStr = semicolonIndex >= 0 ? chunkSizeLine.substring(0, semicolonIndex).trim() : chunkSizeLine.trim(); + + int chunkSize; + try { + chunkSize = Integer.parseInt(sizeStr, 16); + } catch (NumberFormatException e) { + throw new IOException("Invalid chunk size: " + chunkSizeLine, e); + } + + if (chunkSize < 0) { + throw new IOException("Negative chunk size: " + chunkSize); + } + + // If chunk size is 0, we've reached the end + if (chunkSize == 0) { + // Read trailing headers until empty line + String trailerLine; + while ((trailerLine = readLineFromStream(inputStream)) != null && !trailerLine.trim().isEmpty()) { + // no-op + } + break; + } + + // Read chunk data (exact byte count) + byte[] chunkData = new byte[chunkSize]; + int totalRead = 0; + while (totalRead < chunkSize) { + int read = inputStream.read(chunkData, totalRead, chunkSize - totalRead); + if (read == -1) { + throw new IOException("Unexpected EOF while reading chunk data"); + } + totalRead += read; + } + + bodyBytes.write(chunkData); + + // Read trailing CRLF after chunk data + int c1 = inputStream.read(); + int c2 = inputStream.read(); + if (c1 != '\r' || c2 != '\n') { + throw new IOException("Expected CRLF after chunk data, got: " + (char) c1 + (char) c2); + } + } + + return new String(bodyBytes.toByteArray(), charset); + } + + private String readFixedLengthBodyWithCharset(InputStream inputStream, int contentLength, Charset charset) throws IOException { + byte[] bodyBytes = new byte[contentLength]; + int totalRead = 0; + + while (totalRead < contentLength) { + int read = inputStream.read(bodyBytes, totalRead, contentLength - totalRead); + if (read == -1) { + throw new IOException("Unexpected EOF while reading fixed-length body"); + } + totalRead += read; + } + + return new String(bodyBytes, charset); + } + + private String readUntilCloseWithCharset(InputStream inputStream, Charset charset) throws IOException { + ByteArrayOutputStream bodyBytes = new ByteArrayOutputStream(); + byte[] buffer = new byte[8192]; + int bytesRead; + + while ((bytesRead = inputStream.read(buffer)) != -1) { + bodyBytes.write(buffer, 0, bytesRead); + } + + return new String(bodyBytes.toByteArray(), charset); + } + + private String readLineFromStream(InputStream inputStream) throws IOException { + ByteArrayOutputStream lineBytes = new ByteArrayOutputStream(); + int b; + boolean foundCR = false; + + while ((b = inputStream.read()) != -1) { + if (b == '\r') { + foundCR = true; + } else if (b == '\n' && foundCR) { + break; + } else if (foundCR) { + // CR not followed by LF, add the CR to output + lineBytes.write('\r'); + lineBytes.write(b); + foundCR = false; + } else { + lineBytes.write(b); + } + } + + if (b == -1 && lineBytes.size() == 0) { + return null; // EOF + } + + return new String(lineBytes.toByteArray(), StandardCharsets.UTF_8); + } + + private static class ParsedResponseHeaders { + final int mContentLength; + final boolean mIsChunked; + final boolean mConnectionClose; + final String mContentType; + + ParsedResponseHeaders(int contentLength, boolean isChunked, boolean connectionClose, String contentType) { + mContentLength = contentLength; + mIsChunked = isChunked; + mConnectionClose = connectionClose; + mContentType = contentType; + } + } +} diff --git a/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java new file mode 100644 index 000000000..d7940e1b5 --- /dev/null +++ b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java @@ -0,0 +1,216 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.net.HttpRetryException; +import java.net.HttpURLConnection; +import java.net.Socket; +import java.nio.charset.StandardCharsets; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +/** + * Establishes SSL tunnels to SSL proxies using CONNECT protocol. + */ +class SslProxyTunnelEstablisher { + + private static final String CRLF = "\r\n"; + private static final String PROXY_AUTHORIZATION_HEADER = "Proxy-Authorization"; + private final Base64Encoder mBase64Encoder; + + // Default timeout for regular connections (10 seconds) + private static final int DEFAULT_SOCKET_TIMEOUT = 20000; + + SslProxyTunnelEstablisher() { + this(new DefaultBase64Encoder()); + } + + @VisibleForTesting + SslProxyTunnelEstablisher(Base64Encoder base64Encoder) { + mBase64Encoder = base64Encoder; + } + + /** + * Establishes an SSL tunnel through the proxy using the CONNECT method. + * After successful tunnel establishment, extracts the underlying socket + * for use with origin server SSL connections. + * + * @param proxyHost The proxy server hostname + * @param proxyPort The proxy server port + * @param targetHost The target server hostname + * @param targetPort The target server port + * @param sslSocketFactory SSL socket factory for proxy authentication + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @param isStreaming Whether this connection is for streaming (uses longer timeout) + * @return Raw socket with tunnel established (connection maintained) + * @throws IOException if tunnel establishment fails + */ + @NonNull + Socket establishTunnel(@NonNull String proxyHost, + int proxyPort, + @NonNull String targetHost, + int targetPort, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + boolean isStreaming) throws IOException { + + Socket rawSocket = null; + SSLSocket sslSocket = null; + + try { + int timeout = DEFAULT_SOCKET_TIMEOUT; + // Step 1: Create raw TCP connection to proxy + rawSocket = new Socket(proxyHost, proxyPort); + rawSocket.setSoTimeout(timeout); + + // Create a temporary SSL socket to establish the SSL session with proper trust validation + sslSocket = (SSLSocket) sslSocketFactory.createSocket(rawSocket, proxyHost, proxyPort, true); + sslSocket.setUseClientMode(true); + if (isStreaming) { + sslSocket.setSoTimeout(0); // no timeout for streaming + } else { + sslSocket.setSoTimeout(timeout); + } + + // Perform SSL handshake using the SSL socket with custom CA certificates + sslSocket.startHandshake(); + + // Validate the proxy hostname + HostnameVerifier verifier = HttpsURLConnection.getDefaultHostnameVerifier(); + if (!verifier.verify(proxyHost, sslSocket.getSession())) { + throw new SSLHandshakeException("Proxy hostname verification failed"); + } + + // Step 3: Send CONNECT request through SSL connection + sendConnectRequest(sslSocket, targetHost, targetPort, proxyCredentialsProvider); + + // Step 4: Validate CONNECT response through SSL connection + validateConnectResponse(sslSocket); + + // Step 5: Return SSL socket for tunnel communication + return sslSocket; + + } catch (Exception e) { + // Clean up resources on error + if (sslSocket != null) { + try { + sslSocket.close(); + } catch (IOException closeEx) { + // Ignore close exceptions + } + } else if (rawSocket != null) { + try { + rawSocket.close(); + } catch (IOException closeEx) { + // Ignore close exceptions + } + } + + if (e instanceof HttpRetryException) { + throw (HttpRetryException) e; + } else if (e instanceof IOException) { + throw (IOException) e; + } else { + throw new IOException("Failed to establish SSL tunnel", e); + } + } + } + + /** + * Sends CONNECT request through SSL connection to proxy. + */ + private void sendConnectRequest(@NonNull SSLSocket sslSocket, + @NonNull String targetHost, + int targetPort, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + PrintWriter writer = new PrintWriter(new OutputStreamWriter(sslSocket.getOutputStream(), StandardCharsets.UTF_8), false); + writer.write("CONNECT " + targetHost + ":" + targetPort + " HTTP/1.1" + CRLF); + writer.write("Host: " + targetHost + ":" + targetPort + CRLF); + + if (proxyCredentialsProvider != null) { + addProxyAuthHeader(proxyCredentialsProvider, writer); + } + + // Send empty line to end headers + writer.write(CRLF); + writer.flush(); + } + + private void addProxyAuthHeader(@NonNull ProxyCredentialsProvider proxyCredentialsProvider, PrintWriter writer) { + if (proxyCredentialsProvider instanceof BearerCredentialsProvider) { + // Send Proxy-Authorization header if credentials are set + String bearerToken = ((BearerCredentialsProvider) proxyCredentialsProvider).getToken(); + if (bearerToken != null && !bearerToken.trim().isEmpty()) { + writer.write(PROXY_AUTHORIZATION_HEADER + ": Bearer " + bearerToken + CRLF); + } + } else if (proxyCredentialsProvider instanceof BasicCredentialsProvider) { + BasicCredentialsProvider basicCredentialsProvider = (BasicCredentialsProvider) proxyCredentialsProvider; + String userName = basicCredentialsProvider.getUsername(); + String password = basicCredentialsProvider.getPassword(); + if (userName != null && !userName.trim().isEmpty() && password != null && !password.trim().isEmpty()) { + writer.write(PROXY_AUTHORIZATION_HEADER + ": Basic " + mBase64Encoder.encode(userName + ":" + password) + CRLF); + } + } + } + + /** + * Validates CONNECT response through SSL connection. + * Only reads status line and headers, leaving the stream open for tunneling. + */ + private void validateConnectResponse(@NonNull SSLSocket sslSocket) throws IOException { + + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(sslSocket.getInputStream(), StandardCharsets.UTF_8)); + + String statusLine = reader.readLine(); + if (statusLine == null) { + throw new IOException("No CONNECT response received from proxy"); + } + + // Parse status code + String[] statusParts = statusLine.split(" "); + if (statusParts.length < 2) { + throw new IOException("Invalid CONNECT response status line: " + statusLine); + } + + int statusCode; + try { + statusCode = Integer.parseInt(statusParts[1]); + } catch (NumberFormatException e) { + throw new IOException("Invalid CONNECT response status code: " + statusLine, e); + } + + // Read headers until empty line (but don't process them for CONNECT) + String headerLine; + while ((headerLine = reader.readLine()) != null && !headerLine.trim().isEmpty()) { + // no-op + } + + // Check status code + if (statusCode != 200) { + if (statusCode == HttpURLConnection.HTTP_PROXY_AUTH) { + throw new HttpRetryException("CONNECT request failed with status " + statusCode + ": " + statusLine, HttpURLConnection.HTTP_PROXY_AUTH); + } + throw new IOException("CONNECT request failed with status " + statusCode + ": " + statusLine); + } + } catch (IOException e) { + if (e instanceof HttpRetryException) { + throw e; + } + + throw new IOException("Failed to validate CONNECT response from proxy: " + e.getMessage(), e); + } + } +} diff --git a/src/main/java/io/split/android/client/service/ServiceConstants.java b/src/main/java/io/split/android/client/service/ServiceConstants.java index ecfa4ad05..95c4886fd 100644 --- a/src/main/java/io/split/android/client/service/ServiceConstants.java +++ b/src/main/java/io/split/android/client/service/ServiceConstants.java @@ -35,6 +35,7 @@ public class ServiceConstants { public static final String WORKER_PARAM_CONFIGURED_FILTER_TYPE = "configuredFilterType"; public static final String WORKER_PARAM_FLAGS_SPEC = "flagsSpec"; public static final String WORKER_PARAM_CERTIFICATE_PINS = "certificatePins"; + public static final String WORKER_PARAM_USES_PROXY = "usesProxy"; public static final int LAST_SEEN_IMPRESSION_CACHE_SIZE = 2000; public static final int MY_SEGMENT_V2_DATA_SIZE = 1024 * 10;// bytes diff --git a/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java b/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java index 9e8f11f21..78a8f316b 100644 --- a/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java +++ b/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java @@ -36,6 +36,7 @@ public class SseClientImpl implements SseClient { private final StringHelper mStringHelper; private HttpStreamRequest mHttpStreamRequest = null; + private HttpStreamResponse mHttpStreamResponse = null; private static final String PUSH_NOTIFICATION_CHANNELS_PARAM = "channel"; private static final String PUSH_NOTIFICATION_TOKEN_PARAM = "accessToken"; @@ -71,8 +72,21 @@ public void disconnect() { private void close() { Logger.d("Disconnecting SSE client"); if (mStatus.getAndSet(DISCONNECTED) != DISCONNECTED) { + // Close the HttpStreamResponse first to clean up sockets + if (mHttpStreamResponse != null) { + try { + mHttpStreamResponse.close(); + Logger.v("HttpStreamResponse closed successfully"); + } catch (IOException e) { + Logger.w("Failed to close HttpStreamResponse: " + e.getMessage()); + } + mHttpStreamResponse = null; + } + + // Close the HttpStreamRequest if (mHttpStreamRequest != null) { mHttpStreamRequest.close(); + mHttpStreamRequest = null; } Logger.d("SSE client disconnected"); } @@ -94,9 +108,9 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { .addParameter(PUSH_NOTIFICATION_TOKEN_PARAM, rawToken) .build(); mHttpStreamRequest = mHttpClient.streamRequest(url); - HttpStreamResponse response = mHttpStreamRequest.execute(); - if (response.isSuccess()) { - bufferedReader = response.getBufferedReader(); + mHttpStreamResponse = mHttpStreamRequest.execute(); + if (mHttpStreamResponse.isSuccess()) { + bufferedReader = mHttpStreamResponse.getBufferedReader(); if (bufferedReader != null) { Logger.d("Streaming connection opened"); mStatus.set(CONNECTED); @@ -126,8 +140,8 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { throw (new IOException("Buffer is null")); } } else { - Logger.e("Streaming connection error. Http return code " + response.getHttpStatus()); - isErrorRetryable = !response.isClientRelatedError(); + Logger.e("Streaming connection error. Http return code " + mHttpStreamResponse.getHttpStatus()); + isErrorRetryable = !mHttpStreamResponse.isClientRelatedError(); } } catch (URISyntaxException e) { logError("An error has occurred while creating stream Url ", e); @@ -149,8 +163,7 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { } } - private void logError(String message, Exception e) { + private static void logError(String message, Exception e) { Logger.e(message + " : " + e.getLocalizedMessage()); } - } diff --git a/src/main/java/io/split/android/client/service/synchronizer/WorkManagerWrapper.java b/src/main/java/io/split/android/client/service/synchronizer/WorkManagerWrapper.java index c65042116..ca3761bca 100644 --- a/src/main/java/io/split/android/client/service/synchronizer/WorkManagerWrapper.java +++ b/src/main/java/io/split/android/client/service/synchronizer/WorkManagerWrapper.java @@ -147,6 +147,7 @@ private Data buildInputData(Data customData) { dataBuilder.putString(ServiceConstants.WORKER_PARAM_DATABASE_NAME, mDatabaseName); dataBuilder.putString(ServiceConstants.WORKER_PARAM_API_KEY, mApiKey); dataBuilder.putBoolean(ServiceConstants.WORKER_PARAM_ENCRYPTION_ENABLED, mSplitClientConfig.encryptionEnabled()); + dataBuilder.putBoolean(ServiceConstants.WORKER_PARAM_USES_PROXY, mSplitClientConfig.proxy() != null); if (mSplitClientConfig.certificatePinningConfiguration() != null) { try { Map> pins = mSplitClientConfig.certificatePinningConfiguration().getPins(); diff --git a/src/main/java/io/split/android/client/service/workmanager/HttpClientProvider.java b/src/main/java/io/split/android/client/service/workmanager/HttpClientProvider.java new file mode 100644 index 000000000..ed4a2cdb2 --- /dev/null +++ b/src/main/java/io/split/android/client/service/workmanager/HttpClientProvider.java @@ -0,0 +1,133 @@ +package io.split.android.client.service.workmanager; + +import androidx.annotation.Nullable; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + +import io.split.android.android_client.BuildConfig; +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.BasicCredentialsProvider; +import io.split.android.client.network.BearerCredentialsProvider; +import io.split.android.client.network.CertificatePinningConfiguration; +import io.split.android.client.network.CertificatePinningConfigurationProvider; +import io.split.android.client.network.HttpClient; +import io.split.android.client.network.HttpClientImpl; +import io.split.android.client.network.HttpProxy; +import io.split.android.client.network.SplitHttpHeadersBuilder; +import io.split.android.client.storage.cipher.SplitCipherFactory; +import io.split.android.client.storage.db.SplitRoomDatabase; +import io.split.android.client.storage.db.StorageFactory; +import io.split.android.client.storage.general.GeneralInfoStorage; +import io.split.android.client.utils.HttpProxySerializer; + +class HttpClientProvider { + + public static HttpClient buildHttpClient(String apiKey, String certPinningConfig, boolean usesProxy, SplitRoomDatabase mDatabase) { + return buildHttpClient(apiKey, buildCertPinningConfig(certPinningConfig), buildProxyConfig(usesProxy, mDatabase, apiKey)); + } + + private static HttpClient buildHttpClient(String apiKey, @Nullable CertificatePinningConfiguration certificatePinningConfiguration, HttpProxy proxyConfiguration) { + HttpClientImpl.Builder builder = new HttpClientImpl.Builder(); + + if (certificatePinningConfiguration != null) { + builder.setCertificatePinningConfiguration(certificatePinningConfiguration); + } + + if (proxyConfiguration != null) { + builder.setProxy(proxyConfiguration); + } + + HttpClient httpClient = builder + .build(); + + SplitHttpHeadersBuilder headersBuilder = new SplitHttpHeadersBuilder(); + headersBuilder.setClientVersion(BuildConfig.SPLIT_VERSION_NAME); + headersBuilder.setApiToken(apiKey); + headersBuilder.addJsonTypeHeaders(); + httpClient.addHeaders(headersBuilder.build()); + + return httpClient; + } + + @Nullable + private static CertificatePinningConfiguration buildCertPinningConfig(@Nullable String pinsJson) { + if (pinsJson == null || pinsJson.trim().isEmpty()) { + return null; + } + + return CertificatePinningConfigurationProvider.getCertificatePinningConfiguration(pinsJson); + } + + private static HttpProxy buildProxyConfig(boolean usesProxy, SplitRoomDatabase database, String apiKey) { + if (!usesProxy) { + return null; + } + + GeneralInfoStorage storage = StorageFactory.getGeneralInfoStorage(database, SplitCipherFactory.create(apiKey, true)); + HttpProxyDto proxyConfigDto = HttpProxySerializer.deserialize(storage); + + if (proxyConfigDto == null) { + return null; + } + + if (proxyConfigDto.host == null) { + return null; + } + + HttpProxy.Builder builder = HttpProxy.newBuilder(proxyConfigDto.host, proxyConfigDto.port); + + addCredentialsProvider(proxyConfigDto, builder); + addMtls(proxyConfigDto, builder); + addCaCert(proxyConfigDto, builder); + + return builder.build(); + } + + private static void addCaCert(HttpProxyDto proxyConfigDto, HttpProxy.Builder builder) { + if (proxyConfigDto.caCert != null) { + InputStream caCertStream = stringToInputStream(proxyConfigDto.caCert); + builder.proxyCacert(caCertStream); + } + } + + private static void addMtls(HttpProxyDto proxyConfigDto, HttpProxy.Builder builder) { + if (proxyConfigDto.clientCert != null && proxyConfigDto.clientKey != null) { + InputStream clientCertStream = stringToInputStream(proxyConfigDto.clientCert); + InputStream clientKeyStream = stringToInputStream(proxyConfigDto.clientKey); + builder.mtls(clientCertStream, clientKeyStream); + } + } + + private static void addCredentialsProvider(HttpProxyDto proxyConfigDto, HttpProxy.Builder builder) { + if (proxyConfigDto.username != null && proxyConfigDto.password != null) { + builder.credentialsProvider(new BasicCredentialsProvider() { + @Override + public String getUsername() { + return proxyConfigDto.username; + } + + @Override + public String getPassword() { + return proxyConfigDto.password; + } + }); + } else if (proxyConfigDto.bearerToken != null) { + builder.credentialsProvider(new BearerCredentialsProvider() { + + @Override + public String getToken() { + return proxyConfigDto.bearerToken; + } + }); + } + } + + private static InputStream stringToInputStream(String input) { + if (input == null) { + return null; + } + return new ByteArrayInputStream(input.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/src/main/java/io/split/android/client/service/workmanager/SplitWorker.java b/src/main/java/io/split/android/client/service/workmanager/SplitWorker.java index b24b231ce..812b8963f 100644 --- a/src/main/java/io/split/android/client/service/workmanager/SplitWorker.java +++ b/src/main/java/io/split/android/client/service/workmanager/SplitWorker.java @@ -3,17 +3,11 @@ import android.content.Context; import androidx.annotation.NonNull; -import androidx.annotation.Nullable; import androidx.work.Data; import androidx.work.Worker; import androidx.work.WorkerParameters; -import io.split.android.android_client.BuildConfig; -import io.split.android.client.network.CertificatePinningConfiguration; -import io.split.android.client.network.CertificatePinningConfigurationProvider; import io.split.android.client.network.HttpClient; -import io.split.android.client.network.HttpClientImpl; -import io.split.android.client.network.SplitHttpHeadersBuilder; import io.split.android.client.service.ServiceConstants; import io.split.android.client.service.executor.SplitTask; import io.split.android.client.storage.db.SplitRoomDatabase; @@ -35,7 +29,9 @@ public SplitWorker(@NonNull Context context, String apiKey = inputData.getString(ServiceConstants.WORKER_PARAM_API_KEY); mEndpoint = inputData.getString(ServiceConstants.WORKER_PARAM_ENDPOINT); mDatabase = SplitRoomDatabase.getDatabase(context, databaseName); - mHttpClient = buildHttpClient(apiKey, buildCertPinningConfig(inputData.getString(ServiceConstants.WORKER_PARAM_CERTIFICATE_PINS))); + mHttpClient = HttpClientProvider.buildHttpClient(apiKey, + inputData.getString(ServiceConstants.WORKER_PARAM_CERTIFICATE_PINS), + inputData.getBoolean(ServiceConstants.WORKER_PARAM_USES_PROXY, false), mDatabase); } @NonNull @@ -60,32 +56,4 @@ public HttpClient getHttpClient() { public String getEndPoint() { return mEndpoint; } - - private static HttpClient buildHttpClient(String apiKey, @Nullable CertificatePinningConfiguration certificatePinningConfiguration) { - HttpClientImpl.Builder builder = new HttpClientImpl.Builder(); - - if (certificatePinningConfiguration != null) { - builder.setCertificatePinningConfiguration(certificatePinningConfiguration); - } - - HttpClient httpClient = builder - .build(); - - SplitHttpHeadersBuilder headersBuilder = new SplitHttpHeadersBuilder(); - headersBuilder.setClientVersion(BuildConfig.SPLIT_VERSION_NAME); - headersBuilder.setApiToken(apiKey); - headersBuilder.addJsonTypeHeaders(); - httpClient.addHeaders(headersBuilder.build()); - - return httpClient; - } - - @Nullable - private static CertificatePinningConfiguration buildCertPinningConfig(@Nullable String pinsJson) { - if (pinsJson == null || pinsJson.trim().isEmpty()) { - return null; - } - - return CertificatePinningConfigurationProvider.getCertificatePinningConfiguration(pinsJson); - } } diff --git a/src/main/java/io/split/android/client/service/workmanager/splits/StorageProvider.java b/src/main/java/io/split/android/client/service/workmanager/splits/StorageProvider.java index d60e81967..b9eb7b5d3 100644 --- a/src/main/java/io/split/android/client/service/workmanager/splits/StorageProvider.java +++ b/src/main/java/io/split/android/client/service/workmanager/splits/StorageProvider.java @@ -14,11 +14,18 @@ class StorageProvider { private final SplitRoomDatabase mDatabase; private final boolean mShouldRecordTelemetry; private final SplitCipher mCipher; + // some values in general info storage require encryption always + private final SplitCipher mAlwaysEncryptedCipher; StorageProvider(SplitRoomDatabase database, String apiKey, boolean encryptionEnabled, boolean shouldRecordTelemetry) { mDatabase = database; mCipher = SplitCipherFactory.create(apiKey, encryptionEnabled); mShouldRecordTelemetry = shouldRecordTelemetry; + if (encryptionEnabled) { + mAlwaysEncryptedCipher = mCipher; + } else { + mAlwaysEncryptedCipher = SplitCipherFactory.create(apiKey, true); + } } SplitsStorage provideSplitsStorage() { @@ -40,6 +47,6 @@ RuleBasedSegmentStorageProducer provideRuleBasedSegmentStorage() { } GeneralInfoStorage provideGeneralInfoStorage() { - return StorageFactory.getGeneralInfoStorage(mDatabase); + return StorageFactory.getGeneralInfoStorage(mDatabase, mAlwaysEncryptedCipher); } } diff --git a/src/main/java/io/split/android/client/storage/db/StorageFactory.java b/src/main/java/io/split/android/client/storage/db/StorageFactory.java index 6dc6e4c4d..1d591e551 100644 --- a/src/main/java/io/split/android/client/storage/db/StorageFactory.java +++ b/src/main/java/io/split/android/client/storage/db/StorageFactory.java @@ -153,8 +153,8 @@ public static PersistentImpressionsObserverCacheStorage getImpressionsObserverCa return new SqlitePersistentImpressionsObserverCacheStorage(splitRoomDatabase.impressionsObserverCacheDao(), expirationPeriod, executorService); } - public static GeneralInfoStorage getGeneralInfoStorage(SplitRoomDatabase splitRoomDatabase) { - return new GeneralInfoStorageImpl(splitRoomDatabase.generalInfoDao()); + public static GeneralInfoStorage getGeneralInfoStorage(SplitRoomDatabase splitRoomDatabase, SplitCipher splitCipher) { + return new GeneralInfoStorageImpl(splitRoomDatabase.generalInfoDao(), splitCipher); } public static PersistentRuleBasedSegmentStorage getPersistentRuleBasedSegmentStorage(SplitRoomDatabase splitRoomDatabase, SplitCipher splitCipher, GeneralInfoStorage generalInfoStorage) { @@ -163,7 +163,7 @@ public static PersistentRuleBasedSegmentStorage getPersistentRuleBasedSegmentSto public static RuleBasedSegmentStorageProducer getRuleBasedSegmentStorageForWorker(SplitRoomDatabase splitRoomDatabase, SplitCipher splitCipher) { PersistentRuleBasedSegmentStorage persistentRuleBasedSegmentStorage = - new SqLitePersistentRuleBasedSegmentStorageProvider(splitCipher, splitRoomDatabase, getGeneralInfoStorage(splitRoomDatabase)).get(); + new SqLitePersistentRuleBasedSegmentStorageProvider(splitCipher, splitRoomDatabase, getGeneralInfoStorage(splitRoomDatabase, null)).get(); return new RuleBasedSegmentStorageProducerImpl(persistentRuleBasedSegmentStorage, new ConcurrentHashMap<>(), new AtomicLong(-1)); } } diff --git a/src/main/java/io/split/android/client/storage/general/GeneralInfoStorage.java b/src/main/java/io/split/android/client/storage/general/GeneralInfoStorage.java index 87a6a55ec..b3ad1215c 100644 --- a/src/main/java/io/split/android/client/storage/general/GeneralInfoStorage.java +++ b/src/main/java/io/split/android/client/storage/general/GeneralInfoStorage.java @@ -38,4 +38,9 @@ public interface GeneralInfoStorage { void setLastProxyUpdateTimestamp(long timestamp); long getLastProxyUpdateTimestamp(); + + @Nullable + String getProxyConfig(); + + void setProxyConfig(@Nullable String proxyConfig); } diff --git a/src/main/java/io/split/android/client/storage/general/GeneralInfoStorageImpl.java b/src/main/java/io/split/android/client/storage/general/GeneralInfoStorageImpl.java index c351d9a48..304729d57 100644 --- a/src/main/java/io/split/android/client/storage/general/GeneralInfoStorageImpl.java +++ b/src/main/java/io/split/android/client/storage/general/GeneralInfoStorageImpl.java @@ -5,6 +5,7 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import io.split.android.client.storage.cipher.SplitCipher; import io.split.android.client.storage.db.GeneralInfoDao; import io.split.android.client.storage.db.GeneralInfoEntity; @@ -13,11 +14,14 @@ public class GeneralInfoStorageImpl implements GeneralInfoStorage { private static final String ROLLOUT_CACHE_LAST_CLEAR_TIMESTAMP = "rolloutCacheLastClearTimestamp"; private static final String RBS_CHANGE_NUMBER = "rbsChangeNumber"; private static final String LAST_PROXY_CHECK_TIMESTAMP = "lastProxyCheckTimestamp"; + private static final String PROXY_CONFIG = "proxyConfig"; private final GeneralInfoDao mGeneralInfoDao; + private final SplitCipher mAlwaysEncryptedSplitCipher; - public GeneralInfoStorageImpl(GeneralInfoDao generalInfoDao) { + public GeneralInfoStorageImpl(GeneralInfoDao generalInfoDao, @Nullable SplitCipher splitCipher) { mGeneralInfoDao = checkNotNull(generalInfoDao); + mAlwaysEncryptedSplitCipher = splitCipher; } @Override @@ -109,4 +113,27 @@ public long getLastProxyUpdateTimestamp() { GeneralInfoEntity entity = mGeneralInfoDao.getByName(LAST_PROXY_CHECK_TIMESTAMP); return entity != null ? entity.getLongValue() : 0L; } + + @Nullable + @Override + public String getProxyConfig() { + GeneralInfoEntity entity = mGeneralInfoDao.getByName(PROXY_CONFIG); + if (entity == null) { + return null; + } + + if (mAlwaysEncryptedSplitCipher != null) { + return mAlwaysEncryptedSplitCipher.decrypt(entity.getStringValue()); + } + + return entity.getStringValue(); + } + + @Override + public void setProxyConfig(@Nullable String proxyConfig) { + if (mAlwaysEncryptedSplitCipher != null) { + proxyConfig = mAlwaysEncryptedSplitCipher.encrypt(proxyConfig); + } + mGeneralInfoDao.update(new GeneralInfoEntity(PROXY_CONFIG, proxyConfig)); + } } diff --git a/src/main/java/io/split/android/client/utils/HttpProxySerializer.java b/src/main/java/io/split/android/client/utils/HttpProxySerializer.java new file mode 100644 index 000000000..889f9344a --- /dev/null +++ b/src/main/java/io/split/android/client/utils/HttpProxySerializer.java @@ -0,0 +1,42 @@ +package io.split.android.client.utils; + +import androidx.annotation.Nullable; + +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.HttpProxy; +import io.split.android.client.storage.general.GeneralInfoStorage; + +/** + * Utility class for serializing and deserializing HttpProxy objects. + */ +public class HttpProxySerializer { + + private HttpProxySerializer() { + } + + @Nullable + public static String serialize(@Nullable HttpProxy httpProxy) { + if (httpProxy == null) { + return null; + } + HttpProxyDto dto = new HttpProxyDto(httpProxy); + return Json.toJson(dto); + } + + + @Nullable + public static HttpProxyDto deserialize(GeneralInfoStorage storage) { + if (storage == null) { + return null; + } + String json = storage.getProxyConfig(); + if (json == null || json.isEmpty()) { + return null; + } + try { + return Json.fromJson(json, HttpProxyDto.class); + } catch (Exception e) { + return null; + } + } +} diff --git a/src/test/java/io/split/android/client/SplitClientConfigTest.java b/src/test/java/io/split/android/client/SplitClientConfigTest.java index 1e69b9c12..8d796a758 100644 --- a/src/test/java/io/split/android/client/SplitClientConfigTest.java +++ b/src/test/java/io/split/android/client/SplitClientConfigTest.java @@ -6,6 +6,7 @@ import static junit.framework.TestCase.assertTrue; import androidx.annotation.NonNull; +import androidx.annotation.Nullable; import org.junit.Test; @@ -14,6 +15,9 @@ import java.util.concurrent.TimeUnit; import io.split.android.client.network.CertificatePinningConfiguration; +import io.split.android.client.network.ProxyConfiguration; +import io.split.android.client.network.SplitAuthenticatedRequest; +import io.split.android.client.network.SplitAuthenticator; import io.split.android.client.utils.logger.LogPrinter; import io.split.android.client.utils.logger.Logger; import io.split.android.client.utils.logger.SplitLogLevel; @@ -256,6 +260,47 @@ public void nullRolloutCacheConfigurationSetsDefault() { assertEquals(1, logMessages.size()); } + @Test + public void proxyHostAndProxyConfigurationSetLogWarning() { + Queue logMessages = getLogMessagesQueue(); + SplitClientConfig.builder() + .logLevel(SplitLogLevel.WARNING) + .proxyHost("proxyHost") + .proxyConfiguration(ProxyConfiguration.builder().url("http://proxy.url").build()) + .build(); + assertEquals(1, logMessages.size()); + assertEquals("Both the deprecated proxy configuration methods (proxyHost, proxyAuthenticator) and the new ProxyConfiguration builder are being used. ProxyConfiguration will take precedence.", logMessages.poll()); + } + + @Test + public void proxyAuthenticatorAndProxyConfigurationSetLogWarning() { + Queue logMessages = getLogMessagesQueue(); + SplitClientConfig.builder() + .logLevel(SplitLogLevel.WARNING) + .proxyAuthenticator(new SplitAuthenticator() { + @Nullable + @Override + public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest request) { + return null; + } + }) + .proxyConfiguration(ProxyConfiguration.builder().url("http://proxy.url").build()) + .build(); + assertEquals(1, logMessages.size()); + assertEquals("Both the deprecated proxy configuration methods (proxyHost, proxyAuthenticator) and the new ProxyConfiguration builder are being used. ProxyConfiguration will take precedence.", logMessages.poll()); + } + + @Test + public void proxyConfigurationWithNoUrlSetLogWarning() { + Queue logMessages = getLogMessagesQueue(); + SplitClientConfig.builder() + .logLevel(SplitLogLevel.WARNING) + .proxyConfiguration(ProxyConfiguration.builder().build()) + .build(); + assertEquals(1, logMessages.size()); + assertEquals("Proxy configuration with no URL. This will prevent SplitFactory from working.", logMessages.poll()); + } + @NonNull private static Queue getLogMessagesQueue() { Queue logMessages = new LinkedList<>(); diff --git a/src/test/java/io/split/android/client/SplitFactoryHelperTest.kt b/src/test/java/io/split/android/client/SplitFactoryHelperTest.kt index 04cc76c80..0a0f46ec4 100644 --- a/src/test/java/io/split/android/client/SplitFactoryHelperTest.kt +++ b/src/test/java/io/split/android/client/SplitFactoryHelperTest.kt @@ -2,10 +2,13 @@ package io.split.android.client import android.content.Context import io.split.android.client.SplitFactoryHelper.Initializer.Listener +import io.split.android.client.api.Key import io.split.android.client.events.EventsManagerCoordinator import io.split.android.client.events.SplitInternalEvent +import io.split.android.client.exceptions.SplitInstantiationException import io.split.android.client.lifecycle.SplitLifecycleManager -import io.split.android.client.service.CleanUpDatabaseTask +import io.split.android.client.network.HttpProxy +import io.split.android.client.network.ProxyConfiguration import io.split.android.client.service.executor.SplitSingleThreadTaskExecutor import io.split.android.client.service.executor.SplitTaskExecutionInfo import io.split.android.client.service.executor.SplitTaskExecutionListener @@ -13,7 +16,12 @@ import io.split.android.client.service.executor.SplitTaskExecutor import io.split.android.client.service.executor.SplitTaskType import io.split.android.client.service.synchronizer.RolloutCacheManager import io.split.android.client.service.synchronizer.SyncManager +import io.split.android.client.service.synchronizer.WorkManagerWrapper +import io.split.android.client.storage.general.GeneralInfoStorage +import io.split.android.client.utils.HttpProxySerializer import junit.framework.TestCase.assertEquals +import junit.framework.TestCase.assertFalse +import junit.framework.TestCase.assertTrue import org.junit.After import org.junit.Before import org.junit.Test @@ -21,11 +29,12 @@ import org.mockito.Mock import org.mockito.Mockito import org.mockito.Mockito.any import org.mockito.Mockito.mock +import org.mockito.Mockito.mockStatic +import org.mockito.Mockito.never import org.mockito.Mockito.verify import org.mockito.Mockito.`when` import org.mockito.MockitoAnnotations import java.io.File -import java.lang.IllegalArgumentException import java.util.concurrent.locks.ReentrantLock class SplitFactoryHelperTest { @@ -184,4 +193,131 @@ class SplitFactoryHelperTest { verify(lifecycleManager).register(syncManager) verify(initLock).unlock() } + + @Test + fun `initializing with proxy config with null url throws`() { + var exceptionThrown = false + try { + SplitFactoryBuilder.build("sdk_key", Key("user"), SplitClientConfig.builder().proxyConfiguration( + ProxyConfiguration.builder().build()).build(), context) + } catch (splitInstantiationException: SplitInstantiationException) { + exceptionThrown = (splitInstantiationException.message ?: "").contains("When configured, proxy host cannot be null") + } + + assertTrue(exceptionThrown) + } + + @Test + fun `initializing with proxy config with valid url does not throw`() { + var exceptionThrown = false + try { + SplitFactoryBuilder.build("sdk_key", Key("user"), SplitClientConfig.builder().proxyConfiguration( + ProxyConfiguration.builder().url("http://localhost:8080").build()).build(), context) + } catch (splitInstantiationException: SplitInstantiationException) { + exceptionThrown = (splitInstantiationException.message ?: "").contains("When configured, proxy host cannot be null") + } + + assertFalse(exceptionThrown) + } + + @Test + fun `setupProxyForBackgroundSync should start thread when proxy is not null, not legacy, and background sync is enabled`() { + val httpProxy = mock(HttpProxy::class.java) + val config = mock(SplitClientConfig::class.java) + val proxyConfigSaveTask = mock(Runnable::class.java) + + `when`(config.proxy()).thenReturn(httpProxy) + `when`(httpProxy.isLegacy()).thenReturn(false) + `when`(config.synchronizeInBackground()).thenReturn(true) + + SplitFactoryHelper.setupProxyForBackgroundSync(config, proxyConfigSaveTask) + + Thread.sleep(100) + verify(proxyConfigSaveTask).run() + } + + @Test + fun `setupProxyForBackgroundSync should not start thread when proxy is null`() { + val config = mock(SplitClientConfig::class.java) + val proxyConfigSaveTask = mock(Runnable::class.java) + + `when`(config.proxy()).thenReturn(null) + `when`(config.synchronizeInBackground()).thenReturn(true) + + SplitFactoryHelper.setupProxyForBackgroundSync(config, proxyConfigSaveTask) + + Thread.sleep(100) + verify(proxyConfigSaveTask, never()).run() + } + + @Test + fun `setupProxyForBackgroundSync should not start thread when proxy is legacy`() { + val httpProxy = mock(HttpProxy::class.java) + val config = mock(SplitClientConfig::class.java) + val proxyConfigSaveTask = mock(Runnable::class.java) + + `when`(config.proxy()).thenReturn(httpProxy) + `when`(httpProxy.isLegacy()).thenReturn(true) + `when`(config.synchronizeInBackground()).thenReturn(true) + + SplitFactoryHelper.setupProxyForBackgroundSync(config, proxyConfigSaveTask) + + Thread.sleep(100) // Give time to ensure no thread was started + } + + @Test + fun `setupProxyForBackgroundSync should not start thread when background sync is disabled`() { + val httpProxy = mock(HttpProxy::class.java) + val config = mock(SplitClientConfig::class.java) + val proxyConfigSaveTask = mock(Runnable::class.java) + + `when`(config.proxy()).thenReturn(httpProxy) + `when`(httpProxy.isLegacy()).thenReturn(false) + `when`(config.synchronizeInBackground()).thenReturn(false) + + SplitFactoryHelper.setupProxyForBackgroundSync(config, proxyConfigSaveTask) + + Thread.sleep(100) // Give time to ensure no thread was started + } + + @Test + fun `getProxyConfigSaveTask should return runnable that saves proxy config`() { + val config = mock(SplitClientConfig::class.java) + val httpProxy = mock(HttpProxy::class.java) + val workManagerWrapper = mock(WorkManagerWrapper::class.java) + val generalInfoStorage = mock(GeneralInfoStorage::class.java) + val serializedProxy = "serialized_proxy_json" + + `when`(config.proxy()).thenReturn(httpProxy) + + mockStatic(HttpProxySerializer::class.java).use { mockedSerializer -> + mockedSerializer.`when` { HttpProxySerializer.serialize(httpProxy) }.thenReturn(serializedProxy) + + val runnable = SplitFactoryHelper.getProxyConfigSaveTask(config, workManagerWrapper, generalInfoStorage) + runnable.run() + + verify(generalInfoStorage).setProxyConfig(serializedProxy) + verify(workManagerWrapper, never()).removeWork() + } + } + + @Test + fun `getProxyConfigSaveTask should handle exceptions and disable background sync`() { + val config = mock(SplitClientConfig::class.java) + val httpProxy = mock(HttpProxy::class.java) + val workManagerWrapper = mock(WorkManagerWrapper::class.java) + val generalInfoStorage = mock(GeneralInfoStorage::class.java) + + `when`(config.proxy()).thenReturn(httpProxy) + + mockStatic(HttpProxySerializer::class.java).use { mockedSerializer -> + mockedSerializer.`when` { HttpProxySerializer.serialize(httpProxy) }.thenThrow(RuntimeException("Test exception")) + + val runnable = SplitFactoryHelper.getProxyConfigSaveTask(config, workManagerWrapper, generalInfoStorage) + runnable.run() + + verify(generalInfoStorage, never()).setProxyConfig(any()) + verify(workManagerWrapper).removeWork() + } + } } diff --git a/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java b/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java new file mode 100644 index 000000000..738300ce7 --- /dev/null +++ b/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java @@ -0,0 +1,47 @@ +package io.split.android.client.network; + +import static org.mockito.Mockito.mockStatic; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockedStatic; + +import java.nio.charset.StandardCharsets; + +import io.split.android.client.utils.Base64Util; + +public class DefaultBase64EncoderTest { + + private DefaultBase64Encoder encoder; + private MockedStatic mockedBase64Util; + + @Before + public void setUp() { + encoder = new DefaultBase64Encoder(); + mockedBase64Util = mockStatic(Base64Util.class); + } + + @After + public void tearDown() { + mockedBase64Util.close(); + } + + @Test + public void encodeStringUsesBase64Util() { + String input = "test string"; + + encoder.encode(input); + + mockedBase64Util.verify(() -> Base64Util.encode(input)); + } + + @Test + public void encodeByteArrayUsesBase64Util() { + byte[] input = "test bytes".getBytes(StandardCharsets.UTF_8); + + encoder.encode(input); + + mockedBase64Util.verify(() -> Base64Util.encode(input)); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpClientTest.java b/src/test/java/io/split/android/client/network/HttpClientTest.java index c78b52696..3ecec1cc3 100644 --- a/src/test/java/io/split/android/client/network/HttpClientTest.java +++ b/src/test/java/io/split/android/client/network/HttpClientTest.java @@ -23,11 +23,14 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.lang.reflect.Type; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -52,12 +55,12 @@ public class HttpClientTest { private MockWebServer mWebServer; private MockWebServer mProxyServer; private HttpClient client; - private UrlSanitizer mUrlSanitizer; + private UrlSanitizer mUrlSanitizerMock; @Before public void setup() throws IOException { - mUrlSanitizer = mock(UrlSanitizer.class); - when(mUrlSanitizer.getUrl(any())).thenAnswer(new Answer() { + mUrlSanitizerMock = mock(UrlSanitizer.class); + when(mUrlSanitizerMock.getUrl(any())).thenAnswer(new Answer() { @Override public URL answer(InvocationOnMock invocation) throws Throwable { URI argument = invocation.getArgument(0); @@ -219,7 +222,7 @@ public void addHeaders() throws InterruptedException, URISyntaxException, HttpEx } @Test - public void addStreamingHeaders() throws InterruptedException, URISyntaxException, HttpException { + public void addStreamingHeaders() throws InterruptedException, HttpException, IOException { client.addStreamingHeaders(Collections.singletonMap("my_header", "my_header_value")); HttpUrl url = mWebServer.url("/test1/"); @@ -277,8 +280,8 @@ public MockResponse dispatch(RecordedRequest request) { HttpClient client = new HttpClientImpl.Builder() .setContext(mock(Context.class)) - .setUrlSanitizer(mUrlSanitizer) - .setProxy(new HttpProxy(mProxyServer.getHostName(), mProxyServer.getPort())) + .setUrlSanitizer(mUrlSanitizerMock) + .setProxy(HttpProxy.newBuilder(mProxyServer.getHostName(), mProxyServer.getPort()).build()) .build(); HttpRequest request = client.request(mWebServer.url("/test1/").uri(), HttpMethod.GET); @@ -312,7 +315,7 @@ public MockResponse dispatch(RecordedRequest request) { HttpClient client = new HttpClientImpl.Builder() .setContext(mock(Context.class)) - .setUrlSanitizer(mUrlSanitizer) + .setUrlSanitizer(mUrlSanitizerMock) .setProxyAuthenticator(new SplitAuthenticator() { @Override public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest request) { @@ -322,7 +325,7 @@ public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest return request; } }) - .setProxy(new HttpProxy(mProxyServer.getHostName(), mProxyServer.getPort())) + .setProxy(HttpProxy.newBuilder(mProxyServer.getHostName(), mProxyServer.getPort()).build()) .build(); HttpRequest request = client.request(mWebServer.url("/test1/").uri(), HttpMethod.GET); @@ -367,7 +370,7 @@ public MockResponse dispatch(RecordedRequest request) { HttpClient client = new HttpClientImpl.Builder() .setContext(mock(Context.class)) - .setUrlSanitizer(mUrlSanitizer) + .setUrlSanitizer(mUrlSanitizerMock) .setProxyAuthenticator(new SplitAuthenticator() { @Override public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest request) { @@ -377,7 +380,7 @@ public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest return request; } }) - .setProxy(new HttpProxy(mProxyServer.getHostName(), mProxyServer.getPort())) + .setProxy(HttpProxy.newBuilder(mProxyServer.getHostName(), mProxyServer.getPort()).build()) .build(); HttpRequest request = client.request(mWebServer.url("/test1/").uri(), HttpMethod.POST, "{}"); @@ -400,6 +403,79 @@ public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest mProxyServer.shutdown(); } + + @Test + public void copyStreamToByteArrayWithSimpleString() { + String testString = "Test string content"; + InputStream inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)); + + byte[] result = HttpClientImpl.copyStreamToByteArray(inputStream); + + assertNotNull("Result should not be null", result); + assertEquals("Result should match original string", testString, new String(result, StandardCharsets.UTF_8)); + + byte[] buffer = new byte[100]; + try { + int bytesRead = inputStream.read(buffer); + assertEquals("Stream should be readable and contain the same content", testString, + new String(buffer, 0, bytesRead, StandardCharsets.UTF_8)); + } catch (IOException e) { + Assert.fail("Should be able to read from stream after copying: " + e.getMessage()); + } + } + + @Test + public void copyStreamToByteArrayWithEmptyStream() { + InputStream emptyStream = new ByteArrayInputStream(new byte[0]); + + byte[] result = HttpClientImpl.copyStreamToByteArray(emptyStream); + + assertNotNull("Result should not be null even for empty stream", result); + assertEquals("Result should be empty array", 0, result.length); + } + + @Test + public void copyStreamToByteArrayWithNullStream() { + byte[] result = HttpClientImpl.copyStreamToByteArray(null); + + assertNull("Result should be null for null input", result); + } + + @Test + public void copyStreamToByteArrayWithNonMarkableStream() { + InputStream nonMarkableStream = new InputStream() { + private final byte[] data = "Test data".getBytes(StandardCharsets.UTF_8); + private int position = 0; + + @Override + public int read() { + if (position < data.length) { + return data[position++] & 0xff; + } + return -1; + } + + @Override + public boolean markSupported() { + return false; + } + }; + + byte[] result = HttpClientImpl.copyStreamToByteArray(nonMarkableStream); + + assertNotNull("Result should not be null", result); + assertEquals("Result should match original content", "Test data", + new String(result, StandardCharsets.UTF_8)); + + int nextByte = -1; + try { + nextByte = nonMarkableStream.read(); + } catch (IOException e) { + Assert.fail("Reading from stream should not throw exception"); + } + assertEquals("Stream should be at EOF", -1, nextByte); + } + @After public void tearDown() throws IOException { mWebServer.shutdown(); @@ -456,7 +532,7 @@ public MockResponse dispatch(RecordedRequest request) throws InterruptedExceptio mWebServer.start(); client = new HttpClientImpl.Builder() - .setUrlSanitizer(mUrlSanitizer) + .setUrlSanitizer(mUrlSanitizerMock) .build(); } diff --git a/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java b/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java new file mode 100644 index 000000000..f6f3454dd --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java @@ -0,0 +1,657 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.URI; +import java.net.URL; +import java.nio.file.Files; +import java.security.KeyStore; +import java.util.Base64; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import okhttp3.tls.HeldCertificate; + +public class HttpClientTunnellingProxyTest { + + private UrlSanitizer mUrlSanitizerMock; + private Base64Decoder mBas64Decoder; + + @Before + public void setUp() { + + // override the default hostname verifier for testing + HttpsURLConnection.setDefaultHostnameVerifier(new HostnameVerifier() { + @Override + public boolean verify(String hostname, SSLSession sslSession) { + return true; + } + }); + mUrlSanitizerMock = mock(UrlSanitizer.class); + when(mUrlSanitizerMock.getUrl(any())).thenAnswer(new Answer() { + @Override + public URL answer(InvocationOnMock invocation) throws Throwable { + URI argument = invocation.getArgument(0); + + return new URL(argument.toString()); + } + }); + mBas64Decoder = new Base64Decoder() { + @Override + public byte[] decode(String base64) { + return Base64.getDecoder().decode(base64); + } + }; + } + + @Test + public void proxyCacertProxyTunnelling() throws Exception { + // 1. Create separate CA and server certs for proxy and origin + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + HeldCertificate originCa = new HeldCertificate.Builder() + .commonName("Test Origin CA") + .certificateAuthority(0) + .build(); + HeldCertificate originCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(originCa) + .build(); + + // 2. Start HTTP origin server (not HTTPS to avoid SSL layering issues) + MockWebServer originServer = new MockWebServer(); + CountDownLatch originLatch = new CountDownLatch(1); + final String[] methodAndPath = new String[2]; + originServer.setDispatcher(new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + methodAndPath[0] = request.getMethod(); + methodAndPath[1] = request.getPath(); + originLatch.countDown(); + return new MockResponse().setBody("from origin!"); + } + }); + // Use HTTP instead of HTTPS to test tunnel establishment without SSL layering issues + originServer.start(); + + // 3. Start SSL tunnel proxy (server-only SSL, no client cert required) + TunnelProxySslServerOnly tunnelProxy = new TunnelProxySslServerOnly(0, proxyServerCert); + tunnelProxy.start(); + while (tunnelProxy.mServerSocket == null || tunnelProxy.mServerSocket.getLocalPort() == 0) { + Thread.sleep(10); + } + int assignedProxyPort = tunnelProxy.mServerSocket.getLocalPort(); + + // 4. Write BOTH proxy CA and origin CA certs to temp file (for combined trust store) + File caCertFile = File.createTempFile("proxy-ca", ".pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write(proxyCa.certificatePem()); + writer.write(originCa.certificatePem()); + } + + // 5. Configure HttpProxy with PROXY_CACERT + HttpProxy proxy = HttpProxy.newBuilder("localhost", assignedProxyPort) + .proxyCacert(Files.newInputStream(caCertFile.toPath())) + .build(); + + // 6. Build client (let builder/factory handle trust) + HttpClient client = new HttpClientImpl.Builder() + .setProxy(proxy) + .setBase64Decoder(mBas64Decoder) + .setUrlSanitizer(mUrlSanitizerMock) + .build(); + + // 7. Make a request to the origin server (should tunnel via SSL proxy) + URI uri = originServer.url("/test").uri(); + HttpRequest req = client.request(uri, HttpMethod.GET); + HttpResponse resp = req.execute(); + assertNotNull(resp); + assertEquals(200, resp.getHttpStatus()); + assertEquals("from origin!", resp.getData()); + + // Assert that the tunnel was established and the origin received the request + assertTrue("TunnelProxy did not tunnel the request in time", tunnelProxy.getTunnelLatch().await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertTrue("Origin server did not receive the request in time", originLatch.await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertEquals("GET", methodAndPath[0]); + assertEquals("/test", methodAndPath[1]); + + tunnelProxy.stopProxy(); + originServer.shutdown(); + } + + @Test + public void proxyCacertProxyTunnelling_SslOverSsl() throws Exception { + // 1. Create separate CA and server certs for proxy and origin + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + HeldCertificate originCa = new HeldCertificate.Builder() + .commonName("Test Origin CA") + .certificateAuthority(0) + .build(); + HeldCertificate originCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(originCa) + .build(); + + // 2. Start HTTPS origin server + MockWebServer originServer = new MockWebServer(); + originServer.useHttps(createSslSocketFactory(originCert), false); // Use HTTPS for SSL-over-SSL + CountDownLatch originLatch = new CountDownLatch(1); + final String[] methodAndPath = new String[2]; + originServer.setDispatcher(new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + methodAndPath[0] = request.getMethod(); + methodAndPath[1] = request.getPath(); + originLatch.countDown(); + return new MockResponse().setBody("from https origin!"); + } + }); + originServer.start(); + + // 3. Start SSL tunnel proxy (server-only SSL, no client cert required) + TunnelProxySslServerOnly tunnelProxy = new TunnelProxySslServerOnly(0, proxyServerCert); + tunnelProxy.start(); + while (tunnelProxy.mServerSocket == null || tunnelProxy.mServerSocket.getLocalPort() == 0) { + Thread.sleep(10); + } + int assignedProxyPort = tunnelProxy.mServerSocket.getLocalPort(); + + // 4. Write BOTH proxy CA and origin CA certs to temp file (for combined trust store) + File caCertFile = File.createTempFile("proxy-ca", ".pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write(proxyCa.certificatePem()); + writer.write(originCa.certificatePem()); // Client needs to trust origin CA as well + } + + // 5. Configure HttpProxy with PROXY_CACERT + HttpProxy proxy = HttpProxy.newBuilder("localhost", assignedProxyPort) + .proxyCacert(Files.newInputStream(caCertFile.toPath())) + .build(); + + // 6. Build client (let builder/factory handle trust) + HttpClient client = new HttpClientImpl.Builder() + .setProxy(proxy) + .setBase64Decoder(mBas64Decoder) + .setUrlSanitizer(mUrlSanitizerMock) + .build(); + + // 7. Make a request to the HTTPS origin server (should tunnel via SSL proxy) + URI uri = originServer.url("/test").uri(); + HttpRequest req = client.request(uri, HttpMethod.GET); + HttpResponse resp = req.execute(); + assertNotNull(resp); + assertEquals(200, resp.getHttpStatus()); + assertEquals("from https origin!", resp.getData()); + + // Assert that the tunnel was established and the origin received the request + assertTrue("TunnelProxy did not tunnel the request in time", tunnelProxy.getTunnelLatch().await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertTrue("Origin server did not receive the request in time", originLatch.await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertEquals("GET", methodAndPath[0]); + assertEquals("/test", methodAndPath[1]); + + tunnelProxy.stopProxy(); + originServer.shutdown(); + } + + /** + * Negative test: mTLS proxy requires client certificate, but client presents none. + * The proxy should reject the connection, and the client should throw SSLHandshakeException. + */ + @Test + public void proxyMtlsProxyTunnelling_rejectsNoClientCert() throws Exception { + // 1. Create CA, proxy/server, and origin certs + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + HeldCertificate originCa = new HeldCertificate.Builder() + .commonName("Test Origin CA") + .certificateAuthority(0) + .build(); + HeldCertificate originCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(originCa) + .build(); + + // Write proxy server cert and key to temp files + File proxyCertFile = File.createTempFile("proxy-server", ".crt"); + File proxyKeyFile = File.createTempFile("proxy-server", ".key"); + try (FileWriter writer = new FileWriter(proxyCertFile)) { + writer.write(proxyServerCert.certificatePem()); + } + try (FileWriter writer = new FileWriter(proxyKeyFile)) { + writer.write(proxyServerCert.privateKeyPkcs8Pem()); + } + // Write proxy CA cert (for client auth) to temp file + File proxyCaFile = File.createTempFile("proxy-ca", ".crt"); + try (FileWriter writer = new FileWriter(proxyCaFile)) { + writer.write(proxyCa.certificatePem()); + } + + // 2. Start HTTPS origin server + MockWebServer originServer = new MockWebServer(); + originServer.useHttps(createSslSocketFactory(originCert), false); + originServer.start(); + + // 3. Start mTLS tunnel proxy + TunnelProxySsl tunnelProxy = new TunnelProxySsl(0, proxyServerCert, proxyCa); + tunnelProxy.start(); + while (tunnelProxy.mServerSocket == null || tunnelProxy.mServerSocket.getLocalPort() == 0) { + Thread.sleep(10); + } + int assignedProxyPort = tunnelProxy.mServerSocket.getLocalPort(); + + // 4. Configure HttpProxy WITHOUT client cert (should be rejected) + HttpProxy proxy = HttpProxy.newBuilder("localhost", assignedProxyPort) + .proxyCacert(Files.newInputStream(proxyCaFile.toPath())) // only trust, no client auth + .build(); + + // 5. Build client (let builder/factory handle trust) + HttpClient client = new HttpClientImpl.Builder() + .setProxy(proxy) + .setBase64Decoder(mBas64Decoder) + .setUrlSanitizer(mUrlSanitizerMock) + .build(); + + // 6. Make a request to the origin server (should fail at proxy handshake) + URI uri = originServer.url("/test").uri(); + HttpRequest req = client.request(uri, HttpMethod.GET); + boolean handshakeFailed = false; + try { + req.execute(); + } catch (Exception e) { + handshakeFailed = true; + } + assertTrue("Expected SSL handshake to fail due to missing client certificate", handshakeFailed); + + tunnelProxy.stopProxy(); + tunnelProxy.join(); + originServer.shutdown(); + proxyCertFile.delete(); + proxyKeyFile.delete(); + proxyCaFile.delete(); + } + + /** + * Positive test: mTLS proxy requires client certificate, and client presents a valid certificate. + * The proxy should accept the connection, tunnel should be established, and the request should reach the origin. + */ + @Test + public void proxyMtlsProxyTunnelling() throws Exception { + // 1. Create CA, proxy/server, client, and origin certs + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + HeldCertificate clientCert = new HeldCertificate.Builder() + .commonName("Test Client") + .signedBy(proxyCa) + .build(); + HeldCertificate originCa = new HeldCertificate.Builder() + .commonName("Test Origin CA") + .certificateAuthority(0) + .build(); + HeldCertificate originCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(originCa) + .build(); + + // Write proxy server cert and key to temp files + File proxyCertFile = File.createTempFile("proxy-server", ".crt"); + File proxyKeyFile = File.createTempFile("proxy-server", ".key"); + try (FileWriter writer = new FileWriter(proxyCertFile)) { + writer.write(proxyServerCert.certificatePem()); + } + try (FileWriter writer = new FileWriter(proxyKeyFile)) { + writer.write(proxyServerCert.privateKeyPkcs8Pem()); + } + // Write proxy CA cert (for client auth) to temp file + File proxyCaFile = File.createTempFile("proxy-ca", ".crt"); + try (FileWriter writer = new FileWriter(proxyCaFile)) { + writer.write(proxyCa.certificatePem()); + } + + // Write client certificate and key to separate files (PEM format) + File clientCertFile = File.createTempFile("client", ".crt"); + File clientKeyFile = File.createTempFile("client", ".key"); + try (FileWriter writer = new FileWriter(clientCertFile)) { + writer.write(clientCert.certificatePem()); + } + try (FileWriter writer = new FileWriter(clientKeyFile)) { + writer.write(clientCert.privateKeyPkcs8Pem()); + } + + // 2. Start HTTP origin server (not HTTPS to avoid SSL layering issues) + MockWebServer originServer = new MockWebServer(); + CountDownLatch originLatch = new CountDownLatch(1); + final String[] methodAndPath = new String[2]; + originServer.setDispatcher(new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + methodAndPath[0] = request.getMethod(); + methodAndPath[1] = request.getPath(); + originLatch.countDown(); + return new MockResponse().setBody("from origin!"); + } + }); + // Use HTTP instead of HTTPS to test tunnel establishment without SSL layering issues + originServer.start(); + + // 3. Start mTLS tunnel proxy + TunnelProxySsl tunnelProxy = new TunnelProxySsl(0, proxyServerCert, proxyCa); + tunnelProxy.start(); + while (tunnelProxy.mServerSocket == null || tunnelProxy.mServerSocket.getLocalPort() == 0) { + Thread.sleep(10); + } + int assignedProxyPort = tunnelProxy.mServerSocket.getLocalPort(); + + // 4. Configure HttpProxy with mTLS (client cert, key, and CA) + HttpProxy proxy = HttpProxy.newBuilder("localhost", assignedProxyPort) + .mtls( + Files.newInputStream(clientCertFile.toPath()), + Files.newInputStream(clientKeyFile.toPath()) + ) + .proxyCacert(Files.newInputStream(proxyCaFile.toPath())) + .build(); + + // 5. Build client (let builder/factory handle trust) + HttpClient client = new HttpClientImpl.Builder() + .setProxy(proxy) + .setBase64Decoder(mBas64Decoder) + .setUrlSanitizer(mUrlSanitizerMock) + .build(); + + // 6. Make a request to the origin server (should tunnel via proxy) + URI uri = originServer.url("/test").uri(); + HttpRequest req = client.request(uri, HttpMethod.GET); + HttpResponse resp = req.execute(); + assertNotNull(resp); + assertEquals(200, resp.getHttpStatus()); + assertEquals("from origin!", resp.getData()); + + // Assert that the tunnel was established and the origin received the request + assertTrue("TunnelProxy did not tunnel the request in time", tunnelProxy.getTunnelLatch().await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertTrue("Origin server did not receive the request in time", originLatch.await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertEquals("GET", methodAndPath[0]); + assertEquals("/test", methodAndPath[1]); + + tunnelProxy.stopProxy(); + tunnelProxy.join(); + originServer.shutdown(); + proxyCertFile.delete(); + proxyKeyFile.delete(); + proxyCaFile.delete(); + } + + // Helper to create SSLSocketFactory from HeldCertificate + private static SSLSocketFactory createSslSocketFactory(HeldCertificate cert) throws Exception { + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", cert.keyPair().getPrivate(), "password".toCharArray(), new java.security.cert.Certificate[]{cert.certificate()}); + + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(kmf.getKeyManagers(), null, null); + + return sslContext.getSocketFactory(); + } + + /** + * TunnelProxySslServerOnly is an SSL proxy that presents a server certificate but doesn't require client certificates. + * This is used for testing proxy_cacert functionality where the client validates the proxy's certificate. + */ + static class TunnelProxySslServerOnly extends TunnelProxy { + private final HeldCertificate mServerCert; + private final AtomicBoolean mRunning = new AtomicBoolean(true); + + public TunnelProxySslServerOnly(int port, HeldCertificate serverCert) { + super(port); + this.mServerCert = serverCert; + } + + @Override + public void run() { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", mServerCert.keyPair().getPrivate(), "password".toCharArray(), new java.security.cert.Certificate[]{mServerCert.certificate()}); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + + // No client certificate validation - use default trust manager + sslContext.init(kmf.getKeyManagers(), null, null); + + javax.net.ssl.SSLServerSocketFactory factory = sslContext.getServerSocketFactory(); + mServerSocket = factory.createServerSocket(mPort); + mPort = mServerSocket.getLocalPort(); // Update mPort with the actual assigned port + + // Don't require client auth - this is server-only SSL + ((javax.net.ssl.SSLServerSocket) mServerSocket).setWantClientAuth(false); + ((javax.net.ssl.SSLServerSocket) mServerSocket).setNeedClientAuth(false); + + System.out.println("[TunnelProxySslServerOnly] Listening on port: " + mServerSocket.getLocalPort()); + while (mRunning.get()) { + Socket client = mServerSocket.accept(); + System.out.println("[TunnelProxySslServerOnly] Accepted connection from: " + client.getRemoteSocketAddress()); + new Thread(() -> handle(client)).start(); + } + } catch (IOException e) { + System.out.println("[TunnelProxySslServerOnly] Server socket closed or error: " + e); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + /** + * TunnelProxySsl is a minimal SSL/TLS proxy supporting mTLS (client authentication). + * It uses an SSLServerSocket and requires client certificates signed by the provided CA. + * Only for use in mTLS proxy integration tests. + */ + private static class TunnelProxySsl extends TunnelProxy { + private final HeldCertificate mServerCert; + private final HeldCertificate mClientCa; + private final AtomicBoolean mRunning = new AtomicBoolean(true); + + public TunnelProxySsl(int port, HeldCertificate serverCert, HeldCertificate clientCa) { + super(port); + this.mServerCert = serverCert; + this.mClientCa = clientCa; + } + @Override + public void run() { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", mServerCert.keyPair().getPrivate(), "password".toCharArray(), new java.security.cert.Certificate[]{mServerCert.certificate()}); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + trustStore.load(null, null); + trustStore.setCertificateEntry("ca", mClientCa.certificate()); + javax.net.ssl.TrustManagerFactory tmf = javax.net.ssl.TrustManagerFactory.getInstance(javax.net.ssl.TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(trustStore); + sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + javax.net.ssl.SSLServerSocketFactory factory = sslContext.getServerSocketFactory(); + mServerSocket = factory.createServerSocket(mPort); + ((javax.net.ssl.SSLServerSocket) mServerSocket).setNeedClientAuth(true); + System.out.println("[TunnelProxySsl] Listening on port: " + mServerSocket.getLocalPort()); + while (mRunning.get()) { + Socket client = mServerSocket.accept(); + System.out.println("[TunnelProxySsl] Accepted connection from: " + client.getRemoteSocketAddress()); + new Thread(() -> handle(client)).start(); + } + } catch (IOException e) { + System.out.println("[TunnelProxySsl] Server socket closed or error: " + e); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + /** + * Minimal CONNECT-capable proxy for HTTPS tunneling in tests. + * Listens on a port, accepts CONNECT requests, and pipes bytes between the client and the requested target. + * Used to simulate a real HTTPS proxy for end-to-end CA trust validation. + */ + private static class TunnelProxy extends Thread { + // Latch to signal that a CONNECT tunnel was established + private final CountDownLatch mTunnelLatch = new CountDownLatch(1); + // Port to listen on (0 = auto-assign) + protected int mPort; + // The server socket for accepting connections + public ServerSocket mServerSocket; + // Flag to control proxy shutdown + private final AtomicBoolean mRunning = new AtomicBoolean(true); + + /** + * Create a new TunnelProxy listening on the given port. + * @param port Port to listen on (0 = auto-assign) + */ + TunnelProxy(int port) { mPort = port; } + + /** + * Main accept loop. For each incoming client, start a handler thread. + */ + public void run() { + try { + mServerSocket = new ServerSocket(mPort); + System.out.println("[TunnelProxy] Listening on port: " + mServerSocket.getLocalPort()); + while (mRunning.get()) { + Socket client = mServerSocket.accept(); + System.out.println("[TunnelProxy] Accepted connection from: " + client.getRemoteSocketAddress()); + // Each client handled in its own thread + new Thread(() -> handle(client)).start(); + } + } catch (IOException ignored) { + System.out.println("[TunnelProxy] Server socket closed or error: " + ignored); + } + } + + /** + * Handles a single client connection. Waits for CONNECT, then establishes a tunnel. + */ + void handle(Socket client) { + try (BufferedReader in = new BufferedReader(new InputStreamReader(client.getInputStream())); + OutputStream out = client.getOutputStream()) { + String line = in.readLine(); + // Only handle CONNECT requests (as sent by HTTPS clients to a proxy) + if (line != null && line.startsWith("CONNECT")) { + mTunnelLatch.countDown(); + System.out.println("[TunnelProxy] Received CONNECT: " + line); + out.write("HTTP/1.1 200 Connection Established\r\n\r\n".getBytes()); + out.flush(); + String[] parts = line.split(" "); + String[] hostPort = parts[1].split(":"); + // Open a socket to the requested target (origin server) + Socket target = new Socket(hostPort[0], Integer.parseInt(hostPort[1])); + System.out.println("[TunnelProxy] Established tunnel to: " + hostPort[0] + ":" + hostPort[1]); + // Pipe bytes in both directions (client <-> target) until closed + Thread t1 = new Thread(() -> { + try { + pipe(client, target); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + Thread t2 = new Thread(() -> { + try { + pipe(target, client); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + t1.start(); t2.start(); + try { t1.join(); t2.join(); } catch (InterruptedException ignored) {} + System.out.println("[TunnelProxy] Tunnel closed for: " + hostPort[0] + ":" + hostPort[1]); + target.close(); + } + } catch (Exception ignored) { } + } + + /** + * Copies bytes from inSocket to outSocket until EOF. + * Used to relay data in both directions for the tunnel. + */ + private void pipe(Socket inSocket, Socket outSocket) throws IOException { + try (InputStream in = inSocket.getInputStream(); OutputStream out = outSocket.getOutputStream()) { + byte[] buf = new byte[1024]; + int len; + while ((len = in.read(buf)) != -1) { + out.write(buf, 0, len); + out.flush(); + } + } catch (IOException ignored) { } + } + + /** + * Stops the proxy by closing the server socket and setting the running flag to false. + */ + public void stopProxy() throws IOException { + mRunning.set(false); + if (mServerSocket != null && !mServerSocket.isClosed()) { + mServerSocket.close(); + System.out.println("[TunnelProxy] Proxy stopped."); + } + } + + public CountDownLatch getTunnelLatch() { + return mTunnelLatch; + } + } +} diff --git a/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java b/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java new file mode 100644 index 000000000..2fdc64e30 --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java @@ -0,0 +1,212 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.when; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.net.URL; +import java.util.Collections; + +public class HttpOverTunnelExecutorTest { + + private static final String CRLF = "\r\n"; + private HttpOverTunnelExecutor mExecutor; + + @Mock + private Socket mSocket; + + private OutputStream mOutputStream; + private InputStream mInputStream; + + @Before + public void setUp() throws IOException { + MockitoAnnotations.openMocks(this); + mExecutor = new HttpOverTunnelExecutor(); + mOutputStream = new ByteArrayOutputStream(); + + String httpResponse = "HTTP/1.1 200 OK" + CRLF + "Content-Length: 0\r\n\r\n"; + mInputStream = new ByteArrayInputStream(httpResponse.getBytes()); + + when(mSocket.getOutputStream()).thenReturn(mOutputStream); + when(mSocket.getInputStream()).thenReturn(mInputStream); + } + + @Test + public void postRequestWithBodyAndHeaders() throws IOException { + URL url = new URL("https://test.com/path"); + String body = "{\"key\":\"value\"}"; + java.util.Map headers = new java.util.HashMap<>(); + headers.put("Custom-Header", "CustomValue"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.POST, headers, body, null); + + String expectedRequest = "POST /path HTTP/1.1" + CRLF + + "Host: test.com" + CRLF + + "Custom-Header: CustomValue" + CRLF + + "Content-Length: 15" + CRLF + + "Connection: close" + CRLF + + CRLF + + body; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @Test + public void getRequestWithQuery() throws IOException { + URL url = new URL("http://test.com/path?q=1&v=2"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + + String expectedRequest = "GET /path?q=1&v=2 HTTP/1.1" + CRLF + + "Host: test.com" + CRLF + + "Connection: close" + CRLF + + CRLF; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @Test + public void getRequestWithNonDefaultPort() throws IOException { + URL url = new URL("http://test.com:8080/path"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + + String expectedRequest = "GET /path HTTP/1.1" + CRLF + + "Host: test.com:8080" + CRLF + + "Connection: close" + CRLF + + CRLF; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @Test + public void getRequestWithEmptyPath() throws IOException { + URL url = new URL("http://test.com"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + + String expectedRequest = "GET / HTTP/1.1" + CRLF + + "Host: test.com" + CRLF + + "Connection: close" + CRLF + + CRLF; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @Test(expected = IOException.class) + public void requestThrowsIOException() throws IOException { + URL url = new URL("http://test.com/path"); + when(mSocket.getOutputStream()).thenThrow(new IOException("Socket error")); + + mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + } + + @Test + public void getRequest() throws IOException { + URL url = new URL("http://test.com/path"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + + String expectedRequest = "GET /path HTTP/1.1" + CRLF + + "Host: test.com" + CRLF + + "Connection: close" + CRLF + + CRLF; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @After + public void tearDown() throws IOException { + mOutputStream.close(); + mInputStream.close(); + } + + @Test + public void executeStreamRequestTest() throws IOException { + // Prepare HTTP response with headers and body + String httpResponse = "HTTP/1.1 200 OK\r\n" + + "Content-Type: application/json; charset=utf-8\r\n" + + "Content-Length: 16\r\n" + + "\r\n" + + "{\"data\":\"test\"}"; + + // Set up input stream with the HTTP response + ByteArrayInputStream inputStream = new ByteArrayInputStream(httpResponse.getBytes()); + when(mSocket.getInputStream()).thenReturn(inputStream); + + // Execute the stream request + URL url = new URL("https://test.com/stream"); + HttpStreamResponse response = mExecutor.executeStreamRequest( + mSocket, // finalSocket + mSocket, // tunnelSocket (using same mock for simplicity) + null, // originSocket + url, + HttpMethod.GET, + Collections.emptyMap(), + null // serverCertificates + ); + + // Verify the request was sent correctly + String expectedRequest = "GET /stream HTTP/1.1\r\n" + + "Host: test.com\r\n" + + "Connection: close\r\n" + + "\r\n"; + assertEquals(expectedRequest, mOutputStream.toString()); + + // Verify the response properties + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + + // Verify we can read from the response stream + BufferedReader reader = response.getBufferedReader(); + assertNotNull(reader); + StringBuilder responseBody = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + responseBody.append(line); + } + assertEquals("{\"data\":\"test\"}", responseBody.toString()); + + // Close the response + response.close(); + } + + @Test(expected = IOException.class) + public void executeStreamRequestWithSocketException() throws IOException { + URL url = new URL("http://test.com/stream"); + when(mSocket.getOutputStream()).thenThrow(new IOException("Socket error")); + + mExecutor.executeStreamRequest( + mSocket, + mSocket, + null, + url, + HttpMethod.GET, + Collections.emptyMap(), + null + ); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpRequestHelperTest.java b/src/test/java/io/split/android/client/network/HttpRequestHelperTest.java new file mode 100644 index 000000000..b3f08fb55 --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpRequestHelperTest.java @@ -0,0 +1,112 @@ +package io.split.android.client.network; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.lang.reflect.Method; +import java.net.HttpURLConnection; +import java.net.Proxy; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSocketFactory; + +public class HttpRequestHelperTest { + + @Mock + private HttpURLConnection mockConnection; + @Mock + private HttpsURLConnection mockHttpsConnection; + @Mock + private URL mockUrl; + @Mock + private SplitUrlConnectionAuthenticator mockAuthenticator; + @Mock + private SSLSocketFactory mockSslSocketFactory; + @Mock + private DevelopmentSslConfig mockDevelopmentSslConfig; + @Mock + private CertificateChecker mockCertificateChecker; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + when(mockUrl.openConnection()).thenReturn(mockConnection); + when(mockUrl.openConnection(any(Proxy.class))).thenReturn(mockConnection); + when(mockAuthenticator.authenticate(any(HttpURLConnection.class))).thenReturn(mockConnection); + } + + @Test + public void addHeaders() throws Exception { + Map headers = new HashMap<>(); + headers.put("Content-Type", "application/json"); + headers.put("Authorization", "Bearer token123"); + headers.put(null, "This should be ignored"); + + Method addHeadersMethod = HttpRequestHelper.class.getDeclaredMethod( + "addHeaders", HttpURLConnection.class, Map.class); + addHeadersMethod.setAccessible(true); + addHeadersMethod.invoke(null, mockConnection, headers); + + verify(mockConnection).addRequestProperty("Content-Type", "application/json"); + verify(mockConnection).addRequestProperty("Authorization", "Bearer token123"); + verify(mockConnection, never()).addRequestProperty(null, "This should be ignored"); + } + + @Test + public void applyTimeouts() { + HttpRequestHelper.applyTimeouts(5000, 3000, mockConnection); + verify(mockConnection).setReadTimeout(5000); + verify(mockConnection).setConnectTimeout(3000); + + HttpRequestHelper.applyTimeouts(0, 0, mockConnection); + verify(mockConnection, times(1)).setReadTimeout(any(Integer.class)); + verify(mockConnection, times(1)).setConnectTimeout(any(Integer.class)); + + HttpRequestHelper.applyTimeouts(-1000, -500, mockConnection); + verify(mockConnection, times(1)).setReadTimeout(any(Integer.class)); + verify(mockConnection, times(1)).setConnectTimeout(any(Integer.class)); + } + + @Test + public void applySslConfigWithDevelopmentSslConfig() { + when(mockDevelopmentSslConfig.getSslSocketFactory()).thenReturn(mockSslSocketFactory); + + HttpRequestHelper.applySslConfig(null, mockDevelopmentSslConfig, mockHttpsConnection); + + verify(mockHttpsConnection).setSSLSocketFactory(mockSslSocketFactory); + verify(mockHttpsConnection).setHostnameVerifier(any()); + } + + @Test + public void pinsAreCheckedWithCertificateChecker() throws SSLPeerUnverifiedException { + HttpRequestHelper.checkPins(mockHttpsConnection, mockCertificateChecker); + + verify(mockCertificateChecker).checkPins(mockHttpsConnection); + } + + @Test + public void pinsAreNotCheckedWithoutCertificateChecker() throws SSLPeerUnverifiedException { + HttpRequestHelper.checkPins(mockHttpsConnection, null); + + verify(mockCertificateChecker, never()).checkPins(any()); + } + + @Test + public void pinsAreNotCheckedForNonHttpsConnections() throws SSLPeerUnverifiedException { + HttpRequestHelper.checkPins(mockConnection, mockCertificateChecker); + + verify(mockCertificateChecker, never()).checkPins(any()); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java b/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java new file mode 100644 index 000000000..0bc972444 --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java @@ -0,0 +1,528 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.cert.Certificate; +import java.util.List; +import java.util.Map; + +public class HttpResponseConnectionAdapterTest { + + @Mock + private HttpResponse mMockResponse; + + @Mock + private Certificate mMockCertificate; + + private URL mTestUrl; + private Certificate[] mTestCertificates; + private HttpResponseConnectionAdapter mAdapter; + + @Before + public void setUp() throws MalformedURLException { + mMockCertificate = mock(Certificate.class); + mMockResponse = mock(HttpResponse.class); + mTestUrl = new URL("https://example.com/test"); + mTestCertificates = new Certificate[]{mMockCertificate}; + } + + @Test + public void responseCodeIsValueFromResponse() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals(200, mAdapter.getResponseCode()); + } + + @Test + public void successfulResponse() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("OK", mAdapter.getResponseMessage()); + } + + @Test + public void status400Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(400); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Bad Request", mAdapter.getResponseMessage()); + } + + @Test + public void status401Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(401); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Unauthorized", mAdapter.getResponseMessage()); + } + + @Test + public void status403Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(403); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Forbidden", mAdapter.getResponseMessage()); + } + + @Test + public void status404Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(404); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Not Found", mAdapter.getResponseMessage()); + } + + @Test + public void status500Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(500); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Internal Server Error", mAdapter.getResponseMessage()); + } + + @Test + public void statusUnknownResponse() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(418); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("HTTP 418", mAdapter.getResponseMessage()); + } + + @Test + public void successfulInputStream() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(200); + when(mMockResponse.getData()).thenReturn("test data"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream inputStream = mAdapter.getInputStream(); + assertNotNull(inputStream); + + byte[] buffer = new byte[1024]; + int bytesRead = inputStream.read(buffer); + String result = new String(buffer, 0, bytesRead, StandardCharsets.UTF_8); + assertEquals("test data", result); + } + + @Test + public void nullDataInputStream() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(200); + when(mMockResponse.getData()).thenReturn(null); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream inputStream = mAdapter.getInputStream(); + assertNotNull(inputStream); + + byte[] buffer = new byte[1024]; + int bytesRead = inputStream.read(buffer); + assertEquals(-1, bytesRead); + } + + @Test(expected = IOException.class) + public void inputStreamErrorStatusThrows() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(400); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + mAdapter.getInputStream(); + } + + @Test(expected = IOException.class) + public void inputStream500Throws() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(500); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + mAdapter.getInputStream(); + } + + @Test + public void status400ErrorStream() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(400); + when(mMockResponse.getData()).thenReturn("error message"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream errorStream = mAdapter.getErrorStream(); + assertNotNull(errorStream); + + byte[] buffer = new byte[1024]; + int bytesRead = errorStream.read(buffer); + String result = new String(buffer, 0, bytesRead, StandardCharsets.UTF_8); + assertEquals("error message", result); + } + + @Test + public void errorStreamStatus500() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(500); + when(mMockResponse.getData()).thenReturn(null); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream errorStream = mAdapter.getErrorStream(); + assertNotNull(errorStream); + + byte[] buffer = new byte[1024]; + int bytesRead = errorStream.read(buffer); + assertEquals(-1, bytesRead); // Empty stream + } + + @Test + public void errorStreamIsNullForSuccessful() { + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream errorStream = mAdapter.getErrorStream(); + assertNull(errorStream); + } + + @Test + public void usingProxyIsAlwaysTrue() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + assertTrue(mAdapter.usingProxy()); // This is only used for Proxy + } + + @Test + public void getServerCertificatesReturnsCerts() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + Certificate[] certificates = mAdapter.getServerCertificates(); + assertSame(mTestCertificates, certificates); + assertEquals(1, certificates.length); + assertSame(mMockCertificate, certificates[0]); + } + + @Test + public void nullServerCertificates() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, null); + + Certificate[] certificates = mAdapter.getServerCertificates(); + assertNull(certificates); + } + + @Test + public void contentTypeIsJsonForJsonData() { + when(mMockResponse.getData()).thenReturn("{\"key\": \"value\"}"); + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getHeaderField("content-type"); + assertEquals("application/json; charset=utf-8", contentType); + } + + @Test + public void nullHeaderField() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String result = mAdapter.getHeaderField(null); + assertNull(result); + } + + @Test + public void getHeaderIsCaseInsensitive() { + when(mMockResponse.getData()).thenReturn("{\"key\": \"value\"}"); + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType1 = mAdapter.getHeaderField("Content-Type"); + String contentType2 = mAdapter.getHeaderField("CONTENT-TYPE"); + String contentType3 = mAdapter.getHeaderField("content-type"); + + assertEquals("application/json; charset=utf-8", contentType1); + assertEquals("application/json; charset=utf-8", contentType2); + assertEquals("application/json; charset=utf-8", contentType3); + } + + @Test + public void generatedHeaderFieldsCanBeRetrieved() throws IOException { + when(mMockResponse.getData()).thenReturn("{\"test\": \"data\"}"); + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + Map> headers = mAdapter.getHeaderFields(); + assertNotNull(headers); + + assertTrue(headers.containsKey("content-type")); + assertEquals("application/json; charset=utf-8", headers.get("content-type").get(0)); + + assertTrue(headers.containsKey("content-length")); + assertEquals("16", headers.get("content-length").get(0)); + + assertTrue(headers.containsKey("content-encoding")); + assertEquals("utf-8", headers.get("content-encoding").get(0)); + + assertTrue(headers.containsKey("status")); + assertEquals("200 OK", headers.get("status").get(0)); + } + + @Test + public void getContentLengthWithData() { + when(mMockResponse.getData()).thenReturn("Hello World"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + long length = mAdapter.getContentLengthLong(); + assertEquals(11, length); // "Hello World" is 11 bytes + } + + @Test + public void getContentLengthWithNullData() { + when(mMockResponse.getData()).thenReturn(null); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + long length = mAdapter.getContentLengthLong(); + assertEquals(0, length); + } + + @Test + public void getContentLengthEmptyData() { + when(mMockResponse.getData()).thenReturn(""); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + long length = mAdapter.getContentLengthLong(); + assertEquals(0, length); + } + + @Test + public void getContentTypeJsonData() { + when(mMockResponse.getData()).thenReturn("{\"key\": \"value\"}"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertEquals("application/json; charset=utf-8", contentType); + } + + @Test + public void getContentTypeJsonArray() { + when(mMockResponse.getData()).thenReturn("[{\"key\": \"value\"}]"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertEquals("application/json; charset=utf-8", contentType); + } + + @Test + public void getContentTypeHtmlData() { + when(mMockResponse.getData()).thenReturn("Test"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertEquals("text/html; charset=utf-8", contentType); + } + + @Test + public void getContentTypeText() { + when(mMockResponse.getData()).thenReturn("Plain text content"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertEquals("text/plain; charset=utf-8", contentType); + } + + @Test + public void getContentTypeNullDataHasNoContentType() { + when(mMockResponse.getData()).thenReturn(null); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertNull(contentType); + } + + @Test + public void testGetContentEncoding() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String encoding = mAdapter.getContentEncoding(); + assertEquals("utf-8", encoding); + } + + @Test + public void testGetDate() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + long currentTime = System.currentTimeMillis(); + long date = mAdapter.getDate(); + + // Should be close to current time (within 1 second) + assertTrue(Math.abs(date - currentTime) < 1000); + } + + @Test + public void urlCanBeRetrieved() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + URL url = mAdapter.getURL(); + assertSame(mTestUrl, url); + } + + @Test(expected = IOException.class) + public void getOutputStreamThrowsWhenNotEnabled() throws IOException { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + // Should throw exception since doOutput is not enabled + mAdapter.getOutputStream(); + } + + @Test + public void setDoOutputEnablesOutput() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + // Initially doOutput should be false + assertEquals(false, mAdapter.getDoOutput()); + + // After setting doOutput to true, getDoOutput should return true + mAdapter.setDoOutput(true); + assertEquals(true, mAdapter.getDoOutput()); + } + + @Test + public void getOutputStreamAfterEnablingOutput() throws IOException { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + mAdapter.setDoOutput(true); + + assertNotNull("Output stream should not be null when doOutput is enabled", mAdapter.getOutputStream()); + } + + @Test + public void writeToOutputStream() throws IOException { + // Create a ByteArrayOutputStream to capture the written data + ByteArrayOutputStream testOutputStream = new ByteArrayOutputStream(); + + // Use the constructor that accepts a custom OutputStream + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates, testOutputStream); + mAdapter.setDoOutput(true); + + // Write test data to the output stream + String testData = "Test output data"; + mAdapter.getOutputStream().write(testData.getBytes(StandardCharsets.UTF_8)); + + // Verify that the data was written correctly + assertEquals("Written data should match the input", testData, testOutputStream.toString(StandardCharsets.UTF_8.name())); + } + + @Test + public void disconnectClosesOutputStream() throws IOException { + // Create a custom OutputStream that tracks if it's been closed + TestOutputStream testOutputStream = new TestOutputStream(); + + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates, testOutputStream); + mAdapter.setDoOutput(true); + + // Get the output stream and write some data + mAdapter.getOutputStream().write("Test".getBytes(StandardCharsets.UTF_8)); + + // Verify the stream is not closed yet + assertFalse("Output stream should not be closed before disconnect", testOutputStream.isClosed()); + + // Disconnect should close the output stream + mAdapter.disconnect(); + + // Verify the stream was closed + assertTrue("Output stream should be closed after disconnect", testOutputStream.isClosed()); + } + + @Test + public void disconnectClosesInputStream() throws IOException { + // Create a custom InputStream that tracks if it's been closed + TestInputStream testInputStream = new TestInputStream("Test response data".getBytes(StandardCharsets.UTF_8)); + TestOutputStream testOutputStream = new TestOutputStream(); + + // Create adapter with injected test input stream + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter( + mTestUrl, + mMockResponse, + mTestCertificates, + testOutputStream, + testInputStream, + null); + + // Get the input stream and read some data to simulate usage + InputStream stream = mAdapter.getInputStream(); + byte[] buffer = new byte[10]; + stream.read(buffer); + + // Verify the stream is not closed yet + assertFalse("Input stream should not be closed before disconnect", testInputStream.isClosed()); + + // Disconnect should close the input stream + mAdapter.disconnect(); + + // Verify the stream was closed + assertTrue("Input stream should be closed after disconnect", testInputStream.isClosed()); + } + + /** + * Custom OutputStream implementation for testing that tracks if it's been closed. + */ + private static class TestOutputStream extends ByteArrayOutputStream { + private boolean mClosed = false; + + @Override + public void close() throws IOException { + super.close(); + mClosed = true; + } + + public boolean isClosed() { + return mClosed; + } + } + + private static class TestInputStream extends ByteArrayInputStream { + private boolean mClosed = false; + + public TestInputStream(byte[] data) { + super(data); + } + + @Override + public void close() throws IOException { + super.close(); + mClosed = true; + } + + public boolean isClosed() { + return mClosed; + } + } + + @Test + public void disconnectClosesErrorStream() throws IOException { + TestInputStream testErrorStream = new TestInputStream("Error data".getBytes(StandardCharsets.UTF_8)); + TestOutputStream testOutputStream = new TestOutputStream(); + + when(mMockResponse.getHttpStatus()).thenReturn(404); // Error status + mAdapter = new HttpResponseConnectionAdapter( + mTestUrl, + mMockResponse, + mTestCertificates, + testOutputStream, + null, + testErrorStream); + + // Get the error stream and read some data to simulate usage + InputStream stream = mAdapter.getErrorStream(); + byte[] buffer = new byte[10]; + stream.read(buffer); + + assertFalse("Error stream should not be closed before disconnect", testErrorStream.isClosed()); + + mAdapter.disconnect(); + + assertTrue("Error stream should be closed after disconnect", testErrorStream.isClosed()); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java b/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java new file mode 100644 index 000000000..aa60be76e --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java @@ -0,0 +1,148 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.BufferedReader; +import java.io.IOException; +import java.net.Socket; + +public class HttpStreamResponseTest { + + private static final int HTTP_STATUS_OK = 200; + private static final int HTTP_STATUS_BAD_REQUEST = 400; + + @Mock + private BufferedReader mockBufferedReader; + + @Mock + private Socket mockTunnelSocket; + + @Mock + private Socket mockOriginSocket; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void createFromTunnelSocketReturnsValidResponse() { + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Verify the response is created correctly + assertNotNull(response); + assertEquals(HTTP_STATUS_OK, response.getHttpStatus()); + } + + @Test + public void createFromTunnelSocketWithNullSocketsReturnsValidResponse() { + // Create response with null sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_BAD_REQUEST, + mockBufferedReader, + null, + null + ); + + // Verify the response is created correctly + assertNotNull(response); + assertEquals(HTTP_STATUS_BAD_REQUEST, response.getHttpStatus()); + } + + @Test + public void closeSuccessfullyClosesAllResources() throws IOException { + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Close the response + response.close(); + + // Verify all resources were closed in the correct order + verify(mockBufferedReader, times(1)).close(); + verify(mockOriginSocket, times(1)).close(); + verify(mockTunnelSocket, times(1)).close(); + } + + @Test + public void closeWithNullSocketsOnlyClosesBufferedReader() throws IOException { + // Create response with null sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + null, + null + ); + + // Close the response + response.close(); + + // Verify only the BufferedReader was closed + verify(mockBufferedReader, times(1)).close(); + verifyNoMoreInteractions(mockTunnelSocket, mockOriginSocket); + } + + @Test + public void closeWithSameTunnelAndOriginSocketClosesSocketOnce() throws IOException { + // Create response with the same socket for tunnel and origin + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockTunnelSocket + ); + + // Close the response + response.close(); + + // Verify BufferedReader was closed + verify(mockBufferedReader, times(1)).close(); + + // Verify tunnel socket was closed only once (since it's the same as origin socket) + verify(mockTunnelSocket, times(1)).close(); + } + + @Test + public void closeWithExceptionsSucceeds() throws IOException { + // Setup mocks to throw exceptions when closed + doThrow(new IOException("BufferedReader close error")).when(mockBufferedReader).close(); + doThrow(new IOException("Origin socket close error")).when(mockOriginSocket).close(); + doThrow(new IOException("Tunnel socket close error")).when(mockTunnelSocket).close(); + + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Close the response - should not throw exceptions + response.close(); + + // Verify all resources were attempted to be closed despite exceptions + verify(mockBufferedReader, times(1)).close(); + verify(mockOriginSocket, times(1)).close(); + verify(mockTunnelSocket, times(1)).close(); + } +} diff --git a/src/test/java/io/split/android/client/network/ProxyConfigurationTest.java b/src/test/java/io/split/android/client/network/ProxyConfigurationTest.java new file mode 100644 index 000000000..5730579b5 --- /dev/null +++ b/src/test/java/io/split/android/client/network/ProxyConfigurationTest.java @@ -0,0 +1,142 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; + +public class ProxyConfigurationTest { + + private static final String VALID_URL = "http://proxy.example.com:8080"; + private static final String INVALID_URL = "invalid://\\url"; + private static final String URL_WITH_PATH = "https://proxy.example.com:8080/path/to/proxy"; + private static final String URL_WITH_PATH_NORMALIZED = "https://proxy.example.com:8080/path/to/proxy"; + + @Mock + private BearerCredentialsProvider mockBearerCredentialsProvider; + + @Mock + private BasicCredentialsProvider mockBasicCredentialsProvider; + + private InputStream clientCert; + private InputStream clientPk; + private InputStream caCert; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + + clientCert = new ByteArrayInputStream("client-cert-content".getBytes()); + clientPk = new ByteArrayInputStream("client-pk-content".getBytes()); + caCert = new ByteArrayInputStream("ca-cert-content".getBytes()); + } + + @Test + public void buildWithValidUrl() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .build(); + + assertNotNull("Configuration should be created with valid URL", config); + assertEquals("URL should match", VALID_URL, config.getUrl().toString()); + } + + @Test + public void buildWithInvalidUrlHasNoNullConfig() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(INVALID_URL) + .build(); + + assertNotNull(config); + } + + @Test + public void urlIsNormalized() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(URL_WITH_PATH) + .build(); + + assertNotNull("Configuration should be created with URL containing path", config); + assertEquals("URL should be normalized", URL_WITH_PATH_NORMALIZED, config.getUrl().toString()); + } + + @Test + public void bearerCredentialsProvider() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .credentialsProvider(mockBearerCredentialsProvider) + .build(); + + assertNotNull("Configuration should be created with bearer credentials", config); + assertSame("Credentials provider should match", mockBearerCredentialsProvider, config.getCredentialsProvider()); + } + + @Test + public void basicCredentialsProvider() { + BasicCredentialsProvider provider = mockBasicCredentialsProvider; + + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .credentialsProvider(provider) + .build(); + + assertNotNull("Configuration should be created with basic credentials", config); + assertSame("Credentials provider should match", provider, config.getCredentialsProvider()); + } + + @Test + public void mtlsValues() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .mtls(clientCert, clientPk) + .build(); + + assertNotNull("Configuration should be created with mTLS", config); + assertSame("Client certificate should match", clientCert, config.getClientCert()); + assertSame("Client private key should match", clientPk, config.getClientPk()); + } + + @Test + public void cacert() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .caCert(caCert) + .build(); + + assertNotNull("Configuration should be created with CA certificate", config); + assertSame("CA certificate should match", caCert, config.getCaCert()); + } + + @Test + public void allOptions() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .credentialsProvider(mockBearerCredentialsProvider) + .mtls(clientCert, clientPk) + .caCert(caCert) + .build(); + + assertNotNull("Configuration should be created with all options", config); + assertEquals("URL should match", VALID_URL, config.getUrl().toString()); + assertSame("Credentials provider should match", mockBearerCredentialsProvider, config.getCredentialsProvider()); + assertSame("Client certificate should match", clientCert, config.getClientCert()); + assertSame("Client private key should match", clientPk, config.getClientPk()); + assertSame("CA certificate should match", caCert, config.getCaCert()); + } + + @Test + public void buildWithoutUrlReturnsNonNullConfig() { + ProxyConfiguration config = ProxyConfiguration.builder() + .credentialsProvider(mockBearerCredentialsProvider) + .build(); + + assertNotNull("Configuration should be created with null URL", config); + } +} diff --git a/src/test/java/io/split/android/client/network/ProxySslSocketFactoryProviderImplTest.java b/src/test/java/io/split/android/client/network/ProxySslSocketFactoryProviderImplTest.java new file mode 100644 index 000000000..b4d88868a --- /dev/null +++ b/src/test/java/io/split/android/client/network/ProxySslSocketFactoryProviderImplTest.java @@ -0,0 +1,139 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertNotNull; + +import androidx.annotation.NonNull; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileWriter; +import java.util.Base64; + +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.tls.HeldCertificate; + +public class ProxySslSocketFactoryProviderImplTest { + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + private final Base64Decoder mBase64Decoder = new Base64Decoder() { + @Override + public byte[] decode(String base64) { + return Base64.getDecoder().decode(base64); + } + }; + + @Test + public void creatingWithValidCaCertCreatesSocketFactory() throws Exception { + HeldCertificate ca = getCaCert(); + File caCertFile = tempFolder.newFile("held-ca.pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write(ca.certificatePem()); + } + ProxySslSocketFactoryProviderImpl provider = getProvider(); + try (FileInputStream fis = new FileInputStream(caCertFile)) { + SSLSocketFactory socketFactory = provider.create(fis); + assertNotNull(socketFactory); + } + } + + @Test(expected = Exception.class) + public void creatingWithInvalidCaCertThrows() throws Exception { + File caCertFile = tempFolder.newFile("invalid-ca.pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write("not a cert"); + } + ProxySslSocketFactoryProviderImpl provider = getProvider(); + try (FileInputStream fis = new FileInputStream(caCertFile)) { + provider.create(fis); + } + } + + @Test + public void creatingWithValidMtlsParamsCreatesSocketFactory() throws Exception { + // Create CA cert and client cert & key + HeldCertificate ca = getCaCert(); + HeldCertificate clientCert = getClientCert(ca); + File caCertFile = createCaCertFile(ca); + File clientCertFile = tempFolder.newFile("client.crt"); + File clientKeyFile = tempFolder.newFile("client.key"); + + // Write client certificate and key to separate files + try (FileWriter writer = new FileWriter(clientCertFile)) { + writer.write(clientCert.certificatePem()); + } + try (FileWriter writer = new FileWriter(clientKeyFile)) { + writer.write(clientCert.privateKeyPkcs8Pem()); + } + + // Create socket factory + ProxySslSocketFactoryProviderImpl factory = new ProxySslSocketFactoryProviderImpl(mBase64Decoder); + SSLSocketFactory sslSocketFactory; + try (FileInputStream caCertStream = new FileInputStream(caCertFile); + FileInputStream clientCertStream = new FileInputStream(clientCertFile); + FileInputStream clientKeyStream = new FileInputStream(clientKeyFile)) { + sslSocketFactory = factory.create(caCertStream, clientCertStream, clientKeyStream); + } + + assertNotNull(sslSocketFactory); + } + + @Test(expected = Exception.class) + public void creatingWithInvalidMtlsParamsThrows() throws Exception { + // Create valid CA cert but invalid client cert/key files + HeldCertificate ca = getCaCert(); + File caCertFile = createCaCertFile(ca); + File invalidClientCertFile = tempFolder.newFile("invalid-client.crt"); + File invalidClientKeyFile = tempFolder.newFile("invalid-client.key"); + + // Write invalid data to cert and key files + try (FileWriter writer = new FileWriter(invalidClientCertFile)) { + writer.write("invalid certificate"); + } + try (FileWriter writer = new FileWriter(invalidClientKeyFile)) { + writer.write("invalid key"); + } + + ProxySslSocketFactoryProviderImpl provider = getProvider(); + try (FileInputStream caCertStream = new FileInputStream(caCertFile); + FileInputStream invalidClientCertStream = new FileInputStream(invalidClientCertFile); + FileInputStream invalidClientKeyStream = new FileInputStream(invalidClientKeyFile)) { + provider.create(caCertStream, invalidClientCertStream, invalidClientKeyStream); + } + } + + private File createCaCertFile(HeldCertificate ca) throws Exception { + File caCertFile = tempFolder.newFile("mtls-ca.pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write(ca.certificatePem()); + } + return caCertFile; + } + + @NonNull + private static HeldCertificate getCaCert() { + return new HeldCertificate.Builder() + .commonName("Test CA") + .certificateAuthority(0) + .build(); + } + + @NonNull + private static HeldCertificate getClientCert(HeldCertificate ca) { + return new HeldCertificate.Builder() + .commonName("Test Client") + .signedBy(ca) + .build(); + } + + @NonNull + private ProxySslSocketFactoryProviderImpl getProvider() { + return new ProxySslSocketFactoryProviderImpl(mBase64Decoder); + } +} diff --git a/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java b/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java new file mode 100644 index 000000000..203e50570 --- /dev/null +++ b/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java @@ -0,0 +1,250 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.Socket; +import java.security.cert.Certificate; +import java.util.Objects; + +public class RawHttpResponseParserTest { + + private final Certificate[] mServerCertificates = new Certificate[]{}; + + @Test + public void httpResponseWithValidResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: application/json\r\n" + + "Content-Length: 25\r\n" + + "\r\n" + + "{\"message\":\"Hello World\"}"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + assertEquals("Response data should match", "{\"message\":\"Hello World\"}", response.getData()); + assertTrue("Response should be successful", response.isSuccess()); + } + + @Test + public void responseWithErrorStatusReturnsErrorResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 500 Internal Server Error\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 13\r\n" + + "\r\n" + + "Server Error!"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 500", 500, response.getHttpStatus()); + assertEquals("Response data should match", "Server Error!", response.getData()); + assertFalse("Response should not be successful", response.isSuccess()); + } + + @Test + public void responseWithNoContentLengthReadsUntilEnd() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/plain\r\n" + + "Connection: close\r\n" + + "\r\n" + + "This is response data\r\n" + + "with multiple lines\r\n" + + "until connection closes"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + assertNotNull("Response data should not be null", response.getData()); + assertTrue("Response data should contain expected content", + response.getData().contains("This is response data")); + assertTrue("Response data should contain multiple lines", + response.getData().contains("with multiple lines")); + } + + @Test + public void responseWithNoBodyReturnsEmptyData() throws Exception { + String rawHttpResponse = + "HTTP/1.1 204 No Content\r\n" + + "Content-Length: 0\r\n" + + "\r\n"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 204", 204, response.getHttpStatus()); + assertTrue("Response data should be null or empty", + response.getData() == null || response.getData().isEmpty()); + } + + @Test + public void responseWithInvalidStatusLineThrowsException() throws Exception { + String rawHttpResponse = "INVALID STATUS LINE\r\n\r\n"; + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpResponse(inputStream, mServerCertificates); + fail("Should have thrown exception for invalid status line"); + } catch (IOException e) { + assertTrue("Exception should mention invalid status", + Objects.requireNonNull(e.getMessage()).contains("Invalid HTTP status")); + } + } + + @Test + public void responseWithEmptyStreamThrowsException() throws Exception { + InputStream inputStream = new ByteArrayInputStream(new byte[0]); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpResponse(inputStream, mServerCertificates); + fail("Should have thrown exception for empty stream"); + } catch (IOException e) { + assertTrue("Exception should mention no response", + e.getMessage().contains("No HTTP response")); + } + } + + @Test + public void responseWithChunkedEncodingHandlesCorrectly() throws Exception { + String rawHttpResponse = + // headers + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + // 1st chunk size + "15\r\n" + + // 1st chunk data + "This is chunked data!" + + "\r\n" + + + // 2nd chunk size + "0\r\n" + + "\r\n"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + assertNotNull("Response data should not be null", response.getData()); + assertTrue("Response data should contain expected content", + response.getData().contains("This is chunked data!")); + } + + // Tests for parseHttpStreamResponse method + + @Test + public void parseHttpStreamResponseWithValidInputReturnsCorrectResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: application/json\r\n" + + "Content-Length: 25\r\n" + + "\r\n" + + "{\"message\":\"Hello World\"}"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + Socket mockTunnelSocket = mock(Socket.class); + Socket mockOriginSocket = mock(Socket.class); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, mockTunnelSocket, mockOriginSocket); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithNullSocketsReturnsValidResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 13\r\n" + + "\r\n" + + "Hello, World!"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, null, null); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithDifferentContentTypeUsesCorrectCharset() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/html; charset=ISO-8859-1\r\n" + + "Content-Length: 20\r\n" + + "\r\n" + + "Test Page"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("ISO-8859-1")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, null, null); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithEmptyStreamThrowsException() throws Exception { + InputStream inputStream = new ByteArrayInputStream(new byte[0]); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpStreamResponse(inputStream, null, null); + fail("Should have thrown exception for empty stream"); + } catch (IOException e) { + assertTrue("Exception should mention no response", + e.getMessage().contains("No HTTP response")); + } + } + + @Test + public void parseHttpStreamResponseWithInvalidStatusLineThrowsException() throws Exception { + String rawHttpResponse = "INVALID STATUS LINE\r\n\r\n"; + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpStreamResponse(inputStream, null, null); + fail("Should have thrown exception for invalid status line"); + } catch (IOException e) { + assertTrue("Exception should mention invalid status", + Objects.requireNonNull(e.getMessage()).contains("Invalid HTTP status")); + } + } +} diff --git a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java new file mode 100644 index 000000000..f6a1157a0 --- /dev/null +++ b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java @@ -0,0 +1,465 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintWriter; +import java.net.HttpRetryException; +import java.net.Socket; +import java.security.KeyStore; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.tls.HeldCertificate; + +public class SslProxyTunnelEstablisherTest { + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + private TestSslProxy testProxy; + private SSLSocketFactory clientSslSocketFactory; + + @Before + public void setUp() throws Exception { + // override the default hostname verifier for testing + HttpsURLConnection.setDefaultHostnameVerifier(new HostnameVerifier() { + @Override + public boolean verify(String hostname, SSLSession sslSession) { + return true; + } + }); + + // Create test certificates + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + + // Create SSL socket factory that trusts the proxy CA + File proxyCaFile = tempFolder.newFile("proxy-ca.pem"); + try (FileWriter writer = new FileWriter(proxyCaFile)) { + writer.write(proxyCa.certificatePem()); + } + + ProxySslSocketFactoryProvider factory = new ProxySslSocketFactoryProviderImpl(); + try (java.io.FileInputStream caInput = new java.io.FileInputStream(proxyCaFile)) { + clientSslSocketFactory = factory.create(caInput); + } + + // Start test SSL proxy + testProxy = new TestSslProxy(0, proxyServerCert); + testProxy.start(); + + // Wait for proxy to start + while (testProxy.getPort() == 0) { + Thread.sleep(10); + } + } + + @After + public void tearDown() throws Exception { + if (testProxy != null) { + testProxy.stop(); + } + } + + @Test + public void establishTunnelWithValidSslProxySucceeds() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + String targetHost = "example.com"; + int targetPort = 443; + BearerCredentialsProvider proxyCredentialsProvider = mock(BearerCredentialsProvider.class); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + targetHost, + targetPort, + clientSslSocketFactory, + proxyCredentialsProvider, + false); + + assertNotNull("Tunnel socket should not be null", tunnelSocket); + assertTrue("Tunnel socket should be connected", tunnelSocket.isConnected()); + + // Verify CONNECT request was sent and successful + assertTrue("Proxy should have received CONNECT request", + testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + assertEquals("CONNECT example.com:443 HTTP/1.1", testProxy.getReceivedConnectLine()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNotTrustedCertificatedThrows() throws Exception { + SSLContext untrustedContext = SSLContext.getInstance("TLS"); + untrustedContext.init(null, null, null); + SSLSocketFactory untrustedSocketFactory = untrustedContext.getSocketFactory(); + BearerCredentialsProvider proxyCredentialsProvider = mock(BearerCredentialsProvider.class); + + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + try { + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + untrustedSocketFactory, + proxyCredentialsProvider, + false); + fail("Should have thrown exception for untrusted certificate"); + } catch (IOException e) { + assertTrue("Exception should be SSL-related", e.getMessage().contains("certification")); + } + } + + @Test + public void establishTunnelWithFailingProxyConnectionThrows() { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + BearerCredentialsProvider proxyCredentialsProvider = mock(BearerCredentialsProvider.class); + + try { + establisher.establishTunnel( + "localhost", + -1234, + "example.com", + 443, + clientSslSocketFactory, + proxyCredentialsProvider, + false); + fail("Should have thrown exception for connection failure"); + } catch (IOException e) { + // The implementation wraps the original exception with a descriptive message + assertTrue(e.getMessage().contains("Failed to establish SSL tunnel")); + } + } + + @Test + public void bearerTokenIsPassedWhenSet() throws IOException, InterruptedException { + // For Bearer token, we don't need to mock the Base64Encoder since it's not used + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BearerCredentialsProvider() { + @Override + public String getToken() { + return "token"; + } + }, + false); + boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); + assertTrue("Proxy should have received authorization header", await); + assertEquals("Proxy-Authorization: Bearer token", testProxy.getReceivedAuthHeader()); + } + + @Test + public void basicAuthIsPassedWhenSet() throws IOException, InterruptedException { + // Create a mock Base64Encoder + Base64Encoder mockEncoder = mock(Base64Encoder.class); + String mockEncodedCredentials = "MOCK_ENCODED_CREDENTIALS"; + when(mockEncoder.encode("username:password")).thenReturn(mockEncodedCredentials); + + // Create SslProxyTunnelEstablisher with the mock encoder + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(mockEncoder); + + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BasicCredentialsProvider() { + @Override + public String getUsername() { + return "username"; + } + + @Override + public String getPassword() { + return "password"; + } + }, + false); + boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); + assertTrue("Proxy should have received authorization header", await); + + // The expected header should contain the mock encoded credentials + String expectedHeader = "Proxy-Authorization: Basic " + mockEncodedCredentials; + assertEquals(expectedHeader, testProxy.getReceivedAuthHeader()); + } + + @Test + public void establishTunnelWithNullCredentialsProviderDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null, + false); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNullBearerTokenDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BearerCredentialsProvider() { + @Override + public String getToken() { + return null; + } + }, + false); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithEmptyBearerTokenDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BearerCredentialsProvider() { + @Override + public String getToken() { + return ""; + } + }, + false); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNullStatusLineThrowsIOException() { + testProxy.setConnectResponse(null); + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + IOException exception = assertThrows(IOException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null, false)); + + assertNotNull(exception); + } + + @Test + public void establishTunnelWithMalformedStatusLineThrowsIOException() { + testProxy.setConnectResponse("HTTP/1.1"); // Malformed, missing status code + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + IOException exception = assertThrows(IOException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null, + false)); + + assertNotNull(exception); + } + + @Test + public void establishTunnelWithProxyAuthRequiredThrowsHttpRetryException() { + testProxy.setConnectResponse("HTTP/1.1 407 Proxy Authentication Required"); + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + HttpRetryException exception = assertThrows(HttpRetryException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null, + false)); + + assertEquals(407, exception.responseCode()); + } + + /** + * Test SSL proxy that accepts SSL connections and handles CONNECT requests. + */ + private static class TestSslProxy extends Thread { + private final int mPort; + private final HeldCertificate mServerCert; + private SSLServerSocket mServerSocket; + private final AtomicBoolean mRunning = new AtomicBoolean(true); + private final CountDownLatch mConnectRequestReceived = new CountDownLatch(1); + private final CountDownLatch mAuthorizationHeaderReceived = new CountDownLatch(1); + private final AtomicReference mReceivedConnectLine = new AtomicReference<>(); + private final AtomicReference mReceivedAuthHeader = new AtomicReference<>(); + private final AtomicReference mConnectResponse = new AtomicReference<>("HTTP/1.1 200 Connection established"); + + public TestSslProxy(int port, HeldCertificate serverCert) { + mPort = port; + mServerCert = serverCert; + } + + @Override + public void run() { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", mServerCert.keyPair().getPrivate(), "password".toCharArray(), + new java.security.cert.Certificate[]{mServerCert.certificate()}); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + sslContext.init(kmf.getKeyManagers(), null, null); + + mServerSocket = (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(mPort); + mServerSocket.setWantClientAuth(false); + mServerSocket.setNeedClientAuth(false); + + while (mRunning.get()) { + try { + Socket client = mServerSocket.accept(); + handleClient(client); + } catch (IOException e) { + if (mRunning.get()) { + System.err.println("Error accepting client: " + e.getMessage()); + } + } + } + } catch (Exception e) { + throw new RuntimeException("Failed to start test SSL proxy", e); + } + } + + private void handleClient(Socket client) { + try { + BufferedReader reader = new BufferedReader( + new InputStreamReader(client.getInputStream())); + PrintWriter writer = new PrintWriter(client.getOutputStream(), true); + + // Read CONNECT request + String line = reader.readLine(); + if (line != null && line.startsWith("CONNECT")) { + mReceivedConnectLine.set(line); + mConnectRequestReceived.countDown(); + + while((line = reader.readLine()) != null && !line.isEmpty()) { + if (line.contains("Authorization") && (line.contains("Bearer") || line.contains("Basic"))) { + mAuthorizationHeaderReceived.countDown(); + mReceivedAuthHeader.set(line); + } + } + + // Send configured CONNECT response + String response = mConnectResponse.get(); + if (response != null) { + writer.println(response); + writer.println(); + writer.flush(); + } + + // Keep connection open for tunnel + Thread.sleep(100); + } + } catch (Exception e) { + System.err.println("Error handling client: " + e.getMessage()); + } finally { + try { + client.close(); + } catch (IOException e) { + // Ignore + } + } + } + + public int getPort() { + return mServerSocket != null ? mServerSocket.getLocalPort() : 0; + } + + public CountDownLatch getConnectRequestReceived() { + return mConnectRequestReceived; + } + + public CountDownLatch getAuthorizationHeaderReceived() { + return mAuthorizationHeaderReceived; + } + + public String getReceivedConnectLine() { + return mReceivedConnectLine.get(); + } + + public String getReceivedAuthHeader() { + return mReceivedAuthHeader.get(); + } + + public void setConnectResponse(String connectResponse) { + mConnectResponse.set(connectResponse); + } + } +} diff --git a/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java b/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java index 1ab36656e..eeb53f2e1 100644 --- a/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java +++ b/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java @@ -1,5 +1,13 @@ package io.split.android.client.service.sseclient; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -29,14 +37,6 @@ import io.split.android.client.service.sseclient.sseclient.SseHandler; import io.split.sharedtest.fake.HttpStreamResponseMock; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - public class SseClientTest { @Mock @@ -66,7 +66,7 @@ public void setup() throws URISyntaxException { } @Test - public void onConnect() throws InterruptedException, HttpException { + public void onConnect() throws InterruptedException, HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); TestConnListener connListener = spy(new TestConnListener(onOpenLatch)); @@ -87,7 +87,7 @@ public void onConnect() throws InterruptedException, HttpException { } @Test - public void onConnectNotConfirmed() throws InterruptedException, HttpException { + public void onConnectNotConfirmed() throws InterruptedException, HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); TestConnListener connListener = spy(new TestConnListener(onOpenLatch)); @@ -268,7 +268,7 @@ public void onConnectionSuccess() { } @Test - public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() throws HttpException { + public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() throws HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); BufferedReader reader = Mockito.mock(BufferedReader.class); @@ -286,7 +286,7 @@ public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() thr } @Test - public void retryableErrorWhenRequestFailsWithHttpExceptionWithNullCode() throws HttpException { + public void retryableErrorWhenRequestFailsWithHttpExceptionWithNullCode() throws HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); BufferedReader reader = Mockito.mock(BufferedReader.class); diff --git a/src/test/java/io/split/android/client/service/workmanager/HttpClientProviderTest.java b/src/test/java/io/split/android/client/service/workmanager/HttpClientProviderTest.java new file mode 100644 index 000000000..2f215febe --- /dev/null +++ b/src/test/java/io/split/android/client/service/workmanager/HttpClientProviderTest.java @@ -0,0 +1,177 @@ +package io.split.android.client.service.workmanager; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mockStatic; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.MockitoJUnitRunner; + +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.CertificatePinningConfiguration; +import io.split.android.client.network.CertificatePinningConfigurationProvider; +import io.split.android.client.network.HttpClient; +import io.split.android.client.storage.cipher.SplitCipher; +import io.split.android.client.storage.cipher.SplitCipherFactory; +import io.split.android.client.storage.db.SplitRoomDatabase; +import io.split.android.client.storage.db.StorageFactory; +import io.split.android.client.storage.general.GeneralInfoStorage; +import io.split.android.client.utils.HttpProxySerializer; + +@RunWith(MockitoJUnitRunner.class) +public class HttpClientProviderTest { + + @Mock + private SplitRoomDatabase mockDatabase; + + @Mock + private GeneralInfoStorage mockGeneralInfoStorage; + + @Mock + private SplitCipher mockSplitCipher; + + @Mock + private CertificatePinningConfiguration mockCertPinningConfig; + + @Mock + private HttpProxyDto mockHttpProxyDto; + + private static final String TEST_API_KEY = "test-api-key"; + private static final String TEST_CERT_PINNING_CONFIG = "{\"pins\":[]}"; + + @Test + public void shouldBuildHttpClientWithNullCertificatePinningConfig() { + HttpClient result = buildHttpClientWithMocks(null, false, null); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithValidCertificatePinningConfig() { + HttpClient result = buildHttpClientWithCertPinningMocks(false); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithValidProxyConfig() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWhenProxyConfigProvidedButDtoIsNull() { + HttpClient result = buildHttpClientWithMocks(null, true, null); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithProxyBasicAuth() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + mockHttpProxyDto.username = "testuser"; + mockHttpProxyDto.password = "testpass"; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithProxyBearerToken() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + mockHttpProxyDto.bearerToken = "test-bearer-token"; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithProxyMtlsAuth() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + mockHttpProxyDto.clientCert = "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----"; + mockHttpProxyDto.clientKey = "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----"; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWhenProxyHostIsNull() { + mockHttpProxyDto.host = null; + mockHttpProxyDto.port = 8080; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithEmptyCertificatePinningConfig() { + HttpClient result = buildHttpClientWithMocks("", false, null); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithProxyCaCert() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + mockHttpProxyDto.caCert = "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----"; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + private void setupCommonMocks(MockedStatic storageFactoryMock, + MockedStatic cipherFactoryMock, + MockedStatic serializerMock, + HttpProxyDto proxyDto) { + cipherFactoryMock.when(() -> SplitCipherFactory.create(TEST_API_KEY, true)) + .thenReturn(mockSplitCipher); + storageFactoryMock.when(() -> StorageFactory.getGeneralInfoStorage(mockDatabase, mockSplitCipher)) + .thenReturn(mockGeneralInfoStorage); + serializerMock.when(() -> HttpProxySerializer.deserialize(mockGeneralInfoStorage)) + .thenReturn(proxyDto); + } + + private HttpClient buildHttpClientWithMocks(String certPinningConfig, boolean usingProxy, HttpProxyDto proxyDto) { + try (MockedStatic storageFactoryMock = mockStatic(StorageFactory.class); + MockedStatic cipherFactoryMock = mockStatic(SplitCipherFactory.class); + MockedStatic serializerMock = mockStatic(HttpProxySerializer.class)) { + + setupCommonMocks(storageFactoryMock, cipherFactoryMock, serializerMock, proxyDto); + + return HttpClientProvider.buildHttpClient( + TEST_API_KEY, + certPinningConfig, + usingProxy, + mockDatabase + ); + } + } + + private HttpClient buildHttpClientWithCertPinningMocks(boolean usingProxy) { + try (MockedStatic storageFactoryMock = mockStatic(StorageFactory.class); + MockedStatic cipherFactoryMock = mockStatic(SplitCipherFactory.class); + MockedStatic serializerMock = mockStatic(HttpProxySerializer.class); + MockedStatic certProviderMock = mockStatic(CertificatePinningConfigurationProvider.class)) { + + setupCommonMocks(storageFactoryMock, cipherFactoryMock, serializerMock, null); + certProviderMock.when(() -> CertificatePinningConfigurationProvider.getCertificatePinningConfiguration(HttpClientProviderTest.TEST_CERT_PINNING_CONFIG)) + .thenReturn(mockCertPinningConfig); + + HttpClient result = HttpClientProvider.buildHttpClient( + TEST_API_KEY, + HttpClientProviderTest.TEST_CERT_PINNING_CONFIG, + usingProxy, + mockDatabase + ); + + certProviderMock.verify(() -> CertificatePinningConfigurationProvider.getCertificatePinningConfiguration(HttpClientProviderTest.TEST_CERT_PINNING_CONFIG)); + return result; + } + } +} diff --git a/src/test/java/io/split/android/client/storage/general/GeneralInfoStorageImplTest.java b/src/test/java/io/split/android/client/storage/general/GeneralInfoStorageImplTest.java index 9a207f2b4..53cae12a2 100644 --- a/src/test/java/io/split/android/client/storage/general/GeneralInfoStorageImplTest.java +++ b/src/test/java/io/split/android/client/storage/general/GeneralInfoStorageImplTest.java @@ -1,6 +1,9 @@ package io.split.android.client.storage.general; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -8,19 +11,45 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.HttpProxy; +import io.split.android.client.network.ProxyCredentialsProvider; +import io.split.android.client.storage.cipher.SplitCipher; import io.split.android.client.storage.db.GeneralInfoDao; import io.split.android.client.storage.db.GeneralInfoEntity; +import io.split.android.client.utils.HttpProxySerializer; public class GeneralInfoStorageImplTest { private GeneralInfoDao mGeneralInfoDao; + private SplitCipher mAlwaysEncryptedSplitCipher; private GeneralInfoStorageImpl mGeneralInfoStorage; @Before public void setUp() { + mAlwaysEncryptedSplitCipher = mock(SplitCipher.class); + when(mAlwaysEncryptedSplitCipher.encrypt(anyString())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + return "encrypted_" + invocation.getArgument(0); + } + }); + when(mAlwaysEncryptedSplitCipher.decrypt(anyString())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + return "decrypted_" + invocation.getArgument(0); + } + }); + mGeneralInfoDao = mock(GeneralInfoDao.class); - mGeneralInfoStorage = new GeneralInfoStorageImpl(mGeneralInfoDao); + mGeneralInfoStorage = new GeneralInfoStorageImpl(mGeneralInfoDao, mAlwaysEncryptedSplitCipher); } @Test @@ -190,4 +219,125 @@ public void setRbsChangeNumberSetsValueOnDao() { verify(mGeneralInfoDao).update(argThat(entity -> entity.getName().equals("rbsChangeNumber") && entity.getLongValue() == 123L)); } + + @Test + public void getProxyConfigReturnsValueFromDao() { + when(mGeneralInfoDao.getByName("proxyConfig")) + .thenReturn(new GeneralInfoEntity("proxyConfig", "encrypted_proxyConfigValue")); + String proxyConfig = mGeneralInfoStorage.getProxyConfig(); + + assertEquals("decrypted_encrypted_proxyConfigValue", proxyConfig); + verify(mGeneralInfoDao).getByName("proxyConfig"); + verify(mAlwaysEncryptedSplitCipher).decrypt("encrypted_proxyConfigValue"); + } + + @Test + public void getProxyConfigReturnsNullIfEntityIsNull() { + when(mGeneralInfoDao.getByName("proxyConfig")).thenReturn(null); + String proxyConfig = mGeneralInfoStorage.getProxyConfig(); + + assertNull(proxyConfig); + } + + @Test + public void setProxyConfigSetsValueOnDao() { + mGeneralInfoStorage.setProxyConfig("proxyConfigValue"); + + verify(mAlwaysEncryptedSplitCipher).encrypt("proxyConfigValue"); + verify(mGeneralInfoDao).update(argThat(entity -> + entity.getName().equals("proxyConfig") && + entity.getStringValue().equals("encrypted_proxyConfigValue"))); + } + + @Test + public void testSerializeAndStoreHttpProxy() { + String testHost = "proxy.example.com"; + int testPort = 8080; + String testUsername = "testuser"; + String testPassword = "testpass"; + String testClientCert = "-----BEGIN CERTIFICATE-----\nMIICertificateContent\n-----END CERTIFICATE-----"; + String testClientKey = "-----BEGIN PRIVATE KEY-----\nMIIKeyContent\n-----END PRIVATE KEY-----"; + String testCaCert = "-----BEGIN CA CERTIFICATE-----\nMIICACertContent\n-----END CA CERTIFICATE-----"; + + InputStream clientCertStream = new ByteArrayInputStream(testClientCert.getBytes(StandardCharsets.UTF_8)); + InputStream clientKeyStream = new ByteArrayInputStream(testClientKey.getBytes(StandardCharsets.UTF_8)); + InputStream caCertStream = new ByteArrayInputStream(testCaCert.getBytes(StandardCharsets.UTF_8)); + + ProxyCredentialsProvider credentialsProvider = mock(ProxyCredentialsProvider.class); + + HttpProxy httpProxy = HttpProxy.newBuilder(testHost, testPort) + .basicAuth(testUsername, testPassword) + .mtls(clientCertStream, clientKeyStream) + .proxyCacert(caCertStream) + .credentialsProvider(credentialsProvider) + .build(); + + String jsonProxy = HttpProxySerializer.serialize(httpProxy); + mGeneralInfoStorage.setProxyConfig(jsonProxy); + + verify(mGeneralInfoDao).update(argThat(entity -> + entity.getName().equals("proxyConfig") && + entity.getStringValue().startsWith("encrypted_"))); + } + + @Test + public void testGetProxyConfig() { + String jsonContent = "{\"host\":\"proxy.example.com\",\"port\":8080,\"username\":\"testuser\",\"password\":\"testpass\",\"client_cert\":\"cert-data\",\"client_key\":\"key-data\",\"ca_cert\":\"ca-data\",\"bearer_token\":\"token\"}"; + when(mAlwaysEncryptedSplitCipher.encrypt(anyString())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + return invocation.getArgument(0); + } + }); + when(mAlwaysEncryptedSplitCipher.decrypt(anyString())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + return invocation.getArgument(0); + } + }); + when(mGeneralInfoDao.getByName("proxyConfig")).thenReturn(new GeneralInfoEntity("proxyConfig", jsonContent)); + + String proxyConfigJson = mGeneralInfoStorage.getProxyConfig(); + + assertNotNull("Proxy config JSON should not be null", proxyConfigJson); + + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + assertNotNull("Deserialized DTO should not be null", dto); + assertEquals("Host should match", "proxy.example.com", dto.host); + assertEquals("Port should match", 8080, dto.port); + assertEquals("Username should match", "testuser", dto.username); + assertEquals("Password should match", "testpass", dto.password); + assertEquals("Client cert should match", "cert-data", dto.clientCert); + assertEquals("Client key should match", "key-data", dto.clientKey); + assertEquals("CA cert should match", "ca-data", dto.caCert); + assertEquals("token", dto.bearerToken); + } + + @Test + public void proxyConfigIsNullWhenStoredDataIsNull() { + when(mGeneralInfoDao.getByName("proxyConfig")).thenReturn(null); + + String proxyConfig = mGeneralInfoStorage.getProxyConfig(); + + assertNull("Proxy config should be null when entity is null", proxyConfig); + } + + @Test + public void proxyConfigIsNullWhenTheStoredValueIsNull() { + GeneralInfoEntity entity = new GeneralInfoEntity("proxyConfig", (String) null); + when(mGeneralInfoDao.getByName("proxyConfig")).thenReturn(entity); + + String proxyConfig = mGeneralInfoStorage.getProxyConfig(); + + assertNull("Proxy config should be null when entity value is null", proxyConfig); + } + + @Test + public void proxyConfigCanBeSetToNull() { + mGeneralInfoStorage.setProxyConfig(null); + + verify(mGeneralInfoDao).update(argThat(entity -> + entity.getName().equals("proxyConfig") && + entity.getStringValue() == null)); + } } diff --git a/src/test/java/io/split/android/client/utils/HttpProxySerializerTest.java b/src/test/java/io/split/android/client/utils/HttpProxySerializerTest.java new file mode 100644 index 000000000..3f1052be0 --- /dev/null +++ b/src/test/java/io/split/android/client/utils/HttpProxySerializerTest.java @@ -0,0 +1,115 @@ +package io.split.android.client.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.BasicCredentialsProvider; +import io.split.android.client.network.HttpProxy; +import io.split.android.client.network.ProxyCredentialsProvider; +import io.split.android.client.storage.general.GeneralInfoStorage; + +public class HttpProxySerializerTest { + + private HttpProxy mHttpProxy; + private GeneralInfoStorage mGeneralInfoStorage; + private final String TEST_HOST = "proxy.example.com"; + private final int TEST_PORT = 8080; + private final String TEST_USERNAME = "testuser"; + private final String TEST_PASSWORD = "testpass"; + private final String TEST_CLIENT_CERT = "-----BEGIN CERTIFICATE-----\nMIICertificateContent\n-----END CERTIFICATE-----"; + private final String TEST_CLIENT_KEY = "-----BEGIN PRIVATE KEY-----\nMIIKeyContent\n-----END PRIVATE KEY-----"; + private final String TEST_CA_CERT = "-----BEGIN CA CERTIFICATE-----\nMIICACertContent\n-----END CA CERTIFICATE-----"; + + @Before + public void setUp() { + mGeneralInfoStorage = mock(GeneralInfoStorage.class); + + // Create input streams from test strings + InputStream clientCertStream = new ByteArrayInputStream(TEST_CLIENT_CERT.getBytes(StandardCharsets.UTF_8)); + InputStream clientKeyStream = new ByteArrayInputStream(TEST_CLIENT_KEY.getBytes(StandardCharsets.UTF_8)); + InputStream caCertStream = new ByteArrayInputStream(TEST_CA_CERT.getBytes(StandardCharsets.UTF_8)); + + // Mock the credentials provider + ProxyCredentialsProvider credentialsProvider = new BasicCredentialsProvider() { + @Override + public String getUsername() { + return TEST_USERNAME; + } + + @Override + public String getPassword() { + return TEST_PASSWORD; + } + }; + + // Create the HttpProxy object + mHttpProxy = HttpProxy.newBuilder(TEST_HOST, TEST_PORT) + .basicAuth(TEST_USERNAME, TEST_PASSWORD) + .mtls(clientCertStream, clientKeyStream) + .proxyCacert(caCertStream) + .credentialsProvider(credentialsProvider) + .build(); + } + + @Test + public void serializeHttpProxyWorks() { + // Serialize the HttpProxy object + String json = HttpProxySerializer.serialize(mHttpProxy); + when(mGeneralInfoStorage.getProxyConfig()).thenReturn(json); + + // Verify the serialization result + assertNotNull("Serialized JSON should not be null", json); + + // Deserialize back to HttpProxyDto + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + + // Verify the deserialized object + assertNotNull("Deserialized DTO should not be null", dto); + assertEquals("Host should match", TEST_HOST, dto.host); + assertEquals("Port should match", TEST_PORT, dto.port); + assertEquals("Username should match", TEST_USERNAME, dto.username); + assertEquals("Password should match", TEST_PASSWORD, dto.password); + assertEquals("Client cert should match", TEST_CLIENT_CERT, dto.clientCert); + assertEquals("Client key should match", TEST_CLIENT_KEY, dto.clientKey); + assertEquals("CA cert should match", TEST_CA_CERT, dto.caCert); + assertNull("Bearer token should be null", dto.bearerToken); + } + + @Test + public void testSerializeNullHttpProxy() { + String json = HttpProxySerializer.serialize(null); + assertNull("Serializing null should return null", json); + } + + @Test + public void deserializeNullJsonReturnsNull() { + when(mGeneralInfoStorage.getProxyConfig()).thenReturn(null); + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + assertNull("Deserializing null should return null", dto); + } + + @Test + public void deserializeEmptyJsonReturnsNull() { + when(mGeneralInfoStorage.getProxyConfig()).thenReturn(""); + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + assertNull("Deserializing empty string should return null", dto); + } + + @Test + public void deserializeInvalidJsonReturnsNull() { + when(mGeneralInfoStorage.getProxyConfig()).thenReturn("{ invalid json }"); + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + assertNull("Deserializing invalid JSON should return null", dto); + } +}