diff --git a/CHANGES.txt b/CHANGES.txt index b9ead323f..4052f3f16 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,7 @@ +5.4.0 (Sep 12, 2025) +- Added new configuration for Fallback Treatments, which allows setting a treatment value and optional config to be returned in place of "control", either globally or by flag. Read more in our docs. +- Added ProxyConfiguration parameter to support proxies, including Harness Forward Proxy, allowing also for more secured authentication options: MTLS, Bearer token and user/password authentication. Read more in our docs. + 5.3.2 (Aug 20, 2025) - Fixed issue with uncaught exception on flags preloading. diff --git a/build.gradle b/build.gradle index 4df504395..a0030dcd5 100644 --- a/build.gradle +++ b/build.gradle @@ -19,7 +19,7 @@ apply from: 'spec.gradle' apply from: 'jacoco.gradle' ext { - splitVersion = '5.3.2' + splitVersion = '5.4.0' jacocoVersion = '0.8.8' } @@ -284,6 +284,3 @@ tasks.withType(Test) { forkEvery = 100 maxHeapSize = "1024m" } - - - diff --git a/src/androidTest/java/fake/HttpResponseMock.java b/src/androidTest/java/fake/HttpResponseMock.java index 38fbc5ba6..ba0dd982b 100644 --- a/src/androidTest/java/fake/HttpResponseMock.java +++ b/src/androidTest/java/fake/HttpResponseMock.java @@ -1,8 +1,6 @@ package fake; -import java.io.PipedInputStream; -import java.io.PipedOutputStream; -import java.util.concurrent.BlockingQueue; +import java.security.cert.Certificate; import io.split.android.client.network.BaseHttpResponseImpl; import io.split.android.client.network.HttpResponse; @@ -25,4 +23,9 @@ public HttpResponseMock(int status, String data) { public String getData() { return data; } + + @Override + public Certificate[] getServerCertificates() { + return new Certificate[0]; + } } diff --git a/src/androidTest/java/fake/HttpResponseStub.java b/src/androidTest/java/fake/HttpResponseStub.java index 085ea5114..a23c08a17 100644 --- a/src/androidTest/java/fake/HttpResponseStub.java +++ b/src/androidTest/java/fake/HttpResponseStub.java @@ -1,5 +1,7 @@ package fake; +import java.security.cert.Certificate; + import io.split.android.client.network.BaseHttpResponseImpl; import io.split.android.client.network.HttpResponse; @@ -29,4 +31,9 @@ public boolean isSuccess() { public String getData() { return data; } + + @Override + public Certificate[] getServerCertificates() { + return new Certificate[0]; + } } diff --git a/src/androidTest/java/helper/TestableSplitConfigBuilder.java b/src/androidTest/java/helper/TestableSplitConfigBuilder.java index 34449f445..8b21cfd49 100644 --- a/src/androidTest/java/helper/TestableSplitConfigBuilder.java +++ b/src/androidTest/java/helper/TestableSplitConfigBuilder.java @@ -6,9 +6,11 @@ import io.split.android.client.ServiceEndpoints; import io.split.android.client.SplitClientConfig; import io.split.android.client.SyncConfig; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; import io.split.android.client.impressions.ImpressionListener; import io.split.android.client.network.CertificatePinningConfiguration; import io.split.android.client.network.DevelopmentSslConfig; +import io.split.android.client.network.ProxyConfiguration; import io.split.android.client.network.SplitAuthenticator; import io.split.android.client.service.ServiceConstants; import io.split.android.client.service.impressions.ImpressionsMode; @@ -66,6 +68,8 @@ public class TestableSplitConfigBuilder { private CertificatePinningConfiguration mCertificatePinningConfiguration; private long mImpressionsDedupeTimeInterval = ServiceConstants.DEFAULT_IMPRESSIONS_DEDUPE_TIME_INTERVAL; private RolloutCacheConfiguration mRolloutCacheConfiguration = RolloutCacheConfiguration.builder().build(); + private FallbackTreatmentsConfiguration mFallbackTreatments; + private ProxyConfiguration mProxyConfiguration = null; public TestableSplitConfigBuilder() { mServiceEndpoints = ServiceEndpoints.builder().build(); @@ -281,6 +285,16 @@ public TestableSplitConfigBuilder rolloutCacheConfiguration(RolloutCacheConfigur return this; } + public TestableSplitConfigBuilder fallbackTreatments(FallbackTreatmentsConfiguration fallbackTreatments) { + this.mFallbackTreatments = fallbackTreatments; + return this; + } + + public TestableSplitConfigBuilder logger(ProxyConfiguration proxyConfiguration) { + this.mProxyConfiguration = proxyConfiguration; + return this; + } + public SplitClientConfig build() { Constructor constructor = SplitClientConfig.class.getDeclaredConstructors()[0]; constructor.setAccessible(true); @@ -337,7 +351,9 @@ public SplitClientConfig build() { mObserverCacheExpirationPeriod, mCertificatePinningConfiguration, mImpressionsDedupeTimeInterval, - mRolloutCacheConfiguration); + mRolloutCacheConfiguration, + mProxyConfiguration, + mFallbackTreatments); Logger.instance().setLevel(mLogLevel); return config; diff --git a/src/androidTest/java/tests/integration/fallback/FallbackTreatmentsTest.java b/src/androidTest/java/tests/integration/fallback/FallbackTreatmentsTest.java new file mode 100644 index 000000000..53c853af6 --- /dev/null +++ b/src/androidTest/java/tests/integration/fallback/FallbackTreatmentsTest.java @@ -0,0 +1,557 @@ +package tests.integration.fallback; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import android.content.Context; + +import androidx.test.platform.app.InstrumentationRegistry; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import helper.IntegrationHelper; +import io.split.android.client.ServiceEndpoints; +import io.split.android.client.SplitClient; +import io.split.android.client.SplitClientConfig; +import io.split.android.client.SplitFactory; +import io.split.android.client.SplitResult; +import io.split.android.client.api.Key; +import io.split.android.client.events.SplitEvent; +import io.split.android.client.events.SplitEventTask; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatment; +import io.split.android.client.impressions.Impression; +import io.split.android.client.impressions.ImpressionListener; +import io.split.android.client.service.impressions.ImpressionsMode; +import io.split.android.client.utils.logger.SplitLogLevel; +import io.split.android.grammar.Treatments; +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +public class FallbackTreatmentsTest { + + private Context mContext; + private MockWebServer mWebServer; + private int mCurSplitReqId; + + private ServiceEndpoints endpoints() { + final String url = mWebServer.url("/").url().toString(); + return ServiceEndpoints.builder() + .apiEndpoint(url) + .eventsEndpoint(url) + .build(); + } + + // Helpers + private static ImpressionListener createImpressionCapturingListener(final List sink) { + return new ImpressionListener() { + @Override + public void log(Impression impression) { sink.add(impression); } + @Override + public void close() { } + }; + } + + private static SplitClientConfig buildDebugConfigWithListener(ServiceEndpoints endpoints, + FallbackTreatmentsConfiguration fbConfig, + ImpressionListener listener) { + return SplitClientConfig.builder() + .serviceEndpoints(endpoints) + .ready(30000) + .featuresRefreshRate(3) + .segmentsRefreshRate(3) + .trafficType("account") + .impressionsRefreshRate(1) + .impressionsMode(ImpressionsMode.DEBUG) + .fallbackTreatments(fbConfig) + .impressionListener(listener) + .build(); + } + + private static void assertPayloadHasOnlyKnownFlagNoDnf(String body) { + boolean hasKnown = body.contains("\"f\":\"real_flag\"") || body.contains("real_flag"); + boolean hasUnknownFlag = body.contains("\"f\":\"dnf_flag\""); + boolean hasDnfLabel = body.contains("\"r\":\"definition not found\""); + boolean hasFallbackDnfLabel = body.contains("fallback - definition not found"); + + assertTrue("Expected at least one impression for real_flag", hasKnown); + assertFalse("Unknown flag should not produce impressions", hasUnknownFlag); + assertFalse("Label 'definition not found' should not appear in impressions", hasDnfLabel); + assertFalse("Label 'fallback - definition not found' should not appear in impressions", hasFallbackDnfLabel); + } + + private static void assertLocalNoUnknownOrDnf(List captured) { + assertEquals("Expected exactly one impression locally (real_flag)", 1, captured.size()); + Impression imp = captured.get(0); + assertEquals("real_flag", imp.split()); + String label = imp.appliedRule(); + assertFalse("Label 'definition not found' should not appear in impressions (listener)", + "definition not found".equals(label)); + assertFalse("Label 'fallback - definition not found' should not appear in impressions (listener)", + label != null && label.contains("fallback - definition not found")); + } + + private SplitClientConfig buildConfig(FallbackTreatmentsConfiguration fbConfig) { + return buildConfig(fbConfig, false, null); + } + + private SplitClientConfig buildConfig(FallbackTreatmentsConfiguration fbConfig, boolean debugImpressions, Integer impressionsRefreshRate) { + SplitClientConfig.Builder builder = SplitClientConfig.builder() + .serviceEndpoints(endpoints()) + .ready(30000) + .featuresRefreshRate(3) + .segmentsRefreshRate(3) + .logLevel(SplitLogLevel.VERBOSE) + .trafficType("account"); + if (impressionsRefreshRate != null) { + builder.impressionsRefreshRate(impressionsRefreshRate); + } else { + builder.impressionsRefreshRate(3); + } + if (debugImpressions) { + builder.impressionsMode(ImpressionsMode.DEBUG); + } + if (fbConfig != null) { + builder.fallbackTreatments(fbConfig); + } + return builder.build(); + } + + private SplitFactory buildFactory(SplitClientConfig config) { + return IntegrationHelper.buildFactory( + IntegrationHelper.dummyApiKey(), new Key("DEFAULT_KEY"), config, mContext, null); + } + + private void awaitReady(SplitClient client) throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + client.on(SplitEvent.SDK_READY, new SplitEventTask() { + @Override + public void onPostExecution(SplitClient client) { + latch.countDown(); + } + }); + latch.await(30, TimeUnit.SECONDS); + } + + @Before + public void setup() { + mWebServer = new MockWebServer(); + mCurSplitReqId = 1; + final Dispatcher dispatcher = new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + final String path = request.getPath(); + if (path.contains("/" + IntegrationHelper.ServicePath.MEMBERSHIPS)) { + return new MockResponse().setResponseCode(200).setBody(IntegrationHelper.dummyAllSegments()); + } else if (path.contains("/splitChanges")) { + // Return empty changes to keep no real flags available + long id = mCurSplitReqId++; + return new MockResponse().setResponseCode(200) + .setBody(IntegrationHelper.emptyTargetingRulesChanges(id, id)); + } else if (path.contains("/testImpressions/bulk")) { + return new MockResponse().setResponseCode(200); + } + return new MockResponse().setResponseCode(404); + } + }; + mWebServer.setDispatcher(dispatcher); + mContext = InstrumentationRegistry.getInstrumentation().getContext(); + } + + @After + public void tearDown() throws Exception { + if (mWebServer != null) mWebServer.shutdown(); + } + + @Test + public void case1_controlTreatment_noFallbacks_returnsControlForUnknownFlags_andTwoKeys() throws Exception { + SplitClientConfig config = buildConfig(null); + + SplitFactory factory = buildFactory(config); + + SplitClient clientKey1 = factory.client(new Key("key_1")); + SplitClient clientKey2 = factory.client(new Key("key_2")); + + awaitReady(clientKey1); + + String t1_flag1 = clientKey1.getTreatment("non_existent_flag"); + String t1_flag2 = clientKey1.getTreatment("non_existent_flag_2"); + String t2_flag1 = clientKey2.getTreatment("non_existent_flag"); + String t2_flag2 = clientKey2.getTreatment("non_existent_flag_2"); + + // Assert + assertEquals(Treatments.CONTROL, t1_flag1); + assertEquals(Treatments.CONTROL, t1_flag2); + assertEquals(Treatments.CONTROL, t2_flag1); + assertEquals(Treatments.CONTROL, t2_flag2); + + factory.destroy(); + } + + @Test + public void case6_impressionsCorrectnessWithFallbackLabelsPrefixedForOverriddenFlagOnlyNotReadyForOthers() throws Exception { + final String url = mWebServer.url("/").url().toString(); + ServiceEndpoints endpoints = ServiceEndpoints.builder() + .apiEndpoint(url) + .eventsEndpoint(url) + .build(); + + final StringBuilder postedImpressions = new StringBuilder(); + final CountDownLatch impressionsLatch = new CountDownLatch(1); + mCurSplitReqId = 1; + final Dispatcher dispatcher = new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + final String path = request.getPath(); + if (path.contains("/" + IntegrationHelper.ServicePath.MEMBERSHIPS)) { + return new MockResponse().setResponseCode(200).setBody(IntegrationHelper.dummyAllSegments()); + } else if (path.contains("/splitChanges")) { + long id = mCurSplitReqId++; + // Keep no real flags to ensure not-ready path applies before SDK ready + return new MockResponse().setResponseCode(200) + .setBody(IntegrationHelper.emptyTargetingRulesChanges(id, id)); + } else if (path.contains("/testImpressions/bulk")) { + try { + // Capture body for assertions + postedImpressions.append(request.getBody().readUtf8()); + } catch (Exception ignore) { } + impressionsLatch.countDown(); + return new MockResponse().setResponseCode(200); + } + return new MockResponse().setResponseCode(404); + } + }; + mWebServer.setDispatcher(dispatcher); + + Map byFlag = new HashMap<>(); + byFlag.put("any_flag", new FallbackTreatment("OFF_FALLBACK")); + FallbackTreatmentsConfiguration fbConfig = FallbackTreatmentsConfiguration.builder() + .byFlag(byFlag) + .build(); + + SplitClientConfig config = buildConfig(fbConfig, true, 1); + + SplitFactory factory = buildFactory(config); + + SplitClient c = factory.client(new Key("key_1")); + + String t_overridden = c.getTreatment("any_flag"); + String t_other = c.getTreatment("other_flag"); + Thread.sleep(1000); + c.flush(); + + impressionsLatch.await(5, TimeUnit.SECONDS); + + assertEquals("OFF_FALLBACK", t_overridden); + + String body = postedImpressions.toString(); + System.out.println("IMPRESSIONS BODY: " + body); + boolean hasPrefixed = body.contains("\"f\":\"any_flag\"") && body.contains("\"r\":\"fallback - not ready\""); + boolean hasPlain = body.contains("\"f\":\"other_flag\"") && body.contains("\"r\":\"not ready\""); + if (!hasPrefixed || !hasPlain) { + hasPrefixed = body.contains("fallback - not ready"); + hasPlain = body.contains("\"r\":\"not ready\""); + } + assertTrue("Expected impression with label 'fallback - not ready' for any_flag", hasPrefixed); + assertTrue("Expected impression with label 'not ready' for other_flag", hasPlain); + + factory.destroy(); + } + + @Test + public void case5_overrideAppliesOnlyWhenOriginalWouldBeControlRealFlagUnaffectedUnknownGetsFallback() throws Exception { + final String url = mWebServer.url("/").url().toString(); + ServiceEndpoints endpoints = ServiceEndpoints.builder() + .apiEndpoint(url) + .eventsEndpoint(url) + .build(); + + final Dispatcher dispatcher = new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + final String path = request.getPath(); + if (path.contains("/" + IntegrationHelper.ServicePath.MEMBERSHIPS)) { + return new MockResponse().setResponseCode(200).setBody(IntegrationHelper.dummyAllSegments()); + } else if (path.contains("/splitChanges")) { + String change = IntegrationHelper.loadSplitChanges(mContext, "simple_split.json"); + change = change.replace("\"workm\"", "\"real_flag\""); + return new MockResponse().setResponseCode(200).setBody(change); + } else if (path.contains("/testImpressions/bulk")) { + return new MockResponse().setResponseCode(200); + } + return new MockResponse().setResponseCode(404); + } + }; + mWebServer.setDispatcher(dispatcher); + + FallbackTreatmentsConfiguration fbConfig = FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("OFF_FALLBACK")) + .build(); + + SplitClientConfig config = buildConfig(fbConfig); + + SplitFactory factory = buildFactory(config); + + SplitClient clientKey1 = factory.client(new Key("key_1")); + + awaitReady(clientKey1); + + String realFlag = clientKey1.getTreatment("real_flag"); + String unknown = clientKey1.getTreatment("non_existent_flag"); + + assertEquals("on", realFlag); + assertEquals("OFF_FALLBACK", unknown); + + factory.destroy(); + } + + @Test + public void case4_FlagOverrideBeatsFactoryDefaultReturnsOnFallbackForOverriddenAndOffFallbackForOthers() throws Exception { + Map byFlag = new HashMap<>(); + byFlag.put("my_flag", new FallbackTreatment("ON_FALLBACK")); + FallbackTreatmentsConfiguration fbConfig = FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("OFF_FALLBACK")) + .byFlag(byFlag) + .build(); + + SplitClientConfig config = buildConfig(fbConfig); + + SplitFactory factory = buildFactory(config); + + SplitClient clientKey1 = factory.client(new Key("key_1")); + SplitClient clientKey2 = factory.client(new Key("key_2")); + + awaitReady(clientKey1); + + String t1_myFlag = clientKey1.getTreatment("my_flag"); + String t1_other = clientKey1.getTreatment("non_existent_flag_2"); + String t2_myFlag = clientKey2.getTreatment("my_flag"); + String t2_other = clientKey2.getTreatment("non_existent_flag_2"); + + assertEquals("ON_FALLBACK", t1_myFlag); + assertEquals("OFF_FALLBACK", t1_other); + assertEquals("ON_FALLBACK", t2_myFlag); + assertEquals("OFF_FALLBACK", t2_other); + + factory.destroy(); + } + + @Test + public void case2_factoryWideOverrideReturnsFallbackForUnknownFlagsAndTwoKeys() throws Exception { + // endpoints provided by helper in buildConfig + + FallbackTreatmentsConfiguration fbConfig = FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("FALLBACK_TREATMENT")) + .build(); + + SplitClientConfig config = buildConfig(fbConfig); + + SplitFactory factory = buildFactory(config); + + SplitClient clientKey1 = factory.client(new Key("key_1")); + SplitClient clientKey2 = factory.client(new Key("key_2")); + + CountDownLatch readyLatch = new CountDownLatch(1); + clientKey1.on(SplitEvent.SDK_READY, new SplitEventTask() { + @Override + public void onPostExecution(SplitClient client) { + readyLatch.countDown(); + } + }); + readyLatch.await(5, TimeUnit.SECONDS); + + String t1_flag1 = clientKey1.getTreatment("non_existent_flag"); + String t1_flag2 = clientKey1.getTreatment("non_existent_flag_2"); + String t2_flag1 = clientKey2.getTreatment("non_existent_flag"); + String t2_flag2 = clientKey2.getTreatment("non_existent_flag_2"); + + assertEquals("FALLBACK_TREATMENT", t1_flag1); + assertEquals("FALLBACK_TREATMENT", t1_flag2); + assertEquals("FALLBACK_TREATMENT", t2_flag1); + assertEquals("FALLBACK_TREATMENT", t2_flag2); + + factory.destroy(); + } + + @Test + public void case3_factorySpecificOverrideReturnsFallbackForOneFlagAndControlForOthersAndTwoKeys() throws Exception { + final String url = mWebServer.url("/").url().toString(); + ServiceEndpoints endpoints = ServiceEndpoints.builder() + .apiEndpoint(url) + .eventsEndpoint(url) + .build(); + + Map byFlag = new HashMap<>(); + byFlag.put("non_existent_flag", new FallbackTreatment("FALLBACK_TREATMENT")); + FallbackTreatmentsConfiguration fbConfig = FallbackTreatmentsConfiguration.builder() + .byFlag(byFlag) + .build(); + + SplitClientConfig config = SplitClientConfig.builder() + .serviceEndpoints(endpoints) + .ready(30000) + .featuresRefreshRate(3) + .segmentsRefreshRate(3) + .impressionsRefreshRate(3) + .logLevel(SplitLogLevel.DEBUG) + .trafficType("account") + .fallbackTreatments(fbConfig) + .build(); + + SplitFactory factory = IntegrationHelper.buildFactory( + IntegrationHelper.dummyApiKey(), new Key("DEFAULT_KEY"), config, mContext, null); + + SplitClient clientKey1 = factory.client(new Key("key_1")); + SplitClient clientKey2 = factory.client(new Key("key_2")); + + CountDownLatch readyLatch = new CountDownLatch(1); + clientKey1.on(SplitEvent.SDK_READY, new SplitEventTask() { + @Override + public void onPostExecution(SplitClient client) { + readyLatch.countDown(); + } + }); + readyLatch.await(5, TimeUnit.SECONDS); + + String t1_flag1 = clientKey1.getTreatment("non_existent_flag"); + String t1_flag2 = clientKey1.getTreatment("non_existent_flag_2"); + String t2_flag1 = clientKey2.getTreatment("non_existent_flag"); + String t2_flag2 = clientKey2.getTreatment("non_existent_flag_2"); + + assertEquals("FALLBACK_TREATMENT", t1_flag1); + assertEquals(Treatments.CONTROL, t1_flag2); + assertEquals("FALLBACK_TREATMENT", t2_flag1); + assertEquals(Treatments.CONTROL, t2_flag2); + + factory.destroy(); + } + + @Test + public void case7_fallbackDynamicConfigPropagationTreatmentAndConfigReturned() throws Exception { + final String url = mWebServer.url("/").url().toString(); + ServiceEndpoints endpoints = ServiceEndpoints.builder() + .apiEndpoint(url) + .eventsEndpoint(url) + .build(); + + + Map byFlag = new HashMap<>(); + byFlag.put("my_flag", new FallbackTreatment("ON_FALLBACK", "{\"flag\":true}")); + FallbackTreatmentsConfiguration fbConfig = FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("OFF_FALLBACK", "{\"global\":true}")) + .byFlag(byFlag) + .build(); + + SplitClientConfig config = SplitClientConfig.builder() + .serviceEndpoints(endpoints) + .ready(30000) + .featuresRefreshRate(3) + .segmentsRefreshRate(3) + .impressionsRefreshRate(3) + .logLevel(SplitLogLevel.DEBUG) + .trafficType("account") + .fallbackTreatments(fbConfig) + .build(); + + SplitFactory factory = IntegrationHelper.buildFactory( + IntegrationHelper.dummyApiKey(), new Key("DEFAULT_KEY"), config, mContext, null); + + SplitClient client = factory.client(new Key("key_1")); + + CountDownLatch readyLatch = new CountDownLatch(1); + client.on(SplitEvent.SDK_READY, new SplitEventTask() { + @Override + public void onPostExecution(SplitClient client) { + readyLatch.countDown(); + } + }); + readyLatch.await(5, TimeUnit.SECONDS); + + SplitResult rMy = client.getTreatmentWithConfig("my_flag", null); + SplitResult rUnknown = client.getTreatmentWithConfig("non_existent_flag", null); + + assertEquals("ON_FALLBACK", rMy.treatment()); + assertEquals("{\"flag\":true}", rMy.config()); + assertEquals("OFF_FALLBACK", rUnknown.treatment()); + assertEquals("{\"global\":true}", rUnknown.config()); + + factory.destroy(); + } + + @Test + public void case8_noImpressionsForDefinitionNotFoundOrFallbackDefinitionNotFoundAfterReady() throws Exception { + final String url = mWebServer.url("/").url().toString(); + ServiceEndpoints endpoints = ServiceEndpoints.builder() + .apiEndpoint(url) + .eventsEndpoint(url) + .build(); + + final StringBuilder postedImpressions = new StringBuilder(); + final CountDownLatch impressionsLatch = new CountDownLatch(1); + mCurSplitReqId = 1; + final Dispatcher dispatcher = new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + final String path = request.getPath(); + if (path.contains("/" + IntegrationHelper.ServicePath.MEMBERSHIPS)) { + return new MockResponse().setResponseCode(200).setBody(IntegrationHelper.dummyAllSegments()); + } else if (path.contains("/splitChanges")) { + // Serve a real flag so we do generate impressions in DEBUG mode + String change = IntegrationHelper.loadSplitChanges(mContext, "simple_split.json"); + change = change.replace("\"workm\"", "\"real_flag\""); + return new MockResponse().setResponseCode(200).setBody(change); + } else if (path.contains("/testImpressions/bulk")) { + try { + postedImpressions.append(request.getBody().readUtf8()); + } catch (Exception ignore) { } + impressionsLatch.countDown(); + return new MockResponse().setResponseCode(200); + } + return new MockResponse().setResponseCode(404); + } + }; + mWebServer.setDispatcher(dispatcher); + + // Configure global fallback so unknown flags return a fallback treatment + FallbackTreatmentsConfiguration fbConfig = FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("OFF_FALLBACK")) + .build(); + + final List capturedImpressions = Collections.synchronizedList(new ArrayList<>()); + ImpressionListener listener = createImpressionCapturingListener(capturedImpressions); + + SplitClientConfig config = buildDebugConfigWithListener(endpoints, fbConfig, listener); + SplitFactory factory = buildFactory(config); + + SplitClient client = factory.client(new Key("key_1")); + awaitReady(client); + + // Evaluate a real flag (will log impression) and an unknown flag (should not log impression) + String tUnknown = client.getTreatment("dnf_flag"); + String tKnown = client.getTreatment("real_flag"); + + // Push impressions + Thread.sleep(1000); + client.flush(); + impressionsLatch.await(5, TimeUnit.SECONDS); + + String body = postedImpressions.toString(); + assertPayloadHasOnlyKnownFlagNoDnf(body); + assertLocalNoUnknownOrDnf(capturedImpressions); + + factory.destroy(); + } +} diff --git a/src/androidTest/java/tests/storage/GeneralInfoStorageTest.java b/src/androidTest/java/tests/storage/GeneralInfoStorageTest.java index 91361df4c..f214b523d 100644 --- a/src/androidTest/java/tests/storage/GeneralInfoStorageTest.java +++ b/src/androidTest/java/tests/storage/GeneralInfoStorageTest.java @@ -21,7 +21,7 @@ public class GeneralInfoStorageTest { @Before public void setUp() { mDb = DatabaseHelper.getTestDatabase(InstrumentationRegistry.getInstrumentation().getContext()); - mGeneralInfoStorage = new GeneralInfoStorageImpl(mDb.generalInfoDao()); + mGeneralInfoStorage = new GeneralInfoStorageImpl(mDb.generalInfoDao(), null); } @After diff --git a/src/androidTest/java/tests/workmanager/WorkManagerWrapperTest.java b/src/androidTest/java/tests/workmanager/WorkManagerWrapperTest.java index a98e6a0e2..8f3539423 100644 --- a/src/androidTest/java/tests/workmanager/WorkManagerWrapperTest.java +++ b/src/androidTest/java/tests/workmanager/WorkManagerWrapperTest.java @@ -295,6 +295,7 @@ private Data buildInputData(Data customData) { dataBuilder.putString("databaseName", "test_database_name"); dataBuilder.putString("apiKey", "api_key"); dataBuilder.putBoolean("encryptionEnabled", false); + dataBuilder.putBoolean("usesProxy", false); if (customData != null) { dataBuilder.putAll(customData); } diff --git a/src/main/java/io/split/android/client/EvaluatorImpl.java b/src/main/java/io/split/android/client/EvaluatorImpl.java index 28166977b..7eb3db19d 100644 --- a/src/main/java/io/split/android/client/EvaluatorImpl.java +++ b/src/main/java/io/split/android/client/EvaluatorImpl.java @@ -4,6 +4,10 @@ import io.split.android.client.dtos.ConditionType; import io.split.android.client.exceptions.ChangeNumberExceptionWrapper; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatment; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; import io.split.android.client.storage.splits.SplitsStorage; import io.split.android.client.utils.logger.Logger; import io.split.android.engine.experiments.ParsedCondition; @@ -11,16 +15,21 @@ import io.split.android.engine.experiments.SplitParser; import io.split.android.engine.matchers.PrerequisitesMatcher; import io.split.android.engine.splitter.Splitter; -import io.split.android.grammar.Treatments; public class EvaluatorImpl implements Evaluator { private final SplitsStorage mSplitsStorage; private final SplitParser mSplitParser; + private final FallbackTreatmentsCalculator mFallbackCalculator; public EvaluatorImpl(SplitsStorage splitsStorage, SplitParser splitParser) { + this(splitsStorage, splitParser, new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())); + } + + public EvaluatorImpl(SplitsStorage splitsStorage, SplitParser splitParser, FallbackTreatmentsCalculator fallbackCalculator) { mSplitsStorage = splitsStorage; mSplitParser = splitParser; + mFallbackCalculator = fallbackCalculator; } @Override @@ -29,16 +38,19 @@ public EvaluationResult getTreatment(String matchingKey, String bucketingKey, St try { ParsedSplit parsedSplit = mSplitParser.parse(mSplitsStorage.get(splitName), matchingKey); if (parsedSplit == null) { - return new EvaluationResult(Treatments.CONTROL, TreatmentLabels.DEFINITION_NOT_FOUND, true); + FallbackTreatment fallback = mFallbackCalculator.resolve(splitName, TreatmentLabels.DEFINITION_NOT_FOUND); + return new EvaluationResult(fallback.getTreatment(), fallback.getLabel(), null, fallback.getConfig(), true); } return getTreatment(matchingKey, bucketingKey, parsedSplit, attributes); } catch (ChangeNumberExceptionWrapper ex) { Logger.e(ex, "Catch Change Number Exception"); - return new EvaluationResult(Treatments.CONTROL, TreatmentLabels.EXCEPTION, ex.changeNumber(), true); + FallbackTreatment fallback = mFallbackCalculator.resolve(splitName, TreatmentLabels.EXCEPTION); + return new EvaluationResult(fallback.getTreatment(), fallback.getLabel(), ex.changeNumber(), fallback.getConfig(), true); } catch (Exception e) { Logger.e(e, "Catch All Exception"); - return new EvaluationResult(Treatments.CONTROL, TreatmentLabels.EXCEPTION, true); + FallbackTreatment fallback = mFallbackCalculator.resolve(splitName, TreatmentLabels.EXCEPTION); + return new EvaluationResult(fallback.getTreatment(), fallback.getLabel(), null, fallback.getConfig(), true); } } @@ -95,7 +107,7 @@ private EvaluationResult getTreatment(String matchingKey, String bucketingKey, P } if (parsedCondition.matcher().match(matchingKey, bucketingKey, attributes, this)) { - String treatment = Splitter.getTreatment(bk, parsedSplit.seed(), parsedCondition.partitions(), parsedSplit.algo()); + String treatment = Splitter.getTreatment(bk, parsedSplit.seed(), parsedCondition.partitions(), parsedSplit.algo(), mFallbackCalculator); return new EvaluationResult(treatment, parsedCondition.label(), parsedSplit.changeNumber(), configForTreatment(parsedSplit, treatment), parsedSplit.impressionsDisabled()); } } diff --git a/src/main/java/io/split/android/client/SplitClientConfig.java b/src/main/java/io/split/android/client/SplitClientConfig.java index 00af53115..3660ff761 100644 --- a/src/main/java/io/split/android/client/SplitClientConfig.java +++ b/src/main/java/io/split/android/client/SplitClientConfig.java @@ -4,6 +4,7 @@ import static io.split.android.client.utils.Utils.checkNotNull; import androidx.annotation.NonNull; +import androidx.annotation.Nullable; import java.net.URI; import java.util.concurrent.TimeUnit; @@ -17,6 +18,7 @@ import io.split.android.client.network.CertificatePinningConfiguration; import io.split.android.client.network.DevelopmentSslConfig; import io.split.android.client.network.HttpProxy; +import io.split.android.client.network.ProxyConfiguration; import io.split.android.client.network.SplitAuthenticator; import io.split.android.client.service.ServiceConstants; import io.split.android.client.service.impressions.ImpressionsMode; @@ -27,6 +29,7 @@ import io.split.android.client.utils.logger.SplitLogLevel; import io.split.android.client.validators.PrefixValidatorImpl; import io.split.android.client.validators.ValidationErrorInfo; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; /** * Configurations for the SplitClient. @@ -132,6 +135,10 @@ public class SplitClientConfig { private final long mImpressionsDedupeTimeInterval; @NonNull private final RolloutCacheConfiguration mRolloutCacheConfiguration; + @Nullable + private final ProxyConfiguration mProxyConfiguration; + @Nullable + private final FallbackTreatmentsConfiguration mFallbackTreatments; public static Builder builder() { return new Builder(); @@ -187,7 +194,9 @@ private SplitClientConfig(String endpoint, long observerCacheExpirationPeriod, CertificatePinningConfiguration certificatePinningConfiguration, long impressionsDedupeTimeInterval, - RolloutCacheConfiguration rolloutCacheConfiguration) { + @NonNull RolloutCacheConfiguration rolloutCacheConfiguration, + @Nullable ProxyConfiguration proxyConfiguration, + @Nullable FallbackTreatmentsConfiguration fallbackTreatments) { mEndpoint = endpoint; mEventsEndpoint = eventsEndpoint; mTelemetryEndpoint = telemetryEndpoint; @@ -246,6 +255,8 @@ private SplitClientConfig(String endpoint, mCertificatePinningConfiguration = certificatePinningConfiguration; mImpressionsDedupeTimeInterval = impressionsDedupeTimeInterval; mRolloutCacheConfiguration = rolloutCacheConfiguration; + mProxyConfiguration = proxyConfiguration; + mFallbackTreatments = fallbackTreatments; } public String trafficType() { @@ -436,7 +447,9 @@ public boolean persistentAttributesEnabled() { return mIsPersistentAttributesEnabled; } - public int offlineRefreshRate() { return mOfflineRefreshRate; } + public int offlineRefreshRate() { + return mOfflineRefreshRate; + } public boolean shouldRecordTelemetry() { return mShouldRecordTelemetry; @@ -446,7 +459,9 @@ public long telemetryRefreshRate() { return mTelemetryRefreshRate; } - public boolean syncEnabled() { return mSyncEnabled; } + public boolean syncEnabled() { + return mSyncEnabled; + } public int mtkPerPush() { return mMtkPerPush; @@ -476,7 +491,9 @@ public int sseDisconnectionDelay() { return mSSEDisconnectionDelayInSecs; } - private void enableTelemetry() { mShouldRecordTelemetry = true; } + private void enableTelemetry() { + mShouldRecordTelemetry = true; + } public long observerCacheExpirationPeriod() { return Math.max(mImpressionsDedupeTimeInterval, mObserverCacheExpirationPeriod); @@ -494,6 +511,11 @@ public RolloutCacheConfiguration rolloutCacheConfiguration() { return mRolloutCacheConfiguration; } + @Nullable + public FallbackTreatmentsConfiguration fallbackTreatments() { + return mFallbackTreatments; + } + public static final class Builder { static final int PROXY_PORT_DEFAULT = 80; @@ -571,11 +593,20 @@ public static final class Builder { private long mImpressionsDedupeTimeInterval = ServiceConstants.DEFAULT_IMPRESSIONS_DEDUPE_TIME_INTERVAL; private RolloutCacheConfiguration mRolloutCacheConfiguration = RolloutCacheConfiguration.builder().build(); + @Nullable + private FallbackTreatmentsConfiguration mFallbackTreatments = null; + + private ProxyConfiguration mProxyConfiguration = null; public Builder() { mServiceEndpoints = ServiceEndpoints.builder().build(); } + public Builder fallbackTreatments(@Nullable FallbackTreatmentsConfiguration fallbackTreatments) { + mFallbackTreatments = fallbackTreatments; + return this; + } + /** * Default Traffic Type to use in .track method * @@ -806,7 +837,9 @@ public Builder ready(int milliseconds) { * * @param proxyHost proxy URI * @return this builder + * @deprecated use {@link #proxyConfiguration(ProxyConfiguration)} */ + @Deprecated public Builder proxyHost(String proxyHost) { if (proxyHost != null && proxyHost.endsWith("/")) { mProxyHost = proxyHost.substring(0, proxyHost.length() - 1); @@ -823,6 +856,7 @@ public Builder proxyHost(String proxyHost) { * @param proxyAuthenticator * @return this builder */ + @Deprecated public Builder proxyAuthenticator(SplitAuthenticator proxyAuthenticator) { mProxyAuthenticator = proxyAuthenticator; return this; @@ -1030,6 +1064,7 @@ public Builder offlineRefreshRate(int offlineRefreshRate) { *

* This is an ADVANCED parameter *

+ * * @param telemetryRefreshRate Rate in seconds for telemetry refresh. * @return This builder * @default 3600 seconds @@ -1101,10 +1136,9 @@ public Builder certificatePinningConfiguration(CertificatePinningConfiguration c /** * This configuration is used to control the size of the impressions deduplication window. * + * @param impressionsDedupeTimeInterval The time interval in milliseconds. * @Experimental This method is experimental and may change or be removed in future versions. * To be used upon Split team recommendation. - * - * @param impressionsDedupeTimeInterval The time interval in milliseconds. */ @Deprecated public Builder impressionsDedupeTimeInterval(long impressionsDedupeTimeInterval) { @@ -1128,6 +1162,17 @@ public Builder rolloutCacheConfiguration(@NonNull RolloutCacheConfiguration roll return this; } + /** + * Sets the proxy configuration + * + * @param proxyConfiguration + * @return this builder + */ + public Builder proxyConfiguration(ProxyConfiguration proxyConfiguration) { + mProxyConfiguration = proxyConfiguration; + return this; + } + public SplitClientConfig build() { Logger.instance().setLevel(mLogLevel); @@ -1207,7 +1252,7 @@ public SplitClientConfig build() { mImpressionsDedupeTimeInterval = ServiceConstants.DEFAULT_IMPRESSIONS_DEDUPE_TIME_INTERVAL; } - HttpProxy proxy = parseProxyHost(mProxyHost); + HttpProxy proxy = parseProxyHost(mProxyHost, mProxyConfiguration); return new SplitClientConfig( mServiceEndpoints.getSdkEndpoint(), @@ -1260,10 +1305,34 @@ public SplitClientConfig build() { mObserverCacheExpirationPeriod, mCertificatePinningConfiguration, mImpressionsDedupeTimeInterval, - mRolloutCacheConfiguration); + mRolloutCacheConfiguration, + mProxyConfiguration, + mFallbackTreatments); } - private HttpProxy parseProxyHost(String proxyUri) { + private HttpProxy parseProxyHost(String proxyUri, ProxyConfiguration proxyConfiguration) { + // Use legacy proxy behavior if proxyConfiguration is null + if (proxyConfiguration == null) { + return legacyProxyBehavior(proxyUri); + } + + if (mProxyHost != null || mProxyAuthenticator != null) { + Logger.w("Both the deprecated proxy configuration methods (proxyHost, proxyAuthenticator) and the new ProxyConfiguration builder are being used. ProxyConfiguration will take precedence."); + } + + // Initialize internal config with null url. This will be verified when building the factory. + HttpProxy.Builder builder = HttpProxy.newBuilder(null, -1); + if (proxyConfiguration.getUrl() != null) { + builder = HttpProxy.newBuilder(proxyConfiguration.getUrl().getHost(), proxyConfiguration.getUrl().getPort()) + .mtls(proxyConfiguration.getClientCert(), proxyConfiguration.getClientPk()) + .proxyCacert(proxyConfiguration.getCaCert()) + .credentialsProvider(proxyConfiguration.getCredentialsProvider()); + } + return builder.build(); + } + + @Nullable + private HttpProxy legacyProxyBehavior(String proxyUri) { if (!Utils.isNullOrEmpty(proxyUri)) { try { String username = null; @@ -1271,15 +1340,19 @@ private HttpProxy parseProxyHost(String proxyUri) { URI uri = URI.create(proxyUri); int port = uri.getPort() != -1 ? uri.getPort() : PROXY_PORT_DEFAULT; String userInfo = uri.getUserInfo(); - if(!Utils.isNullOrEmpty(userInfo)) { + if (!Utils.isNullOrEmpty(userInfo)) { String[] userInfoComponents = userInfo.split(":"); - if(userInfoComponents.length > 1) { + if (userInfoComponents.length > 1) { username = userInfoComponents[0]; password = userInfoComponents[1]; } } String host = String.format("%s%s", uri.getHost(), uri.getPath()); - return new HttpProxy(host, port, username, password); + if (username != null && password != null) { + return HttpProxy.newBuilder(host, port).basicAuth(username, password).buildLegacy(); + } else { + return HttpProxy.newBuilder(host, port).buildLegacy(); + } } catch (IllegalArgumentException e) { Logger.e("Proxy URI not valid: " + e.getLocalizedMessage()); throw new IllegalArgumentException(); diff --git a/src/main/java/io/split/android/client/SplitClientFactoryImpl.java b/src/main/java/io/split/android/client/SplitClientFactoryImpl.java index d2ab82416..79ab93f9d 100644 --- a/src/main/java/io/split/android/client/SplitClientFactoryImpl.java +++ b/src/main/java/io/split/android/client/SplitClientFactoryImpl.java @@ -84,7 +84,8 @@ public SplitClientFactoryImpl(@NonNull SplitFactory splitFactory, mStorageContainer.getTelemetryStorage(), mSplitParser, flagSetsFilter, - splitsStorage + splitsStorage, + config.fallbackTreatments() ); } diff --git a/src/main/java/io/split/android/client/SplitFactoryBuilder.java b/src/main/java/io/split/android/client/SplitFactoryBuilder.java index ca31ee935..ab2bc3108 100644 --- a/src/main/java/io/split/android/client/SplitFactoryBuilder.java +++ b/src/main/java/io/split/android/client/SplitFactoryBuilder.java @@ -67,6 +67,9 @@ public static synchronized SplitFactory build(@NonNull String sdkKey, @NonNull K return new SplitFactoryImpl(sdkKey, key, config, context); } } catch (Exception ex) { + if (ex instanceof SplitInstantiationException) { + throw (SplitInstantiationException) ex; + } throw new SplitInstantiationException("Could not instantiate SplitFactory", ex); } } @@ -97,5 +100,9 @@ private static void checkPreconditions(@NonNull String sdkKey, @NonNull Key key, if (context == null) { throw new SplitInstantiationException("Could not instantiate SplitFactory. Context cannot be null"); } + + if (config.proxy() != null && config.proxy().getHost() == null) { + throw new SplitInstantiationException("Could not instantiate SplitFactory. When configured, proxy host cannot be null"); + } } } diff --git a/src/main/java/io/split/android/client/SplitFactoryHelper.java b/src/main/java/io/split/android/client/SplitFactoryHelper.java index 2c1c33d95..fa1831a09 100644 --- a/src/main/java/io/split/android/client/SplitFactoryHelper.java +++ b/src/main/java/io/split/android/client/SplitFactoryHelper.java @@ -49,7 +49,6 @@ import io.split.android.client.service.sseclient.notifications.MySegmentsV2PayloadDecoder; import io.split.android.client.service.sseclient.notifications.NotificationParser; import io.split.android.client.service.sseclient.notifications.NotificationProcessor; -import io.split.android.client.service.sseclient.notifications.SplitsChangeNotification; import io.split.android.client.service.sseclient.notifications.mysegments.MembershipsNotificationProcessorFactory; import io.split.android.client.service.sseclient.notifications.mysegments.MembershipsNotificationProcessorFactoryImpl; import io.split.android.client.service.sseclient.reactor.MySegmentsUpdateWorkerRegistry; @@ -94,6 +93,7 @@ import io.split.android.client.telemetry.TelemetrySynchronizerStub; import io.split.android.client.telemetry.storage.TelemetryRuntimeProducer; import io.split.android.client.telemetry.storage.TelemetryStorage; +import io.split.android.client.utils.HttpProxySerializer; import io.split.android.client.utils.Utils; import io.split.android.client.utils.logger.Logger; @@ -165,14 +165,15 @@ SplitStorageContainer buildStorageContainer(UserConsent userConsentStatus, TelemetryStorage telemetryStorage, long observerCacheExpirationPeriod, ScheduledThreadPoolExecutor impressionsObserverExecutor, - SplitsStorage splitsStorage) { + SplitsStorage splitsStorage, + SplitCipher alwaysEncryptedSplitCipher) { boolean isPersistenceEnabled = userConsentStatus == UserConsent.GRANTED; PersistentEventsStorage persistentEventsStorage = StorageFactory.getPersistentEventsStorage(splitRoomDatabase, splitCipher); PersistentImpressionsStorage persistentImpressionsStorage = StorageFactory.getPersistentImpressionsStorage(splitRoomDatabase, splitCipher); - GeneralInfoStorage generalInfoStorage = StorageFactory.getGeneralInfoStorage(splitRoomDatabase); + GeneralInfoStorage generalInfoStorage = StorageFactory.getGeneralInfoStorage(splitRoomDatabase, alwaysEncryptedSplitCipher); return new SplitStorageContainer( splitsStorage, StorageFactory.getMySegmentsStorage(splitRoomDatabase, splitCipher), @@ -228,6 +229,29 @@ WorkManagerWrapper buildWorkManagerWrapper(Context context, SplitClientConfig sp } + static void setupProxyForBackgroundSync(@NonNull SplitClientConfig config, Runnable proxyConfigSaveTask) { + if (config.proxy() != null && !config.proxy().isLegacy() && config.synchronizeInBackground()) { + // Store proxy config for background sync usage + new Thread(proxyConfigSaveTask).start(); + } + } + + // Visible to inject for testing + @NonNull + static Runnable getProxyConfigSaveTask(@NonNull SplitClientConfig config, WorkManagerWrapper workManagerWrapper, GeneralInfoStorage generalInfoStorage) { + return new Runnable() { + @Override + public void run() { + try { + generalInfoStorage.setProxyConfig(HttpProxySerializer.serialize(config.proxy())); + } catch (Exception ex) { + Logger.w("Failed to store proxy config for background sync. Disabling background sync", ex); + workManagerWrapper.removeWork(); + } + } + }; + } + SyncManager buildSyncManager(SplitClientConfig config, SplitTaskExecutor splitTaskExecutor, Synchronizer synchronizer, diff --git a/src/main/java/io/split/android/client/SplitFactoryImpl.java b/src/main/java/io/split/android/client/SplitFactoryImpl.java index 4aea0bb13..2ac7d457d 100644 --- a/src/main/java/io/split/android/client/SplitFactoryImpl.java +++ b/src/main/java/io/split/android/client/SplitFactoryImpl.java @@ -155,13 +155,17 @@ private SplitFactoryImpl(@NonNull String apiToken, @NonNull Key key, @NonNull Sp mConfig = config; SplitCipher splitCipher = factoryHelper.getCipher(apiToken, config.encryptionEnabled()); - SplitsStorage splitsStorage = getSplitsStorage(splitDatabase, splitCipher); + // At the moment this cipher is only used for proxy config + SplitCipher alwaysEncryptedSplitCipher = (config.synchronizeInBackground() && config.proxy() != null && !config.proxy().isLegacy()) ? + factoryHelper.getCipher(apiToken, true) : null; + + SplitsStorage splitsStorage = StorageFactory.getSplitsStorage(splitDatabase, splitCipher); ScheduledThreadPoolExecutor impressionsObserverExecutor = new ScheduledThreadPoolExecutor(1, new ThreadPoolExecutor.CallerRunsPolicy()); mStorageContainer = factoryHelper.buildStorageContainer(config.userConsent(), - splitDatabase, config.shouldRecordTelemetry(), splitCipher, telemetryStorage, config.observerCacheExpirationPeriod(), impressionsObserverExecutor, splitsStorage); + splitDatabase, config.shouldRecordTelemetry(), splitCipher, telemetryStorage, config.observerCacheExpirationPeriod(), impressionsObserverExecutor, splitsStorage, alwaysEncryptedSplitCipher); mSplitTaskExecutor = new SplitTaskExecutorImpl(); mSplitTaskExecutor.pause(); @@ -173,20 +177,27 @@ private SplitFactoryImpl(@NonNull String apiToken, @NonNull Key key, @NonNull Sp String splitsFilterQueryStringFromConfig = filtersConfig.second; String flagsSpec = getFlagsSpec(testingConfig); + FlagSetsFilter flagSetsFilter = factoryHelper.getFlagSetsFilter(filters); + WorkManagerWrapper workManagerWrapper = factoryHelper.buildWorkManagerWrapper(context, config, apiToken, databaseName, filters); + HttpClient defaultHttpClient; if (httpClient == null) { HttpClientImpl.Builder builder = new HttpClientImpl.Builder() .setConnectionTimeout(config.connectionTimeout()) .setReadTimeout(config.readTimeout()) - .setProxy(config.proxy()) .setDevelopmentSslConfig(config.developmentSslConfig()) .setContext(context) .setProxyAuthenticator(config.authenticator()); + if (config.proxy() != null) { + builder.setProxy(config.proxy()); + } if (config.certificatePinningConfiguration() != null) { builder.setCertificatePinningConfiguration(config.certificatePinningConfiguration()); } defaultHttpClient = builder.build(); + + SplitFactoryHelper.setupProxyForBackgroundSync(config, SplitFactoryHelper.getProxyConfigSaveTask(config, workManagerWrapper, mStorageContainer.getGeneralInfoStorage())); } else { defaultHttpClient = httpClient; } @@ -195,26 +206,23 @@ private SplitFactoryImpl(@NonNull String apiToken, @NonNull Key key, @NonNull Sp SplitApiFacade splitApiFacade = factoryHelper.buildApiFacade( config, defaultHttpClient, splitsFilterQueryStringFromConfig); - FlagSetsFilter flagSetsFilter = factoryHelper.getFlagSetsFilter(filters); - SplitTaskFactory splitTaskFactory = new SplitTaskFactoryImpl( config, splitApiFacade, mStorageContainer, splitsFilterQueryStringFromConfig, getFlagsSpec(testingConfig), mEventsManagerCoordinator, filters, flagSetsFilter, testingConfig); - WorkManagerWrapper workManagerWrapper = factoryHelper.buildWorkManagerWrapper(context, config, apiToken, databaseName, filters); - + SplitSingleThreadTaskExecutor splitSingleThreadTaskExecutor = new SplitSingleThreadTaskExecutor(); splitSingleThreadTaskExecutor.pause(); ImpressionStrategyProvider impressionStrategyProvider = factoryHelper.getImpressionStrategyProvider(mSplitTaskExecutor, splitTaskFactory, mStorageContainer, config); Pair noneComponents = impressionStrategyProvider.getNoneComponents(); - + mImpressionManager = new StrategyImpressionManager(noneComponents, impressionStrategyProvider.getStrategy(config.impressionsMode())); final RetryBackoffCounterTimerFactory retryBackoffCounterTimerFactory = new RetryBackoffCounterTimerFactory(); StreamingComponents streamingComponents = factoryHelper.buildStreamingComponents(mSplitTaskExecutor, splitTaskFactory, config, defaultHttpClient, splitApiFacade, mStorageContainer, flagsSpec); - + Synchronizer mSynchronizer = new SynchronizerImpl( config, mSplitTaskExecutor, @@ -379,11 +387,6 @@ public void run() { new SplitValidatorImpl(), splitParser); } - @NonNull - private static SplitsStorage getSplitsStorage(SplitRoomDatabase splitDatabase, SplitCipher splitCipher) { - return StorageFactory.getSplitsStorage(splitDatabase, splitCipher); - } - private static String getFlagsSpec(@Nullable TestingConfig testingConfig) { if (testingConfig == null) { return BuildConfig.FLAGS_SPEC; diff --git a/src/main/java/io/split/android/client/dtos/HttpProxyDto.java b/src/main/java/io/split/android/client/dtos/HttpProxyDto.java new file mode 100644 index 000000000..323b32191 --- /dev/null +++ b/src/main/java/io/split/android/client/dtos/HttpProxyDto.java @@ -0,0 +1,117 @@ +package io.split.android.client.dtos; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import com.google.gson.annotations.SerializedName; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; + +import io.split.android.client.network.BasicCredentialsProvider; +import io.split.android.client.network.BearerCredentialsProvider; + +/** + * DTO for HttpProxy serialization to JSON for storage in GeneralInfoStorage. + */ +public class HttpProxyDto { + + @SerializedName("host") + public String host; + + @SerializedName("port") + public int port; + + @SerializedName("username") + public String username; + + @SerializedName("password") + public String password; + + @SerializedName("client_cert") + public String clientCert; + + @SerializedName("client_key") + public String clientKey; + + @SerializedName("ca_cert") + public String caCert; + + @SerializedName("bearer_token") + public String bearerToken; + + public HttpProxyDto() { + // Default constructor for deserialization + } + + /** + * Constructor that creates a DTO from an HttpProxy instance. + * Note that we don't store the actual stream data, only whether they exist. + * + * @param httpProxy The HttpProxy instance to convert + */ + public HttpProxyDto(@NonNull io.split.android.client.network.HttpProxy httpProxy) { + this.host = httpProxy.getHost(); + this.port = httpProxy.getPort(); + if (httpProxy.getCredentialsProvider() instanceof BasicCredentialsProvider) { + BasicCredentialsProvider provider = (BasicCredentialsProvider) httpProxy.getCredentialsProvider(); + this.username = provider.getUsername(); + this.password = provider.getPassword(); + } else if (httpProxy.getCredentialsProvider() instanceof BearerCredentialsProvider) { + BearerCredentialsProvider provider = (BearerCredentialsProvider) httpProxy.getCredentialsProvider(); + this.bearerToken = provider.getToken(); + } + + this.clientCert = streamToString(httpProxy.getClientCertStream()); + this.clientKey = streamToString(httpProxy.getClientKeyStream()); + this.caCert = streamToString(httpProxy.getCaCertStream()); + } + + /** + * Converts an InputStream to a String. + * + * @param inputStream The InputStream to convert + * @return String representation of the InputStream contents, or null if the stream is null + */ + @Nullable + private String streamToString(@Nullable InputStream inputStream) { + if (inputStream == null) { + return null; + } + + try { + StringBuilder content = getStringBuilder(inputStream); + + // Reset the stream if possible to allow reuse + try { + inputStream.reset(); + } catch (IOException ignored) { + + } + return content.toString(); + } catch (Exception e) { + return null; + } + } + + @NonNull + private static StringBuilder getStringBuilder(@NonNull InputStream inputStream) throws IOException { + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)); + StringBuilder content = new StringBuilder(); + String line; + boolean firstLine = true; + + while ((line = reader.readLine()) != null) { + if (!firstLine) { + content.append("\n"); + } else { + firstLine = false; + } + content.append(line); + } + return content; + } +} diff --git a/src/main/java/io/split/android/client/fallback/FallbackTreatment.java b/src/main/java/io/split/android/client/fallback/FallbackTreatment.java new file mode 100644 index 000000000..b596bd37b --- /dev/null +++ b/src/main/java/io/split/android/client/fallback/FallbackTreatment.java @@ -0,0 +1,66 @@ +package io.split.android.client.fallback; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.util.Objects; + +/** + * Represents the fallback treatment, with an optional config and a fixed label. + */ +public final class FallbackTreatment { + + @NonNull + private final String mTreatment; + @Nullable + private final String mConfig; + @Nullable + private final String mLabel; + + public FallbackTreatment(@NonNull String treatment) { + this(treatment, null); + } + + public FallbackTreatment(@NonNull String treatment, @Nullable String config) { + this(treatment, config, null); + } + + FallbackTreatment(@NonNull String treatment, @Nullable String config, @Nullable String label) { + mTreatment = treatment; + mConfig = config; + mLabel = label; + } + + public String getTreatment() { + return mTreatment; + } + + @Nullable + public String getConfig() { + return mConfig; + } + + @Nullable + public String getLabel() { + return mLabel; + } + + FallbackTreatment copyWithLabel(String label) { + return new FallbackTreatment(mTreatment, mConfig, label); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FallbackTreatment that = (FallbackTreatment) o; + return Objects.equals(mTreatment, that.mTreatment) && + Objects.equals(mConfig, that.mConfig) && + Objects.equals(mLabel, that.mLabel); + } + + @Override + public int hashCode() { + return Objects.hash(mTreatment, mConfig, mLabel); + } +} diff --git a/src/main/java/io/split/android/client/fallback/FallbackTreatmentsCalculator.java b/src/main/java/io/split/android/client/fallback/FallbackTreatmentsCalculator.java new file mode 100644 index 000000000..5f9d097c5 --- /dev/null +++ b/src/main/java/io/split/android/client/fallback/FallbackTreatmentsCalculator.java @@ -0,0 +1,29 @@ +package io.split.android.client.fallback; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +/** + * Resolves a fallback treatment for a given flag name. + * Returns null if no fallback applies (caller should use control). + */ +public interface FallbackTreatmentsCalculator { + + /** + * Resolve a fallback for a given flag name. + * @param flagName non-null flag name + * @return a fallback treatment with a null label, if configured; otherwise "control" + */ + @NonNull + FallbackTreatment resolve(@NonNull String flagName); + + /** + * Resolve a fallback for a given flag name and label. + * + * @param flagName non-null flag name + * @param label nullable label + * @return a fallback treatment if configured, with a prefixed label if provided; otherwise "control" + */ + @NonNull + FallbackTreatment resolve(@NonNull String flagName, @Nullable String label); +} diff --git a/src/main/java/io/split/android/client/fallback/FallbackTreatmentsCalculatorImpl.java b/src/main/java/io/split/android/client/fallback/FallbackTreatmentsCalculatorImpl.java new file mode 100644 index 000000000..0eb727a1e --- /dev/null +++ b/src/main/java/io/split/android/client/fallback/FallbackTreatmentsCalculatorImpl.java @@ -0,0 +1,54 @@ +package io.split.android.client.fallback; + +import static io.split.android.client.utils.Utils.checkNotNull; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.util.Map; + +import io.split.android.grammar.Treatments; + +public final class FallbackTreatmentsCalculatorImpl implements FallbackTreatmentsCalculator { + + private static final String LABEL_PREFIX = "fallback - "; + + @NonNull + private final FallbackTreatmentsConfiguration mConfig; + + public FallbackTreatmentsCalculatorImpl(@NonNull FallbackTreatmentsConfiguration config) { + mConfig = checkNotNull(config); + } + + @NonNull + @Override + public FallbackTreatment resolve(@NonNull String flagName) { + return resolve(flagName, null); + } + + @NonNull + @Override + public FallbackTreatment resolve(@NonNull String flagName, @Nullable String label) { + Map byFlag = mConfig.getByFlag(); + if (byFlag != null) { + FallbackTreatment flagTreatment = byFlag.get(flagName); + if (flagTreatment != null) { + return flagTreatment.copyWithLabel(resolveLabel(label)); + } + } + FallbackTreatment global = mConfig.getGlobal(); + if (global != null) { + return global.copyWithLabel(resolveLabel(label)); + } + return new FallbackTreatment(Treatments.CONTROL, null, label); + } + + @Nullable + private static String resolveLabel(@Nullable String label) { + if (label == null) { + return null; + } + + return LABEL_PREFIX + label; + } +} diff --git a/src/main/java/io/split/android/client/fallback/FallbackTreatmentsConfiguration.java b/src/main/java/io/split/android/client/fallback/FallbackTreatmentsConfiguration.java new file mode 100644 index 000000000..992ebae2e --- /dev/null +++ b/src/main/java/io/split/android/client/fallback/FallbackTreatmentsConfiguration.java @@ -0,0 +1,176 @@ +package io.split.android.client.fallback; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import io.split.android.client.utils.logger.Logger; + +public final class FallbackTreatmentsConfiguration { + + @Nullable + private final FallbackTreatment mGlobal; + private final Map mByFlag; + + private FallbackTreatmentsConfiguration(@Nullable FallbackTreatment global, + @Nullable Map byFlag) { + mGlobal = global; + if (byFlag == null || byFlag.isEmpty()) { + mByFlag = Collections.emptyMap(); + } else { + mByFlag = Collections.unmodifiableMap(new HashMap<>(byFlag)); + } + } + + @Nullable + public FallbackTreatment getGlobal() { + return mGlobal; + } + + public Map getByFlag() { + return mByFlag; + } + + /** + * Creates a new {@link Builder} for {@link FallbackTreatmentsConfiguration}. + * Use this to provide an optional global fallback and flag-specific fallbacks. + */ + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + @Nullable + private FallbackTreatment mGlobal; + @Nullable + private Map mByFlag; + private FallbacksSanitizer mSanitizer; + + private Builder() { + mGlobal = null; + mByFlag = null; + mSanitizer = new FallbacksSanitizerImpl(); + } + + /** + * Sets an optional global fallback treatment to be used when no flag-specific + * fallback exists for a given flag. This value is returned only in place of + * the "control" treatment. + * + * @param global optional global {@link FallbackTreatment} + * @return this builder instance + */ + public Builder global(@Nullable FallbackTreatment global) { + if (mGlobal != null && global != null) { + Logger.w("Fallback treatments - You had previously set a global fallback. The new value will replace it"); + } + mGlobal = global; + return this; + } + + /** + * Sets an optional global fallback treatment to be used when no flag-specific + * fallback exists for a given flag. This value is returned only in place of + * the "control" treatment. + * + * @param treatment the treatment string to use as global + * @return this builder instance + */ + public Builder global(String treatment) { + if (mGlobal != null) { + Logger.w("Fallback treatments - You had previously set a global fallback. The new value will replace it"); + } + mGlobal = new FallbackTreatment(treatment); + return this; + } + + /** + * Sets optional flag-specific fallback treatments, where keys are flag names. + * These take precedence over the global fallback. + * + * @param byFlag map of flag name to {@link FallbackTreatment}; may be null or empty + * @return this builder instance + */ + public Builder byFlag(@Nullable Map byFlag) { + if (byFlag == null || byFlag.isEmpty()) { + return this; + } + if (mByFlag == null) { + mByFlag = new HashMap<>(); + } + for (Map.Entry e : byFlag.entrySet()) { + String key = e.getKey(); + if (mByFlag.containsKey(key)) { + Logger.w(getDuplicateFlagMessage(key)); + } + mByFlag.put(key, e.getValue()); + } + return this; + } + + /** + * Sets optional flag-specific fallback treatments, where keys are flag names. + * These take precedence over the global fallback. + * + * @param byFlag map of flag name to treatment string; may be null or empty + * @return this builder instance + */ + public Builder byFlagStrings(@Nullable Map byFlag) { + if (byFlag == null || byFlag.isEmpty()) { + return this; + } + if (mByFlag == null) { + mByFlag = new HashMap<>(); + } + for (Map.Entry e : byFlag.entrySet()) { + String key = e.getKey(); + if (mByFlag.containsKey(key)) { + Logger.w(getDuplicateFlagMessage(key)); + } + mByFlag.put(key, new FallbackTreatment(e.getValue())); + } + return this; + } + + /** + * Builds a {@link FallbackTreatmentsConfiguration} for the configured values. + * + * @return a new immutable {@link FallbackTreatmentsConfiguration} + */ + public FallbackTreatmentsConfiguration build() { + FallbackTreatment sanitizedGlobal = mSanitizer.sanitizeGlobal(mGlobal); + Map sanitizedByFlag = mSanitizer.sanitizeByFlag(mByFlag); + return new FallbackTreatmentsConfiguration(sanitizedGlobal, sanitizedByFlag); + } + + @VisibleForTesting + Builder sanitizer(FallbacksSanitizer sanitizer) { + mSanitizer = sanitizer; + return this; + } + + @NonNull + private static String getDuplicateFlagMessage(String key) { + return "Fallback treatments - Duplicate fallback for flag '" + key + "'. Overriding existing value."; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FallbackTreatmentsConfiguration that = (FallbackTreatmentsConfiguration) o; + return Objects.equals(mGlobal, that.mGlobal) && + Objects.equals(mByFlag, that.mByFlag); + } + + @Override + public int hashCode() { + return Objects.hash(mGlobal, mByFlag); + } +} diff --git a/src/main/java/io/split/android/client/fallback/FallbacksSanitizer.java b/src/main/java/io/split/android/client/fallback/FallbacksSanitizer.java new file mode 100644 index 000000000..29aaf0307 --- /dev/null +++ b/src/main/java/io/split/android/client/fallback/FallbacksSanitizer.java @@ -0,0 +1,12 @@ +package io.split.android.client.fallback; + +import androidx.annotation.Nullable; +import java.util.Map; + +interface FallbacksSanitizer { + + @Nullable + FallbackTreatment sanitizeGlobal(@Nullable FallbackTreatment global); + + Map sanitizeByFlag(@Nullable Map byFlag); +} diff --git a/src/main/java/io/split/android/client/fallback/FallbacksSanitizerImpl.java b/src/main/java/io/split/android/client/fallback/FallbacksSanitizerImpl.java new file mode 100644 index 000000000..70e34194d --- /dev/null +++ b/src/main/java/io/split/android/client/fallback/FallbacksSanitizerImpl.java @@ -0,0 +1,81 @@ +package io.split.android.client.fallback; + +import androidx.annotation.Nullable; + +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Pattern; + +import io.split.android.client.utils.logger.Logger; + +/** + * Validates and sanitizes fallback configurations by applying validation rules. + * Invalid entries are dropped and warnings are logged. + */ +class FallbacksSanitizerImpl implements FallbacksSanitizer { + + private static final int MAX_FLAG_NAME_LENGTH = 100; + private static final int MAX_TREATMENT_LENGTH = 100; + private static final String TREATMENT_REGEXP = "^[0-9]+[.a-zA-Z0-9_-]*$|^[a-zA-Z]+[a-zA-Z0-9_-]*$"; + private static final Pattern TREATMENT_PATTERN = Pattern.compile(TREATMENT_REGEXP); + + + + @Override + @Nullable + public FallbackTreatment sanitizeGlobal(@Nullable FallbackTreatment global) { + if (global == null) { + return null; + } + + if (!isValidTreatment(global)) { + Logger.e("Fallback treatments - Discarded global fallback: Invalid treatment (max " + MAX_TREATMENT_LENGTH + " chars and comply with " + TREATMENT_REGEXP + ")"); + return null; + } + + return global; + } + + @Override + public Map sanitizeByFlag(@Nullable Map byFlag) { + if (byFlag == null || byFlag.isEmpty()) { + return new HashMap<>(); + } + + Map sanitized = new HashMap<>(); + + for (Map.Entry entry : byFlag.entrySet()) { + String flagName = entry.getKey(); + FallbackTreatment treatment = entry.getValue(); + + if (!isValidFlagName(flagName)) { + Logger.e("Fallback treatments - Discarded flag '" + flagName + "': Invalid flag name (max " + MAX_FLAG_NAME_LENGTH + " chars, no spaces)"); + continue; + } + + if (!isValidTreatment(treatment)) { + Logger.e("Fallback treatments - Discarded treatment for flag '" + flagName + "': Invalid treatment (max " + MAX_TREATMENT_LENGTH + " chars and comply with " + TREATMENT_REGEXP + ")"); + continue; + } + + sanitized.put(flagName, treatment); + } + + return sanitized; + } + + private static boolean isValidFlagName(String flagName) { + if (flagName == null) { + return false; + } + return flagName.length() <= MAX_FLAG_NAME_LENGTH && !flagName.contains(" "); + } + + private static boolean isValidTreatment(FallbackTreatment treatment) { + if (treatment == null || treatment.getTreatment() == null) { + return false; + } + String value = treatment.getTreatment(); + return value.length() <= MAX_TREATMENT_LENGTH && TREATMENT_PATTERN.matcher(value).matches(); + } +} diff --git a/src/main/java/io/split/android/client/localhost/LocalhostSplitClient.java b/src/main/java/io/split/android/client/localhost/LocalhostSplitClient.java index 83b1bd4fe..dacead934 100644 --- a/src/main/java/io/split/android/client/localhost/LocalhostSplitClient.java +++ b/src/main/java/io/split/android/client/localhost/LocalhostSplitClient.java @@ -32,6 +32,9 @@ import io.split.android.client.telemetry.storage.TelemetryStorageProducer; import io.split.android.client.utils.logger.Logger; import io.split.android.client.validators.FlagSetsValidatorImpl; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; import io.split.android.client.validators.KeyValidatorImpl; import io.split.android.client.validators.SplitValidatorImpl; import io.split.android.client.validators.TreatmentManager; @@ -72,12 +75,13 @@ public LocalhostSplitClient(@NonNull LocalhostSplitFactory container, mKey = checkNotNull(key); mEventsManager = checkNotNull(eventsManager); mSplitsStorage = splitsStorage; + FallbackTreatmentsCalculator calculator = new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build()); mTreatmentManager = new TreatmentManagerImpl(mKey.matchingKey(), mKey.bucketingKey(), - new EvaluatorImpl(splitsStorage, splitParser), new KeyValidatorImpl(), + new EvaluatorImpl(splitsStorage, splitParser, calculator), new KeyValidatorImpl(), new SplitValidatorImpl(), getImpressionsListener(splitClientConfig), splitClientConfig.labelsEnabled(), eventsManager, attributesManager, attributesMerger, telemetryStorageProducer, flagSetsFilter, splitsStorage, new ValidationMessageLoggerImpl(), new FlagSetsValidatorImpl(), - new PropertyValidatorImpl()); + new PropertyValidatorImpl(), calculator); } @Override diff --git a/src/main/java/io/split/android/client/network/BasicCredentialsProvider.java b/src/main/java/io/split/android/client/network/BasicCredentialsProvider.java new file mode 100644 index 000000000..b68a6659b --- /dev/null +++ b/src/main/java/io/split/android/client/network/BasicCredentialsProvider.java @@ -0,0 +1,13 @@ +package io.split.android.client.network; + +/** + * Interface for providing basic credentials. + *

+ * The username and password will be used to create a Proxy-Authorization header using Basic authentication + */ +public interface BasicCredentialsProvider extends ProxyCredentialsProvider { + + String getUsername(); + + String getPassword(); +} diff --git a/src/main/java/io/split/android/client/network/BearerCredentialsProvider.java b/src/main/java/io/split/android/client/network/BearerCredentialsProvider.java new file mode 100644 index 000000000..d372ce5e7 --- /dev/null +++ b/src/main/java/io/split/android/client/network/BearerCredentialsProvider.java @@ -0,0 +1,11 @@ +package io.split.android.client.network; + +/** + * Interface for providing proxy credentials. + *

+ * The token will be sent in the header "Proxy-Authorization: Bearer " + */ +public interface BearerCredentialsProvider extends ProxyCredentialsProvider { + + String getToken(); +} diff --git a/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java b/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java index f733e28f9..7871f6949 100644 --- a/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java +++ b/src/main/java/io/split/android/client/network/CertificateCheckerImpl.java @@ -17,7 +17,6 @@ import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.X509TrustManager; -import io.split.android.client.utils.Base64Util; import io.split.android.client.utils.logger.Logger; class CertificateCheckerImpl implements CertificateChecker { @@ -100,17 +99,4 @@ private String certificateChainInfo(List cleanCertificates) { return builder.toString(); } - - private static class DefaultBase64Encoder implements Base64Encoder { - - @Override - public String encode(String value) { - return Base64Util.encode(value); - } - - @Override - public String encode(byte[] bytes) { - return Base64Util.encode(bytes); - } - } } diff --git a/src/main/java/io/split/android/client/network/CertificatePinningConfiguration.java b/src/main/java/io/split/android/client/network/CertificatePinningConfiguration.java index d7549bb6f..23ec94d5c 100644 --- a/src/main/java/io/split/android/client/network/CertificatePinningConfiguration.java +++ b/src/main/java/io/split/android/client/network/CertificatePinningConfiguration.java @@ -209,12 +209,5 @@ private Set getInitializedPins(String host) { } return pins; } - - private static class DefaultBase64Decoder implements Base64Decoder { - @Override - public byte[] decode(String base64) { - return Base64Util.bytesDecode(base64); - } - } } } diff --git a/src/main/java/io/split/android/client/network/DefaultBase64Decoder.java b/src/main/java/io/split/android/client/network/DefaultBase64Decoder.java new file mode 100644 index 000000000..c84903fb6 --- /dev/null +++ b/src/main/java/io/split/android/client/network/DefaultBase64Decoder.java @@ -0,0 +1,11 @@ +package io.split.android.client.network; + +import io.split.android.client.utils.Base64Util; + +class DefaultBase64Decoder implements Base64Decoder { + + @Override + public byte[] decode(String base64) { + return Base64Util.bytesDecode(base64); + } +} diff --git a/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java b/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java new file mode 100644 index 000000000..e1333ca80 --- /dev/null +++ b/src/main/java/io/split/android/client/network/DefaultBase64Encoder.java @@ -0,0 +1,16 @@ +package io.split.android.client.network; + +import io.split.android.client.utils.Base64Util; + +class DefaultBase64Encoder implements Base64Encoder { + + @Override + public String encode(String value) { + return Base64Util.encode(value); + } + + @Override + public String encode(byte[] bytes) { + return Base64Util.encode(bytes); + } +} diff --git a/src/main/java/io/split/android/client/network/HttpClientImpl.java b/src/main/java/io/split/android/client/network/HttpClientImpl.java index ca0a1d46d..3b2a4be33 100644 --- a/src/main/java/io/split/android/client/network/HttpClientImpl.java +++ b/src/main/java/io/split/android/client/network/HttpClientImpl.java @@ -6,6 +6,10 @@ import androidx.annotation.Nullable; import androidx.annotation.VisibleForTesting; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; import java.net.InetSocketAddress; import java.net.Proxy; import java.net.URI; @@ -28,7 +32,11 @@ public class HttpClientImpl implements HttpClient { @Nullable private final Proxy mProxy; @Nullable + private final HttpProxy mHttpProxy; + @Nullable private final SplitUrlConnectionAuthenticator mProxyAuthenticator; + @Nullable + private final ProxyCredentialsProvider mProxyCredentialsProvider; private final long mReadTimeout; private final long mConnectionTimeout; @Nullable @@ -39,17 +47,22 @@ public class HttpClientImpl implements HttpClient { private final UrlSanitizer mUrlSanitizer; @Nullable private final CertificateChecker mCertificateChecker; + @Nullable + private final ProxyCacertConnectionHandler mConnectionHandler; HttpClientImpl(@Nullable HttpProxy proxy, @Nullable SplitAuthenticator proxyAuthenticator, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, long readTimeout, long connectionTimeout, @Nullable DevelopmentSslConfig developmentSslConfig, @Nullable SSLSocketFactory sslSocketFactory, @NonNull UrlSanitizer urlSanitizer, @Nullable CertificateChecker certificateChecker) { + mHttpProxy = proxy; mProxy = initializeProxy(proxy); mProxyAuthenticator = initializeProxyAuthenticator(proxy, proxyAuthenticator); + mProxyCredentialsProvider = proxyCredentialsProvider; mReadTimeout = readTimeout; mConnectionTimeout = connectionTimeout; mDevelopmentSslConfig = developmentSslConfig; @@ -58,6 +71,8 @@ public class HttpClientImpl implements HttpClient { mSslSocketFactory = sslSocketFactory; mUrlSanitizer = urlSanitizer; mCertificateChecker = certificateChecker; + mConnectionHandler = mHttpProxy != null && mSslSocketFactory != null ? + new ProxyCacertConnectionHandler() : null; } @Override @@ -73,7 +88,9 @@ public HttpRequest request(URI uri, HttpMethod requestMethod, String body, Map + * This class is responsible for executing HTTP requests through tunnel sockets that have been + * created by the SSL tunnel establisher using the custom tunneling approach. + */ +class HttpOverTunnelExecutor { + + public static final int HTTP_PORT = 80; + public static final int HTTPS_PORT = 443; + public static final int UNSET_PORT = -1; + private static final String CRLF = "\r\n"; + + private final RawHttpResponseParser mResponseParser; + + public HttpOverTunnelExecutor() { + mResponseParser = new RawHttpResponseParser(); + } + + @NonNull + HttpResponse executeRequest( + @NonNull Socket tunnelSocket, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable String body, + @Nullable Certificate[] serverCertificates) throws IOException { + + Logger.v("Executing request through tunnel to: " + targetUrl); + + try { + sendHttpRequest(tunnelSocket, targetUrl, method, headers, body); + + return readHttpResponse(tunnelSocket, serverCertificates); + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped + // This ensures consistent behavior with non-proxy flows + throw e; + } catch (Exception e) { + // Wrap other exceptions in IOException + Logger.e("Failed to execute request through tunnel: " + e.getMessage()); + throw new IOException("Failed to execute HTTP request through tunnel to " + targetUrl, e); + } + } + + @NonNull + HttpStreamResponse executeStreamRequest(@NonNull Socket finalSocket, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable Certificate[] serverCertificates) throws IOException { + Logger.v("Executing stream request through tunnel to: " + targetUrl); + + try { + sendHttpRequest(finalSocket, targetUrl, method, headers, null); + return readHttpStreamResponse(finalSocket, originSocket); + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped + // This ensures consistent behavior with non-proxy flows + throw e; + } catch (Exception e) { + // Wrap other exceptions in IOException + Logger.e("Failed to execute stream request through tunnel: " + e.getMessage()); + throw new IOException("Failed to execute HTTP stream request through tunnel to " + targetUrl, e); + } + } + + /** + * Sends the HTTP request through the tunnel socket. + */ + private void sendHttpRequest( + @NonNull Socket tunnelSocket, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable String body) throws IOException { + + PrintWriter writer = new PrintWriter(tunnelSocket.getOutputStream(), true); + + // 1. Send request line + String path = targetUrl.getPath(); + if (path.isEmpty()) { + path = "/"; + } + if (targetUrl.getQuery() != null) { + path += "?" + targetUrl.getQuery(); + } + + String requestLine = method.name() + " " + path + " HTTP/1.1"; + writer.write(requestLine + CRLF); + + // 2. Send Host header (required for HTTP/1.1) + String host = targetUrl.getHost(); + int port = getTargetPort(targetUrl); + + // Add port to Host header if it's not the default port for the protocol + if (!isIsDefaultPort(targetUrl, port)) { + host += ":" + port; + } + + writer.write("Host: " + host + CRLF); + + // 3. Send custom headers (excluding Host and Content-Length) + for (Map.Entry header : headers.entrySet()) { + if (header.getKey() != null && header.getValue() != null && + !"content-length".equalsIgnoreCase(header.getKey()) && + !"host".equalsIgnoreCase(header.getKey())) { + String headerLine = header.getKey() + ": " + header.getValue(); + writer.write(headerLine + CRLF); + } + } + + // 4. Send Content-Length header if body is present + if (body != null) { + String contentLengthHeader = "Content-Length: " + body.getBytes("UTF-8").length; + writer.write(contentLengthHeader + CRLF); + } + + // 5. Send Connection: close to ensure response completion + writer.write("Connection: close" + CRLF); + + // 6. End headers with empty line + writer.write(CRLF); + + // 7. Send body if present + if (body != null) { + Logger.v("Sending request body: '" + body + "'"); + writer.write(body); + } + + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Failed to send HTTP request through tunnel"); + } + } + + private static boolean isIsDefaultPort(@NonNull URL targetUrl, int port) { + return ("http".equalsIgnoreCase(targetUrl.getProtocol()) && port == HTTP_PORT) || + ("https".equalsIgnoreCase(targetUrl.getProtocol()) && port == HTTPS_PORT); + } + + /** + * Reads HTTP response from the tunnel socket. + * + * @param tunnelSocket The socket to read from + * @param serverCertificates The server certificates to include in the response + * @return HttpResponse with server certificates + */ + private HttpResponse readHttpResponse(@NonNull Socket tunnelSocket, @Nullable Certificate[] serverCertificates) throws IOException { + return mResponseParser.parseHttpResponse(tunnelSocket.getInputStream(), serverCertificates); + } + + private HttpStreamResponse readHttpStreamResponse(@NonNull Socket tunnelSocket, @Nullable Socket originSocket) throws IOException { + return mResponseParser.parseHttpStreamResponse(tunnelSocket.getInputStream(), tunnelSocket, originSocket); + } + + /** + * Gets the target port from URL, defaulting based on protocol. + */ + private int getTargetPort(@NonNull URL targetUrl) { + int port = targetUrl.getPort(); + if (port == UNSET_PORT) { + if ("https".equalsIgnoreCase(targetUrl.getProtocol())) { + return HTTPS_PORT; + } else if ("http".equalsIgnoreCase(targetUrl.getProtocol())) { + return HTTP_PORT; + } + } + return port; + } +} diff --git a/src/main/java/io/split/android/client/network/HttpProxy.java b/src/main/java/io/split/android/client/network/HttpProxy.java index 8f9734fca..a6dc011fa 100644 --- a/src/main/java/io/split/android/client/network/HttpProxy.java +++ b/src/main/java/io/split/android/client/network/HttpProxy.java @@ -1,47 +1,118 @@ package io.split.android.client.network; -import static io.split.android.client.utils.Utils.checkNotNull; - import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import java.io.InputStream; + public class HttpProxy { - final private String host; - final private int port; - final private String username; - final private String password; + private final @NonNull String mHost; + private final int mPort; + private final @Nullable String mUsername; + private final @Nullable String mPassword; + private final @Nullable InputStream mClientCertStream; + private final @Nullable InputStream mClientKeyStream; + private final @Nullable InputStream mCaCertStream; + private final @Nullable ProxyCredentialsProvider mCredentialsProvider; + private final boolean mIsLegacy; - public HttpProxy(@NonNull String host, int port) { - this(host, port, null, null); + private HttpProxy(Builder builder, boolean isLegacy) { + mHost = builder.mHost; + mPort = builder.mPort; + mUsername = builder.mUsername; + mPassword = builder.mPassword; + mClientCertStream = builder.mClientCertStream; + mClientKeyStream = builder.mClientKeyStream; + mCaCertStream = builder.mCaCertStream; + mCredentialsProvider = builder.mCredentialsProvider; + mIsLegacy = isLegacy; } - public HttpProxy(@NonNull String host, int port, @Nullable String username, @Nullable String password) { - checkNotNull(host); + public @Nullable String getHost() { + return mHost; + } - this.host = host; - this.port = port; - this.username = username; - this.password = password; + public int getPort() { + return mPort; } - public String getHost() { - return host; + public @Nullable String getUsername() { + return mUsername; } - public int getPort() { - return port; + public @Nullable String getPassword() { + return mPassword; + } + + public @Nullable InputStream getClientCertStream() { + return mClientCertStream; + } + + public @Nullable InputStream getClientKeyStream() { + return mClientKeyStream; + } + + public @Nullable InputStream getCaCertStream() { + return mCaCertStream; + } + + public @Nullable ProxyCredentialsProvider getCredentialsProvider() { + return mCredentialsProvider; } - public String getUsername() { - return username; + public static Builder newBuilder(@Nullable String host, int port) { + return new Builder(host, port); } - public String getPassword() { - return password; + public boolean isLegacy() { + return mIsLegacy; } - public boolean usesCredentials() { - return username == null; + public static class Builder { + private final @Nullable String mHost; + private final int mPort; + private @Nullable String mUsername; + private @Nullable String mPassword; + private @Nullable InputStream mClientCertStream; + private @Nullable InputStream mClientKeyStream; + private @Nullable InputStream mCaCertStream; + @Nullable + private ProxyCredentialsProvider mCredentialsProvider; + + private Builder(@Nullable String host, int port) { + mHost = host; + mPort = port; + } + + public Builder basicAuth(@NonNull String username, @NonNull String password) { + mUsername = username; + mPassword = password; + return this; + } + + public Builder proxyCacert(@NonNull InputStream caCertStream) { + mCaCertStream = caCertStream; + return this; + } + + public Builder mtls(@NonNull InputStream clientCertStream, @NonNull InputStream keyStream) { + mClientCertStream = clientCertStream; + mClientKeyStream = keyStream; + return this; + } + + public Builder credentialsProvider(@NonNull ProxyCredentialsProvider credentialsProvider) { + mCredentialsProvider = credentialsProvider; + return this; + } + + public HttpProxy build() { + return new HttpProxy(this, false); + } + + public HttpProxy buildLegacy() { + return new HttpProxy(this, true); + } } } diff --git a/src/main/java/io/split/android/client/network/HttpRequestHelper.java b/src/main/java/io/split/android/client/network/HttpRequestHelper.java index 8deca2548..e18bda179 100644 --- a/src/main/java/io/split/android/client/network/HttpRequestHelper.java +++ b/src/main/java/io/split/android/client/network/HttpRequestHelper.java @@ -19,12 +19,55 @@ class HttpRequestHelper { - static HttpURLConnection openConnection(@Nullable Proxy proxy, - @Nullable SplitUrlConnectionAuthenticator proxyAuthenticator, - @NonNull URL url, - @NonNull HttpMethod method, - @NonNull Map headers, - boolean useProxyAuthentication) throws IOException { + private static final ProxyCacertConnectionHandler mConnectionHandler = new ProxyCacertConnectionHandler(); + + static HttpURLConnection createConnection(@NonNull URL url, + @Nullable Proxy proxy, + @Nullable HttpProxy httpProxy, + @Nullable SplitUrlConnectionAuthenticator proxyAuthenticator, + @NonNull HttpMethod method, + @NonNull Map headers, + boolean useProxyAuthentication, + @Nullable SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + @Nullable String body) throws IOException { + + // Use the new tunnel path only when there is no legacy authenticator present. + // If a legacy authenticator proxy, we prefer the legacy path to preserve 407 retry behavior. + if (httpProxy != null && sslSocketFactory != null && !httpProxy.isLegacy()) { + try { + HttpResponse response = mConnectionHandler.executeRequest( + httpProxy, + url, + method, + headers, + body, + sslSocketFactory, + proxyCredentialsProvider + ); + + return new HttpResponseConnectionAdapter(url, response, response.getServerCertificates()); + } catch (UnsupportedOperationException e) { + // Fall through to standard handling + } + } + + return openConnection(proxy, httpProxy, proxyAuthenticator, url, method, headers, useProxyAuthentication); + } + + private static HttpURLConnection openConnection(@Nullable Proxy proxy, + @Nullable HttpProxy httpProxy, + @Nullable SplitUrlConnectionAuthenticator proxyAuthenticator, + @NonNull URL url, + @NonNull HttpMethod method, + @NonNull Map headers, + boolean useProxyAuthentication) throws IOException { + + if (httpProxy != null && !httpProxy.isLegacy() && (httpProxy.getCaCertStream() != null || httpProxy.getClientCertStream() != null)) { + throw new IOException("SSL proxy scenarios require custom handling - use executeRequest method instead"); + } + + // Standard HttpURLConnection proxy handling HttpURLConnection connection; if (proxy != null) { connection = (HttpURLConnection) url.openConnection(proxy); @@ -84,7 +127,7 @@ static void checkPins(HttpURLConnection connection, @Nullable CertificateChecker private static void addHeaders(HttpURLConnection request, Map headers) { for (Map.Entry entry : headers.entrySet()) { - if (entry == null) { + if (entry == null || entry.getKey() == null) { continue; } diff --git a/src/main/java/io/split/android/client/network/HttpRequestImpl.java b/src/main/java/io/split/android/client/network/HttpRequestImpl.java index 6db18413d..1f2a0c402 100644 --- a/src/main/java/io/split/android/client/network/HttpRequestImpl.java +++ b/src/main/java/io/split/android/client/network/HttpRequestImpl.java @@ -4,7 +4,8 @@ import static io.split.android.client.network.HttpRequestHelper.applySslConfig; import static io.split.android.client.network.HttpRequestHelper.applyTimeouts; -import static io.split.android.client.network.HttpRequestHelper.openConnection; +import static io.split.android.client.network.HttpRequestHelper.createConnection; + import androidx.annotation.NonNull; import androidx.annotation.Nullable; @@ -14,6 +15,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; +import java.net.HttpRetryException; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.ProtocolException; @@ -43,7 +45,11 @@ public class HttpRequestImpl implements HttpRequest { @Nullable private final Proxy mProxy; @Nullable + private final HttpProxy mHttpProxy; + @Nullable private final SplitUrlConnectionAuthenticator mProxyAuthenticator; + @Nullable + private final ProxyCredentialsProvider mProxyCredentialsProvider; private final long mReadTimeout; private final long mConnectionTimeout; @Nullable @@ -58,7 +64,9 @@ public class HttpRequestImpl implements HttpRequest { @Nullable String body, @NonNull Map headers, @Nullable Proxy proxy, + @Nullable HttpProxy httpProxy, @Nullable SplitUrlConnectionAuthenticator proxyAuthenticator, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, long readTimeout, long connectionTimeout, @Nullable DevelopmentSslConfig developmentSslConfig, @@ -71,7 +79,9 @@ public class HttpRequestImpl implements HttpRequest { mUrlSanitizer = checkNotNull(urlSanitizer); mHeaders = new HashMap<>(checkNotNull(headers)); mProxy = proxy; + mHttpProxy = httpProxy; mProxyAuthenticator = proxyAuthenticator; + mProxyCredentialsProvider = proxyCredentialsProvider; mReadTimeout = readTimeout; mConnectionTimeout = connectionTimeout; mDevelopmentSslConfig = developmentSslConfig; @@ -178,7 +188,15 @@ private HttpURLConnection setUpConnection(boolean authenticate) throws IOExcepti throw new IOException("Error parsing URL"); } - HttpURLConnection connection = openConnection(mProxy, mProxyAuthenticator, url, mHttpMethod, mHeaders, authenticate); + HttpURLConnection connection; + try { + connection = getConnection(authenticate, url); + } catch (HttpRetryException e) { + if (mProxyAuthenticator == null) { + throw e; + } + connection = getConnection(authenticate, url); + } applyTimeouts(mReadTimeout, mConnectionTimeout, connection); applySslConfig(mSslSocketFactory, mDevelopmentSslConfig, connection); @@ -197,6 +215,21 @@ private HttpURLConnection setUpConnection(boolean authenticate) throws IOExcepti return connection; } + @NonNull + private HttpURLConnection getConnection(boolean authenticate, URL url) throws IOException { + return createConnection( + url, + mProxy, + mHttpProxy, + mProxyAuthenticator, + mHttpMethod, + mHeaders, + authenticate, + mSslSocketFactory, + mProxyCredentialsProvider, + mBody); + } + private static HttpResponse buildResponse(HttpURLConnection connection) throws IOException { int responseCode = connection.getResponseCode(); diff --git a/src/main/java/io/split/android/client/network/HttpResponse.java b/src/main/java/io/split/android/client/network/HttpResponse.java index 417dfb289..42a3d0a93 100644 --- a/src/main/java/io/split/android/client/network/HttpResponse.java +++ b/src/main/java/io/split/android/client/network/HttpResponse.java @@ -1,5 +1,9 @@ package io.split.android.client.network; +import java.security.cert.Certificate; + public interface HttpResponse extends BaseHttpResponse { String getData(); + + Certificate[] getServerCertificates(); } diff --git a/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java b/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java new file mode 100644 index 000000000..fb269cbdf --- /dev/null +++ b/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java @@ -0,0 +1,455 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.Permission; +import java.security.cert.Certificate; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLSocketFactory; + +/** + * Adapter that wraps an HttpResponse as an HttpURLConnection. + *

+ * This is only used to adapt the response from request through the TLS tunnel. + */ +class HttpResponseConnectionAdapter extends HttpsURLConnection { + + private final HttpResponse mResponse; + private final URL mUrl; + private final Certificate[] mServerCertificates; + private final OutputStream mOutputStream; + private InputStream mInputStream; + private InputStream mErrorStream; + private boolean mDoOutput = false; + + /** + * Creates an adapter that wraps an HttpResponse as an HttpURLConnection. + * + * @param url The URL of the request + * @param response The HTTP response from the SSL proxy + * @param serverCertificates The server certificates from the SSL connection + */ + HttpResponseConnectionAdapter(@NonNull URL url, + @NonNull HttpResponse response, + Certificate[] serverCertificates) { + this(url, response, serverCertificates, new ByteArrayOutputStream()); + } + + @VisibleForTesting + HttpResponseConnectionAdapter(@NonNull URL url, + @NonNull HttpResponse response, + Certificate[] serverCertificates, + @NonNull OutputStream outputStream) { + this(url, response, serverCertificates, outputStream, null, null); + } + + @VisibleForTesting + HttpResponseConnectionAdapter(@NonNull URL url, + @NonNull HttpResponse response, + Certificate[] serverCertificates, + @NonNull OutputStream outputStream, + @Nullable InputStream inputStream, + @Nullable InputStream errorStream) { + super(url); + mUrl = url; + mResponse = response; + mServerCertificates = serverCertificates; + mOutputStream = outputStream; + mInputStream = inputStream; + mErrorStream = errorStream; + } + + @Override + public int getResponseCode() throws IOException { + return mResponse.getHttpStatus(); + } + + @Override + public String getResponseMessage() throws IOException { + // Map common HTTP status codes to messages + switch (mResponse.getHttpStatus()) { + case 200: + return "OK"; + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 500: + return "Internal Server Error"; + default: + return "HTTP " + mResponse.getHttpStatus(); + } + } + + @Override + public InputStream getInputStream() throws IOException { + if (mResponse.getHttpStatus() >= 400) { + throw new IOException("HTTP " + mResponse.getHttpStatus()); + } + if (mInputStream == null) { + String data = mResponse.getData(); + if (data == null) { + data = ""; + } + mInputStream = new ByteArrayInputStream(data.getBytes(StandardCharsets.UTF_8)); + } + return mInputStream; + } + + @Override + public InputStream getErrorStream() { + if (mResponse.getHttpStatus() >= 400) { + if (mErrorStream == null) { + String data = mResponse.getData(); + if (data == null) { + data = ""; + } + mErrorStream = new ByteArrayInputStream(data.getBytes(StandardCharsets.UTF_8)); + } + return mErrorStream; + } + return null; + } + + @Override + public void connect() throws IOException { + // Already connected + } + + @Override + public boolean usingProxy() { + return true; + } + + @Override + public void disconnect() { + // Close output stream if it exists + try { + if (mOutputStream != null) { + mOutputStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } + + // Close input stream if it exists + try { + if (mInputStream != null) { + mInputStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } + + // Close error stream if it exists + try { + if (mErrorStream != null) { + mErrorStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } + } + + // Required abstract method implementations for HTTPS connection + @Override + public String getCipherSuite() { + return null; + } + + @Override + public Certificate[] getLocalCertificates() { + return null; + } + + @Override + public Certificate[] getServerCertificates() { + // Return the server certificates from the SSL connection + return mServerCertificates; + } + + // Minimal implementations for other required methods + @Override + public void setRequestMethod(String method) { + } + + @Override + public String getRequestMethod() { + return "GET"; + } + + @Override + public void setInstanceFollowRedirects(boolean followRedirects) { + } + + @Override + public boolean getInstanceFollowRedirects() { + return true; + } + + @Override + public void setDoOutput(boolean doOutput) { + mDoOutput = doOutput; + } + + @Override + public boolean getDoOutput() { + return mDoOutput; + } + + @Override + public void setDoInput(boolean doInput) { + } + + @Override + public boolean getDoInput() { + return true; + } + + @Override + public void setUseCaches(boolean useCaches) { + } + + @Override + public boolean getUseCaches() { + return false; + } + + @Override + public void setIfModifiedSince(long ifModifiedSince) { + } + + @Override + public long getIfModifiedSince() { + return 0; + } + + @Override + public void setDefaultUseCaches(boolean defaultUseCaches) { + } + + @Override + public boolean getDefaultUseCaches() { + return false; + } + + @Override + public void setRequestProperty(String key, String value) { + } + + @Override + public void addRequestProperty(String key, String value) { + } + + @Override + public String getRequestProperty(String key) { + return null; + } + + @Override + public Map> getRequestProperties() { + return null; + } + + @Override + public String getHeaderField(String name) { + if (name == null) { + return null; + } + Map> headers = getHeaderFields(); + List values = headers.get(name.toLowerCase()); + + return (values != null && !values.isEmpty()) ? values.get(0) : null; + } + + @Override + public Map> getHeaderFields() { + Map> headers = new HashMap<>(); + + // Add synthetic headers based on response data + String contentType = getContentType(); + if (contentType != null) { + headers.put("content-type", Collections.singletonList(contentType)); + } + + long contentLength = getContentLengthLong(); + if (contentLength >= 0) { + headers.put("content-length", Collections.singletonList(String.valueOf(contentLength))); + } + + String contentEncoding = getContentEncoding(); + if (contentEncoding != null) { + headers.put("content-encoding", Collections.singletonList(contentEncoding)); + } + + try { + headers.put("status", Collections.singletonList(getResponseCode() + " " + getResponseMessage())); + } catch (IOException e) { + // Ignore if we can't get response code + } + + return headers; + } + + @Override + public int getHeaderFieldInt(String name, int defaultValue) { + String value = getHeaderField(name); + if (value != null) { + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + // Fall through to default + } + } + return defaultValue; + } + + @Override + public long getHeaderFieldDate(String name, long defaultValue) { + // We don't have actual date headers + if ("date".equalsIgnoreCase(name)) { + return System.currentTimeMillis(); + } + return defaultValue; + } + + @Override + public String getHeaderFieldKey(int n) { + Map> headers = getHeaderFields(); + if (n >= 0 && n < headers.size()) { + return (String) headers.keySet().toArray()[n]; + } + return null; + } + + @Override + public String getHeaderField(int n) { + String key = getHeaderFieldKey(n); + return key != null ? getHeaderField(key) : null; + } + + @Override + public long getContentLengthLong() { + String data = mResponse.getData(); + if (data == null) { + return 0; + } + return data.getBytes(StandardCharsets.UTF_8).length; + } + + @Override + public String getContentType() { + // Try to detect content type from response data, default to JSON for API responses + String data = mResponse.getData(); + if (data == null || data.trim().isEmpty()) { + return null; + } + String trimmed = data.trim(); + if (trimmed.startsWith("{") || trimmed.startsWith("[")) { + return "application/json; charset=utf-8"; + } + if (trimmed.startsWith("<")) { + return "text/html; charset=utf-8"; + } + return "text/plain; charset=utf-8"; + } + + @Override + public String getContentEncoding() { + return "utf-8"; + } + + @Override + public long getExpiration() { + return 0; + } + + @Override + public long getDate() { + return System.currentTimeMillis(); + } + + @Override + public long getLastModified() { + return 0; + } + + @Override + public URL getURL() { + return mUrl; + } + + @Override + public int getContentLength() { + long length = getContentLengthLong(); + return length > Integer.MAX_VALUE ? -1 : (int) length; + } + + @Override + public Permission getPermission() throws IOException { + return null; + } + + @Override + public OutputStream getOutputStream() throws IOException { + if (!mDoOutput) { + throw new IOException("Output not enabled for this connection. Call setDoOutput(true) first."); + } + return mOutputStream; + } + + @Override + public void setConnectTimeout(int timeout) { + } + + @Override + public int getConnectTimeout() { + return 0; + } + + @Override + public void setReadTimeout(int timeout) { + } + + @Override + public int getReadTimeout() { + return 0; + } + + @Override + public void setHostnameVerifier(HostnameVerifier v) { + } + + @Override + public HostnameVerifier getHostnameVerifier() { + return null; + } + + @Override + public void setSSLSocketFactory(SSLSocketFactory sf) { + } + + @Override + public SSLSocketFactory getSSLSocketFactory() { + return null; + } +} diff --git a/src/main/java/io/split/android/client/network/HttpResponseImpl.java b/src/main/java/io/split/android/client/network/HttpResponseImpl.java index 1e1f05a0d..07c970d46 100644 --- a/src/main/java/io/split/android/client/network/HttpResponseImpl.java +++ b/src/main/java/io/split/android/client/network/HttpResponseImpl.java @@ -1,20 +1,37 @@ package io.split.android.client.network; -public class HttpResponseImpl extends BaseHttpResponseImpl implements HttpResponse { +import java.security.cert.Certificate; + +public class HttpResponseImpl extends BaseHttpResponseImpl implements HttpResponse { private final String mData; + private final Certificate[] mServerCertificates; HttpResponseImpl(int httpStatus) { - this(httpStatus, null); + this(httpStatus, (String) null); + } + + HttpResponseImpl(int httpStatus, Certificate[] serverCertificates) { + this(httpStatus, null, serverCertificates); } public HttpResponseImpl(int httpStatus, String data) { + this(httpStatus, data, null); + } + + public HttpResponseImpl(int httpStatus, String data, Certificate[] serverCertificates) { super(httpStatus); mData = data; + mServerCertificates = serverCertificates; } @Override public String getData() { return mData; } + + @Override + public Certificate[] getServerCertificates() { + return mServerCertificates; + } } diff --git a/src/main/java/io/split/android/client/network/HttpStreamRequest.java b/src/main/java/io/split/android/client/network/HttpStreamRequest.java index 0b11d6844..f0bb28c67 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamRequest.java +++ b/src/main/java/io/split/android/client/network/HttpStreamRequest.java @@ -1,7 +1,9 @@ package io.split.android.client.network; +import java.io.IOException; + public interface HttpStreamRequest { void addHeader(String name, String value); - HttpStreamResponse execute() throws HttpException; + HttpStreamResponse execute() throws HttpException, IOException; void close(); } diff --git a/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java b/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java index 60453013b..d6f48b8d9 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java +++ b/src/main/java/io/split/android/client/network/HttpStreamRequestImpl.java @@ -1,11 +1,11 @@ package io.split.android.client.network; import static io.split.android.client.network.HttpRequestHelper.checkPins; +import static io.split.android.client.network.HttpRequestHelper.createConnection; import static io.split.android.client.utils.Utils.checkNotNull; import static io.split.android.client.network.HttpRequestHelper.applySslConfig; import static io.split.android.client.network.HttpRequestHelper.applyTimeouts; -import static io.split.android.client.network.HttpRequestHelper.openConnection; import androidx.annotation.NonNull; import androidx.annotation.Nullable; @@ -18,6 +18,7 @@ import java.net.MalformedURLException; import java.net.ProtocolException; import java.net.Proxy; +import java.net.SocketException; import java.net.URI; import java.net.URL; import java.util.HashMap; @@ -52,6 +53,12 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { @Nullable private final CertificateChecker mCertificateChecker; private final AtomicBoolean mWasRetried = new AtomicBoolean(false); + @Nullable + private final HttpProxy mHttpProxy; + @Nullable + private final ProxyCredentialsProvider mProxyCredentialsProvider; + @Nullable + private final ProxyCacertConnectionHandler mConnectionHandler; HttpStreamRequestImpl(@NonNull URI uri, @NonNull Map headers, @@ -61,7 +68,10 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { @Nullable DevelopmentSslConfig developmentSslConfig, @Nullable SSLSocketFactory sslSocketFactory, @NonNull UrlSanitizer urlSanitizer, - @Nullable CertificateChecker certificateChecker) { + @Nullable CertificateChecker certificateChecker, + @Nullable HttpProxy httpProxy, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + @Nullable ProxyCacertConnectionHandler proxyCacertConnectionHandler) { mUri = checkNotNull(uri); mHttpMethod = HttpMethod.GET; mProxy = proxy; @@ -72,10 +82,13 @@ public class HttpStreamRequestImpl implements HttpStreamRequest { mDevelopmentSslConfig = developmentSslConfig; mSslSocketFactory = sslSocketFactory; mCertificateChecker = certificateChecker; + mHttpProxy = httpProxy; + mProxyCredentialsProvider = proxyCredentialsProvider; + mConnectionHandler = proxyCacertConnectionHandler; } @Override - public HttpStreamResponse execute() throws HttpException { + public HttpStreamResponse execute() throws HttpException, IOException { return getRequest(); } @@ -107,14 +120,18 @@ private void closeBufferedReader() { } } - private HttpStreamResponse getRequest() throws HttpException { + private HttpStreamResponse getRequest() throws HttpException, IOException { HttpStreamResponse response; try { - mConnection = setUpConnection(false); - response = buildResponse(mConnection); + if (mConnectionHandler != null && mHttpProxy != null && mSslSocketFactory != null) { + response = mConnectionHandler.executeStreamRequest(mHttpProxy, getUrl(), mHttpMethod, mHeaders, mSslSocketFactory, mProxyCredentialsProvider); + } else { + mConnection = setUpConnection(false); + response = buildResponse(mConnection); - if (response.getHttpStatus() == HttpURLConnection.HTTP_PROXY_AUTH) { - response = handleAuthentication(response); + if (response.getHttpStatus() == HttpURLConnection.HTTP_PROXY_AUTH) { + response = handleAuthentication(response); + } } } catch (MalformedURLException e) { disconnect(); @@ -125,6 +142,11 @@ private HttpStreamResponse getRequest() throws HttpException { } catch (SSLPeerUnverifiedException e) { disconnect(); throw new HttpException("SSL peer not verified: " + e.getLocalizedMessage(), HttpStatus.INTERNAL_NON_RETRYABLE.getCode()); + } catch (SocketException e) { + disconnect(); + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + // This ensures socket closures are treated the same in both direct and proxy flows + throw e; } catch (IOException e) { disconnect(); throw new HttpException("Something happened while retrieving data: " + e.getLocalizedMessage()); @@ -134,12 +156,20 @@ private HttpStreamResponse getRequest() throws HttpException { } private HttpURLConnection setUpConnection(boolean useProxyAuthenticator) throws IOException { - URL url = mUrlSanitizer.getUrl(mUri); - if (url == null) { - throw new IOException("Error parsing URL"); - } - - HttpURLConnection connection = openConnection(mProxy, mProxyAuthenticator, url, mHttpMethod, mHeaders, useProxyAuthenticator); + URL url = getUrl(); + + HttpURLConnection connection = createConnection( + url, + mProxy, + mHttpProxy, + mProxyAuthenticator, + mHttpMethod, + mHeaders, + useProxyAuthenticator, + mSslSocketFactory, + mProxyCredentialsProvider, + null + ); applyTimeouts(HttpStreamRequestImpl.STREAMING_READ_TIMEOUT_IN_MILLISECONDS, mConnectionTimeout, connection); applySslConfig(mSslSocketFactory, mDevelopmentSslConfig, connection); connection.connect(); @@ -148,6 +178,15 @@ private HttpURLConnection setUpConnection(boolean useProxyAuthenticator) throws return connection; } + @NonNull + private URL getUrl() throws IOException { + URL url = mUrlSanitizer.getUrl(mUri); + if (url == null) { + throw new IOException("Error parsing URL"); + } + return url; + } + private HttpStreamResponse handleAuthentication(HttpStreamResponse response) throws HttpException { if (!mWasRetried.getAndSet(true)) { try { @@ -171,11 +210,11 @@ private HttpStreamResponse buildResponse(HttpURLConnection connection) throws IO } mBufferedReader = new BufferedReader(new InputStreamReader(inputStream)); - return new HttpStreamResponseImpl(responseCode, mBufferedReader); + return HttpStreamResponseImpl.createFromHttpUrlConnection(responseCode, mBufferedReader); } } - return new HttpStreamResponseImpl(responseCode); + return HttpStreamResponseImpl.createFromHttpUrlConnection(responseCode, null); } private void disconnect() { diff --git a/src/main/java/io/split/android/client/network/HttpStreamResponse.java b/src/main/java/io/split/android/client/network/HttpStreamResponse.java index e20096041..f733bd3bc 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamResponse.java +++ b/src/main/java/io/split/android/client/network/HttpStreamResponse.java @@ -3,8 +3,9 @@ import androidx.annotation.Nullable; import java.io.BufferedReader; +import java.io.Closeable; -public interface HttpStreamResponse extends BaseHttpResponse { +public interface HttpStreamResponse extends BaseHttpResponse, Closeable { @Nullable BufferedReader getBufferedReader(); } diff --git a/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java b/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java index 175433c44..bf24d0e74 100644 --- a/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java +++ b/src/main/java/io/split/android/client/network/HttpStreamResponseImpl.java @@ -3,18 +3,39 @@ import androidx.annotation.Nullable; import java.io.BufferedReader; +import java.io.IOException; +import java.net.Socket; + +import io.split.android.client.utils.logger.Logger; public class HttpStreamResponseImpl extends BaseHttpResponseImpl implements HttpStreamResponse { private final BufferedReader mData; - HttpStreamResponseImpl(int httpStatus) { - this(httpStatus, null); - } + // Sockets are referenced when using Proxy tunneling, in order to close them + @Nullable + private final Socket mTunnelSocket; + @Nullable + private final Socket mOriginSocket; - public HttpStreamResponseImpl(int httpStatus, BufferedReader data) { + private HttpStreamResponseImpl(int httpStatus, BufferedReader data, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) { super(httpStatus); mData = data; + mTunnelSocket = tunnelSocket; + mOriginSocket = originSocket; + } + + static HttpStreamResponseImpl createFromTunnelSocket(int httpStatus, + BufferedReader data, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) { + return new HttpStreamResponseImpl(httpStatus, data, tunnelSocket, originSocket); + } + + static HttpStreamResponseImpl createFromHttpUrlConnection(int httpStatus, BufferedReader data) { + return new HttpStreamResponseImpl(httpStatus, data, null, null); } @Override @@ -22,4 +43,37 @@ public HttpStreamResponseImpl(int httpStatus, BufferedReader data) { public BufferedReader getBufferedReader() { return mData; } + + @Override + public void close() throws IOException { + + // Close the BufferedReader first + if (mData != null) { + try { + mData.close(); + } catch (IOException e) { + Logger.w("Failed to close BufferedReader: " + e.getMessage()); + } + } + + // Close origin socket if it exists and is different from tunnel socket + if (mOriginSocket != null && mOriginSocket != mTunnelSocket) { + try { + mOriginSocket.close(); + Logger.v("Origin socket closed"); + } catch (IOException e) { + Logger.w("Failed to close origin socket: " + e.getMessage()); + } + } + + // Close tunnel socket + if (mTunnelSocket != null) { + try { + mTunnelSocket.close(); + Logger.v("Tunnel socket closed"); + } catch (IOException e) { + Logger.w("Failed to close tunnel socket: " + e.getMessage()); + } + } + } } diff --git a/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java b/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java new file mode 100644 index 000000000..65cdcb6e4 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxyCacertConnectionHandler.java @@ -0,0 +1,242 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.IOException; +import java.net.Socket; +import java.net.SocketException; +import java.net.URL; +import java.security.cert.Certificate; +import java.util.Map; + +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +import io.split.android.client.utils.logger.Logger; + +/** + * Handles PROXY_CACERT SSL proxy connections. + *

+ * This handler establishes SSL tunnels through SSL proxies using custom CA certificates + * for proxy authentication, then executes HTTP requests through the SSL tunnel. + */ +class ProxyCacertConnectionHandler { + + public static final String HTTPS = "https"; + public static final String HTTP = "http"; + public static final int PORT_HTTPS = 443; + public static final int PORT_HTTP = 80; + private final HttpOverTunnelExecutor mTunnelExecutor; + + public ProxyCacertConnectionHandler() { + mTunnelExecutor = new HttpOverTunnelExecutor(); + } + + /** + * Executes an HTTP request through an SSL proxy tunnel. + * + * @param httpProxy The proxy configuration + * @param targetUrl The target URL to connect to + * @param method The HTTP method to use + * @param headers The HTTP headers to include + * @param body The request body (if any) + * @param sslSocketFactory The SSL socket factory for proxy and origin connections + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @return The HTTP response + * @throws IOException if the request fails + */ + @NonNull + HttpResponse executeRequest(@NonNull HttpProxy httpProxy, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @Nullable String body, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + try { + TunnelConnection connection = establishTunnelConnection( + httpProxy, targetUrl, sslSocketFactory, proxyCredentialsProvider, false); + + try { + return mTunnelExecutor.executeRequest( + connection.finalSocket, + targetUrl, + method, + headers, + body, + connection.serverCertificates); + } finally { + // Close all sockets for non-streaming requests + closeConnection(connection); + } + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + throw e; + } catch (Exception e) { + Logger.e("Failed to execute request through custom tunnel: " + e.getMessage()); + throw new IOException("Failed to execute request through custom tunnel", e); + } + } + + @NonNull + HttpStreamResponse executeStreamRequest(@NonNull HttpProxy httpProxy, + @NonNull URL targetUrl, + @NonNull HttpMethod method, + @NonNull Map headers, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + try { + TunnelConnection connection = establishTunnelConnection( + httpProxy, targetUrl, sslSocketFactory, proxyCredentialsProvider, true); + + // For streaming requests, pass socket references to the response for later cleanup + Socket originSocket = (connection.finalSocket != connection.tunnelSocket) ? connection.finalSocket : null; + return mTunnelExecutor.executeStreamRequest( + connection.finalSocket, + connection.tunnelSocket, + originSocket, + targetUrl, + method, + headers, + connection.serverCertificates); + // For streaming requests, sockets are NOT closed here + // They will be closed when the HttpStreamResponse.close() is called + } catch (SocketException e) { + // Let socket-related IOExceptions pass through unwrapped for consistent error handling + throw e; + } catch (Exception e) { + throw new IOException("Failed to execute request through custom tunnel", e); + } + } + + private static int getTargetPort(@NonNull URL targetUrl) { + int port = targetUrl.getPort(); + if (port == -1) { + if (HTTPS.equalsIgnoreCase(targetUrl.getProtocol())) { + return PORT_HTTPS; + } else if (HTTP.equalsIgnoreCase(targetUrl.getProtocol())) { + return PORT_HTTP; + } + } + return port; + } + + /** + * Represents a connection through an SSL tunnel. + */ + private static class TunnelConnection { + final Socket tunnelSocket; + final Socket finalSocket; + final Certificate[] serverCertificates; + + TunnelConnection(Socket tunnelSocket, Socket finalSocket, Certificate[] serverCertificates) { + this.tunnelSocket = tunnelSocket; + this.finalSocket = finalSocket; + this.serverCertificates = serverCertificates; + } + } + + /** + * Establishes a tunnel connection to the target through the proxy. + * + * @param httpProxy The proxy configuration + * @param targetUrl The target URL to connect to + * @param sslSocketFactory SSL socket factory for connections + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @param isStreaming Whether this is a streaming connection + * @return A TunnelConnection object containing the established sockets + * @throws IOException if connection establishment fails + */ + private TunnelConnection establishTunnelConnection( + @NonNull HttpProxy httpProxy, + @NonNull URL targetUrl, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + boolean isStreaming) throws IOException { + + SslProxyTunnelEstablisher tunnelEstablisher = new SslProxyTunnelEstablisher(); + Socket tunnelSocket = null; + Socket finalSocket = null; + Certificate[] serverCertificates = null; + + try { + tunnelSocket = tunnelEstablisher.establishTunnel( + httpProxy.getHost(), + httpProxy.getPort(), + targetUrl.getHost(), + getTargetPort(targetUrl), + sslSocketFactory, + proxyCredentialsProvider, + isStreaming + ); + + finalSocket = tunnelSocket; + + // If the origin is HTTPS, wrap the tunnel socket with a new SSLSocket (system CA) + if (HTTPS.equalsIgnoreCase(targetUrl.getProtocol())) { + try { + // Use the provided SSLSocketFactory, which is configured to trust the origin's CA + finalSocket = sslSocketFactory.createSocket( + tunnelSocket, + targetUrl.getHost(), + getTargetPort(targetUrl), + true // autoClose + ); + if (finalSocket instanceof SSLSocket) { + SSLSocket originSslSocket = (SSLSocket) finalSocket; + originSslSocket.setUseClientMode(true); + originSslSocket.startHandshake(); + + // Capture server certificates after successful handshake + try { + serverCertificates = originSslSocket.getSession().getPeerCertificates(); + } catch (Exception certEx) { + Logger.w("Could not capture origin server certificates: " + certEx.getMessage()); + } + } else { + throw new IOException("Failed to create SSLSocket to origin"); + } + } catch (Exception sslEx) { + Logger.e("Failed to establish SSL connection to origin: " + sslEx.getMessage()); + throw new IOException("Failed to establish SSL connection to origin server", sslEx); + } + } + + return new TunnelConnection(tunnelSocket, finalSocket, serverCertificates); + } catch (Exception e) { + // Clean up resources on error + closeSockets(finalSocket, tunnelSocket); + throw e; + } + } + + private void closeConnection(TunnelConnection connection) { + if (connection == null) { + return; + } + + closeSockets(connection.finalSocket, connection.tunnelSocket); + } + + private void closeSockets(Socket finalSocket, Socket tunnelSocket) { + // If we are tunnelling, finalSocket is the tunnel socket + if (finalSocket != null && finalSocket != tunnelSocket) { + try { + finalSocket.close(); + } catch (IOException e) { + Logger.w("Failed to close origin SSL socket: " + e.getMessage()); + } + } + + if (tunnelSocket != null) { + try { + tunnelSocket.close(); + } catch (IOException e) { + Logger.w("Failed to close tunnel socket: " + e.getMessage()); + } + } + } +} diff --git a/src/main/java/io/split/android/client/network/ProxyConfiguration.java b/src/main/java/io/split/android/client/network/ProxyConfiguration.java new file mode 100644 index 000000000..e25314992 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxyConfiguration.java @@ -0,0 +1,137 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; + +import io.split.android.client.utils.logger.Logger; + +/** + * Proxy configuration + */ +public class ProxyConfiguration { + + private final URI mUrl; + private final ProxyCredentialsProvider mCredentialsProvider; + private final InputStream mClientCert; + private final InputStream mClientPk; + private final InputStream mCaCert; + + ProxyConfiguration(@NonNull URI url, + @Nullable ProxyCredentialsProvider credentialsProvider, + @Nullable InputStream clientCert, + @Nullable InputStream clientPk, + @Nullable InputStream caCert) { + mUrl = url; + mCredentialsProvider = credentialsProvider; + mClientCert = clientCert; + mClientPk = clientPk; + mCaCert = caCert; + } + + public URI getUrl() { + return mUrl; + } + + public ProxyCredentialsProvider getCredentialsProvider() { + return mCredentialsProvider; + } + + public InputStream getClientCert() { + return mClientCert; + } + + public InputStream getClientPk() { + return mClientPk; + } + + public InputStream getCaCert() { + return mCaCert; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private URI mUrl; + private ProxyCredentialsProvider mCredentialsProvider; + private InputStream mClientCert; + private InputStream mClientPk; + private InputStream mCaCert; + + private Builder() { + + } + + /** + * Set the proxy URL + * + * @param url MUST NOT be null + * @return this builder + */ + public Builder url(@NonNull String url) { + try { + mUrl = new URI(url); + } catch (NullPointerException | URISyntaxException e) { + Logger.e("Proxy url was not a valid URL."); + } + return this; + } + + /** + * Set the credentials provider. + *

+ * Can be an implementation of {@link BearerCredentialsProvider} or {@link BasicCredentialsProvider} + * + * @param credentialsProvider A non null credentials provider + * @return this builder + */ + public Builder credentialsProvider(@NonNull ProxyCredentialsProvider credentialsProvider) { + mCredentialsProvider = credentialsProvider; + return this; + } + + /** + * Set the client certificate and private key in PKCS#8 format + * + * @param clientCert The client certificate + * @param clientPk The client private key + * @return this builder + */ + public Builder mtls(@NonNull InputStream clientCert, @NonNull InputStream clientPk) { + mClientCert = clientCert; + mClientPk = clientPk; + return this; + } + + /** + * Set the Proxy CA certificate + * + * @param caCert The CA certificate in PEM or DER format + * @return this builder + */ + public Builder caCert(@NonNull InputStream caCert) { + mCaCert = caCert; + return this; + } + + /** + * Build the proxy configuration. + * This method will return null if the proxy URL is not set. + * + * @return The proxy configuration + */ + @Nullable + public ProxyConfiguration build() { + if (mUrl == null) { + Logger.w("Proxy configuration with no URL. This will prevent SplitFactory from working."); + } + + return new ProxyConfiguration(mUrl, mCredentialsProvider, mClientCert, mClientPk, mCaCert); + } + } +} diff --git a/src/main/java/io/split/android/client/network/ProxyCredentialsProvider.java b/src/main/java/io/split/android/client/network/ProxyCredentialsProvider.java new file mode 100644 index 000000000..6533883a8 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxyCredentialsProvider.java @@ -0,0 +1,10 @@ +package io.split.android.client.network; + +/** + * Interface for providing proxy credentials. + *

+ * Implementations can be {@link BasicCredentialsProvider} or {@link BearerCredentialsProvider} + */ +public interface ProxyCredentialsProvider { + +} diff --git a/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProvider.java b/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProvider.java new file mode 100644 index 000000000..10c4d8862 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProvider.java @@ -0,0 +1,31 @@ +package io.split.android.client.network; + +import androidx.annotation.Nullable; + +import java.io.InputStream; + +import javax.net.ssl.SSLSocketFactory; + +interface ProxySslSocketFactoryProvider { + + /** + * Create an SSLSocketFactory for proxy connections using a CA certificate from an InputStream. + * The InputStream will be closed after use. + * + * @param caCertInputStream InputStream containing CA certificate (PEM or DER). + * @return SSLSocketFactory configured for the requested scenario + */ + SSLSocketFactory create(@Nullable InputStream caCertInputStream) throws Exception; + + /** + * Create an SSLSocketFactory for proxy connections using CA cert and separate client certificate and key files. + * All InputStreams will be closed after use. + * + * @param caCertInputStream InputStream containing one or more CA certificates (PEM or DER). + * @param clientCertInputStream InputStream containing client certificate (PEM or DER). + * @param clientKeyInputStream InputStream containing client private key (PEM format, PKCS#8). + + * @return SSLSocketFactory configured for mTLS proxy authentication + */ + SSLSocketFactory create(@Nullable InputStream caCertInputStream, @Nullable InputStream clientCertInputStream, @Nullable InputStream clientKeyInputStream) throws Exception; +} diff --git a/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProviderImpl.java b/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProviderImpl.java new file mode 100644 index 000000000..49a84c134 --- /dev/null +++ b/src/main/java/io/split/android/client/network/ProxySslSocketFactoryProviderImpl.java @@ -0,0 +1,270 @@ +package io.split.android.client.network; + +import static io.split.android.client.utils.Utils.checkNotNull; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Collection; +import java.util.Enumeration; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +import io.split.android.client.utils.logger.Logger; + +class ProxySslSocketFactoryProviderImpl implements ProxySslSocketFactoryProvider { + + private final Base64Decoder mBase64Decoder; + + ProxySslSocketFactoryProviderImpl() { + this(new DefaultBase64Decoder()); + } + + ProxySslSocketFactoryProviderImpl(@NonNull Base64Decoder base64Decoder) { + mBase64Decoder = checkNotNull(base64Decoder); + } + + @Override + public SSLSocketFactory create(@Nullable InputStream caCertInputStream) throws Exception { + // The TrustManagerFactory is necessary because of the CA cert + return createSslSocketFactory(null, createTrustManagerFactory(caCertInputStream)); + } + + @Override + public SSLSocketFactory create(@Nullable InputStream caCertInputStream, @Nullable InputStream clientCertInputStream, @Nullable InputStream clientKeyInputStream) throws Exception { + // The KeyManagerFactory is necessary because of the client certificate and key files + KeyManagerFactory keyManagerFactory = createKeyManagerFactory(clientCertInputStream, clientKeyInputStream); + + // The TrustManagerFactory is necessary because of the CA cert + TrustManagerFactory trustManagerFactory = createTrustManagerFactory(caCertInputStream); + + return createSslSocketFactory(keyManagerFactory, trustManagerFactory); + } + + @Nullable + private TrustManagerFactory createTrustManagerFactory(@Nullable InputStream caCertInputStream) throws Exception { + if (caCertInputStream == null) { + return null; + } + + try { + // Generate Certificate objects from the InputStream + CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + Collection caCertificates = certificateFactory.generateCertificates(caCertInputStream); + + KeyStore combinedTrustStore = getCombinedStore(caCertificates); + + // Initialize the TrustManagerFactory with the combined trust store + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(combinedTrustStore); + + return trustManagerFactory; + } finally { + caCertInputStream.close(); + } + } + + /** + * Create a KeyStore with both system CAs and user provided CAs + * @param caCertificates User provided CAs + * @return KeyStore + */ + @NonNull + private static KeyStore getCombinedStore(Collection caCertificates) throws NoSuchAlgorithmException, KeyStoreException, CertificateException, IOException { + // Start with the system's default trust store to include standard CA certificates + TrustManagerFactory defaultTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + defaultTrustManagerFactory.init((KeyStore) null); // Initialize with system default keystore + + // Get the default trust store + KeyStore defaultTrustStore = null; + for (TrustManager tm : defaultTrustManagerFactory.getTrustManagers()) { + if (tm instanceof X509TrustManager) { + // Create a new keystore and populate it with system CAs + defaultTrustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + defaultTrustStore.load(null, null); + + X509Certificate[] acceptedIssuers = ((X509TrustManager) tm).getAcceptedIssuers(); + for (int j = 0; j < acceptedIssuers.length; j++) { + defaultTrustStore.setCertificateEntry("systemCA" + j, acceptedIssuers[j]); + } + break; + } + } + + // Create combined trust store with both system CAs and custom proxy CAs + KeyStore combinedTrustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + combinedTrustStore.load(null, null); + + // Add system CA certificates if we found them + if (defaultTrustStore != null) { + Enumeration aliases = defaultTrustStore.aliases(); + while (aliases.hasMoreElements()) { + String alias = aliases.nextElement(); + Certificate cert = defaultTrustStore.getCertificate(alias); + if (cert != null) { + combinedTrustStore.setCertificateEntry(alias, cert); + } + } + } + + // Add custom proxy CA certificates + int i = 0; + for (Certificate ca : caCertificates) { + combinedTrustStore.setCertificateEntry("proxyCA" + (i++), ca); + } + return combinedTrustStore; + } + + /** + * Creates a KeyManagerFactory from separate client certificate and key files. + * + * @param clientCertInputStream InputStream containing client certificate (PEM or DER) + * @param clientKeyInputStream InputStream containing client private key (PEM format) + * @return KeyManagerFactory initialized with the client certificate and key + * @throws Exception if there is an error loading the certificate or key + */ + private KeyManagerFactory createKeyManagerFactory(@Nullable InputStream clientCertInputStream, + @Nullable InputStream clientKeyInputStream) throws Exception { + if (clientCertInputStream == null || clientKeyInputStream == null) { + return null; + } + + try { + // Get cert and key + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + Certificate cert = cf.generateCertificate(clientCertInputStream); + + PrivateKey privateKey = loadPrivateKeyFromPem(clientKeyInputStream); + + // Initialize a KeyStore and add the cert and key + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, null); // Initialize empty keystore + keyStore.setKeyEntry("client", privateKey, new char[0], new Certificate[] { cert }); + + // Initialize the KeyManagerFactory with the created KeyStore + KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, new char[0]); + + return keyManagerFactory; + } finally { + clientCertInputStream.close(); + clientKeyInputStream.close(); + } + } + + /** + * Loads a private key from a PEM-encoded input stream. + * Only supports PKCS#8 format (BEGIN PRIVATE KEY). + * + * @param keyInputStream InputStream containing the PEM-encoded private key + * @return PrivateKey object + * @throws Exception if there is an error loading the key + */ + private PrivateKey loadPrivateKeyFromPem(InputStream keyInputStream) throws Exception { + try { + // Read the key file + String keyContent = readInputStream(keyInputStream); + if (keyContent.contains("BEGIN PRIVATE KEY")) { + // PKCS#8 format - can be loaded directly + return loadPkcs8PrivateKey(keyContent); + } else { + throw new IllegalArgumentException("Unsupported private key format. Must be PEM encoded PKCS#8 format (BEGIN PRIVATE KEY)"); + } + } catch (Exception e) { + Logger.e("Error loading private key: " + e.getMessage()); + throw e; + } + } + + /** + * Loads a PKCS#8 format private key. + */ + private PrivateKey loadPkcs8PrivateKey(String keyContent) throws Exception { + // Extract the base64 encoded private key using proper PEM parsing + String privateKeyPEM = extractPemContent(keyContent); + + // Decode the Base64 encoded private key + byte[] encoded = mBase64Decoder.decode(privateKeyPEM); + + // Create a PKCS8 key spec and generate the private key + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(encoded); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + try { + return keyFactory.generatePrivate(keySpec); + } catch (InvalidKeySpecException e) { + // Try with EC algorithm if RSA fails + try { + keyFactory = KeyFactory.getInstance("EC"); + return keyFactory.generatePrivate(keySpec); + } catch (InvalidKeySpecException ecException) { + Logger.e("Error loading private key: Neither RSA nor EC algorithms could load the key"); + throw new IllegalArgumentException("Invalid PKCS#8 private key format. Key could not be loaded with RSA or EC algorithms.", e); + } + } + } + + private String readInputStream(InputStream inputStream) throws IOException { + StringBuilder stringBuilder = new StringBuilder(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + stringBuilder.append(line).append("\n"); + } + } + return stringBuilder.toString(); + } + + /** + * Extracts the base64 content from a PKCS#8 String. + * + * @param pemContent The full PEM content + * @return The base64 encoded content without PEM boundaries or whitespace + * @throws IllegalArgumentException if PEM boundaries are not found or malformed + */ + private String extractPemContent(String pemContent) throws IllegalArgumentException { + String beginMarker = "-----BEGIN PRIVATE KEY-----"; + int beginIndex = pemContent.indexOf(beginMarker); + if (beginIndex == -1) { + throw new IllegalArgumentException("PEM begin marker not found: " + beginMarker); + } + String endMarker = "-----END PRIVATE KEY-----"; + int endIndex = pemContent.indexOf(endMarker, beginIndex + beginMarker.length()); + if (endIndex == -1) { + throw new IllegalArgumentException("PEM end marker not found: " + endMarker); + } + String base64Content = pemContent.substring(beginIndex + beginMarker.length(), endIndex); + return base64Content.replaceAll("\\s+", ""); + } + + private SSLSocketFactory createSslSocketFactory(@Nullable KeyManagerFactory keyManagerFactory, @Nullable TrustManagerFactory trustManagerFactory) throws Exception { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyManager[] keyManagers = keyManagerFactory != null ? keyManagerFactory.getKeyManagers() : null; + TrustManager[] trustManagers = trustManagerFactory != null ? trustManagerFactory.getTrustManagers() : null; + + sslContext.init(keyManagers, trustManagers, null); + + return sslContext.getSocketFactory(); + } +} diff --git a/src/main/java/io/split/android/client/network/RawHttpResponseParser.java b/src/main/java/io/split/android/client/network/RawHttpResponseParser.java new file mode 100644 index 000000000..5968419cc --- /dev/null +++ b/src/main/java/io/split/android/client/network/RawHttpResponseParser.java @@ -0,0 +1,299 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; + +import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.Socket; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.cert.Certificate; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import io.split.android.client.utils.logger.Logger; + +/** + * Parses raw HTTP protocol responses from socket input streams. + * Handles the HTTP protocol parsing (status line, headers, body) from socket streams. + */ +class RawHttpResponseParser { + + /** + * Parses a raw HTTP response from an input stream. + * + * @param inputStream The input stream containing the raw HTTP response + * @param serverCertificates The server certificates to include in the response + * @return HttpResponse containing the parsed status code, headers, and response data + * @throws IOException if parsing fails or the response is malformed + */ + @NonNull + HttpResponse parseHttpResponse(@NonNull InputStream inputStream, Certificate[] serverCertificates) throws IOException { + // 1. Read and parse status line + String statusLine = readLineFromStream(inputStream); + if (statusLine == null) { + throw new IOException("No HTTP response received from server"); + } + + int statusCode = parseStatusCode(statusLine); + + // 2. Read and parse response headers directly + ParsedResponseHeaders responseHeaders = parseHeadersDirectly(inputStream); + + // 3. Determine charset from Content-Type header + Charset bodyCharset = extractCharsetFromContentType(responseHeaders.mContentType); + + // 4. Read response body using the same InputStream + String responseBody = readResponseBody(inputStream, responseHeaders.mIsChunked, bodyCharset, responseHeaders.mContentLength, responseHeaders.mConnectionClose); + + // 5. Create and return HttpResponse + if (responseBody != null && !responseBody.trim().isEmpty()) { + return new HttpResponseImpl(statusCode, responseBody, serverCertificates); + } else { + return new HttpResponseImpl(statusCode, serverCertificates); + } + } + + @NonNull + HttpStreamResponse parseHttpStreamResponse(@NonNull InputStream inputStream, + @Nullable Socket tunnelSocket, + @Nullable Socket originSocket) throws IOException { + // 1. Read and parse status line + String statusLine = readLineFromStream(inputStream); + if (statusLine == null) { + throw new IOException("No HTTP response received from server"); + } + + int statusCode = parseStatusCode(statusLine); + + // 2. Read and parse response headers directly + ParsedResponseHeaders responseHeaders = parseHeadersDirectly(inputStream); + + // 3. Determine charset from Content-Type header + Charset bodyCharset = extractCharsetFromContentType(responseHeaders.mContentType); + + return HttpStreamResponseImpl.createFromTunnelSocket(statusCode, + new BufferedReader(new InputStreamReader(inputStream, bodyCharset)), + tunnelSocket, + originSocket); + } + + @NonNull + private ParsedResponseHeaders parseHeadersDirectly(@NonNull InputStream inputStream) throws IOException { + int contentLength = -1; + boolean isChunked = false; + boolean connectionClose = false; + String contentType = null; + String headerLine; + + while ((headerLine = readLineFromStream(inputStream)) != null && !headerLine.trim().isEmpty()) { + int colonIndex = headerLine.indexOf(':'); + if (colonIndex > 0) { + String headerName = headerLine.substring(0, colonIndex).trim(); + String headerValue = headerLine.substring(colonIndex + 1).trim(); + + String lowerHeaderName = headerName.toLowerCase(Locale.US); + if ("content-length".equals(lowerHeaderName)) { + try { + contentLength = Integer.parseInt(headerValue); + } catch (NumberFormatException e) { + Logger.w("Invalid Content-Length header: " + headerLine); + } + } else if ("transfer-encoding".equals(lowerHeaderName) && headerValue.toLowerCase(Locale.US).contains("chunked")) { + isChunked = true; + } else if ("connection".equals(lowerHeaderName) && headerValue.toLowerCase(Locale.US).contains("close")) { + connectionClose = true; + } else if ("content-type".equals(lowerHeaderName)) { + contentType = headerValue; + } + } + } + return new ParsedResponseHeaders(contentLength, isChunked, connectionClose, contentType); + } + + @Nullable + private String readResponseBody(@NonNull InputStream inputStream, boolean isChunked, Charset bodyCharset, int contentLength, boolean connectionClose) throws IOException { + String responseBody = null; + if (isChunked) { + responseBody = readChunkedBodyWithCharset(inputStream, bodyCharset); + } else if (contentLength > 0) { + responseBody = readFixedLengthBodyWithCharset(inputStream, contentLength, bodyCharset); + } else if (connectionClose) { + responseBody = readUntilCloseWithCharset(inputStream, bodyCharset); + } + return responseBody; + } + + /** + * Parses the HTTP status code from the status line. + */ + private int parseStatusCode(@NonNull String statusLine) throws IOException { + // Status line format: "HTTP/1.1 200 OK" or "HTTP/1.0 404 Not Found" + String[] parts = statusLine.split(" "); + if (parts.length < 2) { + throw new IOException("Invalid HTTP status line: " + statusLine); + } + + try { + return Integer.parseInt(parts[1]); + } catch (NumberFormatException e) { + throw new IOException("Invalid HTTP status code in line: " + statusLine, e); + } + } + + /** + * Extracts charset from Content-Type header, defaulting to UTF-8. + */ + private Charset extractCharsetFromContentType(String contentType) { + if (contentType == null) { + return StandardCharsets.UTF_8; + } + + // Pattern to match charset=value in Content-Type header + Pattern charsetPattern = Pattern.compile("charset\\s*=\\s*([^\\s;]+)", Pattern.CASE_INSENSITIVE); + Matcher matcher = charsetPattern.matcher(contentType); + + if (matcher.find()) { + String charsetName = matcher.group(1).replaceAll("[\"']", ""); // Remove quotes + try { + return Charset.forName(charsetName); + } catch (Exception e) { + Logger.w("Unsupported charset: " + charsetName + ", using UTF-8"); + } + } + + return StandardCharsets.UTF_8; + } + + private String readChunkedBodyWithCharset(InputStream inputStream, Charset charset) throws IOException { + ByteArrayOutputStream bodyBytes = new ByteArrayOutputStream(); + + while (true) { + // Read chunk size line + String chunkSizeLine = readLineFromStream(inputStream); + if (chunkSizeLine == null) { + throw new IOException("Unexpected EOF while reading chunk size"); + } + + // Parse chunk size (ignore extensions after semicolon) + int semicolonIndex = chunkSizeLine.indexOf(';'); + String sizeStr = semicolonIndex >= 0 ? chunkSizeLine.substring(0, semicolonIndex).trim() : chunkSizeLine.trim(); + + int chunkSize; + try { + chunkSize = Integer.parseInt(sizeStr, 16); + } catch (NumberFormatException e) { + throw new IOException("Invalid chunk size: " + chunkSizeLine, e); + } + + if (chunkSize < 0) { + throw new IOException("Negative chunk size: " + chunkSize); + } + + // If chunk size is 0, we've reached the end + if (chunkSize == 0) { + // Read trailing headers until empty line + String trailerLine; + while ((trailerLine = readLineFromStream(inputStream)) != null && !trailerLine.trim().isEmpty()) { + // no-op + } + break; + } + + // Read chunk data (exact byte count) + byte[] chunkData = new byte[chunkSize]; + int totalRead = 0; + while (totalRead < chunkSize) { + int read = inputStream.read(chunkData, totalRead, chunkSize - totalRead); + if (read == -1) { + throw new IOException("Unexpected EOF while reading chunk data"); + } + totalRead += read; + } + + bodyBytes.write(chunkData); + + // Read trailing CRLF after chunk data + int c1 = inputStream.read(); + int c2 = inputStream.read(); + if (c1 != '\r' || c2 != '\n') { + throw new IOException("Expected CRLF after chunk data, got: " + (char) c1 + (char) c2); + } + } + + return new String(bodyBytes.toByteArray(), charset); + } + + private String readFixedLengthBodyWithCharset(InputStream inputStream, int contentLength, Charset charset) throws IOException { + byte[] bodyBytes = new byte[contentLength]; + int totalRead = 0; + + while (totalRead < contentLength) { + int read = inputStream.read(bodyBytes, totalRead, contentLength - totalRead); + if (read == -1) { + throw new IOException("Unexpected EOF while reading fixed-length body"); + } + totalRead += read; + } + + return new String(bodyBytes, charset); + } + + private String readUntilCloseWithCharset(InputStream inputStream, Charset charset) throws IOException { + ByteArrayOutputStream bodyBytes = new ByteArrayOutputStream(); + byte[] buffer = new byte[8192]; + int bytesRead; + + while ((bytesRead = inputStream.read(buffer)) != -1) { + bodyBytes.write(buffer, 0, bytesRead); + } + + return new String(bodyBytes.toByteArray(), charset); + } + + private String readLineFromStream(InputStream inputStream) throws IOException { + ByteArrayOutputStream lineBytes = new ByteArrayOutputStream(); + int b; + boolean foundCR = false; + + while ((b = inputStream.read()) != -1) { + if (b == '\r') { + foundCR = true; + } else if (b == '\n' && foundCR) { + break; + } else if (foundCR) { + // CR not followed by LF, add the CR to output + lineBytes.write('\r'); + lineBytes.write(b); + foundCR = false; + } else { + lineBytes.write(b); + } + } + + if (b == -1 && lineBytes.size() == 0) { + return null; // EOF + } + + return new String(lineBytes.toByteArray(), StandardCharsets.UTF_8); + } + + private static class ParsedResponseHeaders { + final int mContentLength; + final boolean mIsChunked; + final boolean mConnectionClose; + final String mContentType; + + ParsedResponseHeaders(int contentLength, boolean isChunked, boolean connectionClose, String contentType) { + mContentLength = contentLength; + mIsChunked = isChunked; + mConnectionClose = connectionClose; + mContentType = contentType; + } + } +} diff --git a/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java new file mode 100644 index 000000000..d7940e1b5 --- /dev/null +++ b/src/main/java/io/split/android/client/network/SslProxyTunnelEstablisher.java @@ -0,0 +1,216 @@ +package io.split.android.client.network; + +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.net.HttpRetryException; +import java.net.HttpURLConnection; +import java.net.Socket; +import java.nio.charset.StandardCharsets; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +/** + * Establishes SSL tunnels to SSL proxies using CONNECT protocol. + */ +class SslProxyTunnelEstablisher { + + private static final String CRLF = "\r\n"; + private static final String PROXY_AUTHORIZATION_HEADER = "Proxy-Authorization"; + private final Base64Encoder mBase64Encoder; + + // Default timeout for regular connections (10 seconds) + private static final int DEFAULT_SOCKET_TIMEOUT = 20000; + + SslProxyTunnelEstablisher() { + this(new DefaultBase64Encoder()); + } + + @VisibleForTesting + SslProxyTunnelEstablisher(Base64Encoder base64Encoder) { + mBase64Encoder = base64Encoder; + } + + /** + * Establishes an SSL tunnel through the proxy using the CONNECT method. + * After successful tunnel establishment, extracts the underlying socket + * for use with origin server SSL connections. + * + * @param proxyHost The proxy server hostname + * @param proxyPort The proxy server port + * @param targetHost The target server hostname + * @param targetPort The target server port + * @param sslSocketFactory SSL socket factory for proxy authentication + * @param proxyCredentialsProvider Credentials provider for proxy authentication + * @param isStreaming Whether this connection is for streaming (uses longer timeout) + * @return Raw socket with tunnel established (connection maintained) + * @throws IOException if tunnel establishment fails + */ + @NonNull + Socket establishTunnel(@NonNull String proxyHost, + int proxyPort, + @NonNull String targetHost, + int targetPort, + @NonNull SSLSocketFactory sslSocketFactory, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider, + boolean isStreaming) throws IOException { + + Socket rawSocket = null; + SSLSocket sslSocket = null; + + try { + int timeout = DEFAULT_SOCKET_TIMEOUT; + // Step 1: Create raw TCP connection to proxy + rawSocket = new Socket(proxyHost, proxyPort); + rawSocket.setSoTimeout(timeout); + + // Create a temporary SSL socket to establish the SSL session with proper trust validation + sslSocket = (SSLSocket) sslSocketFactory.createSocket(rawSocket, proxyHost, proxyPort, true); + sslSocket.setUseClientMode(true); + if (isStreaming) { + sslSocket.setSoTimeout(0); // no timeout for streaming + } else { + sslSocket.setSoTimeout(timeout); + } + + // Perform SSL handshake using the SSL socket with custom CA certificates + sslSocket.startHandshake(); + + // Validate the proxy hostname + HostnameVerifier verifier = HttpsURLConnection.getDefaultHostnameVerifier(); + if (!verifier.verify(proxyHost, sslSocket.getSession())) { + throw new SSLHandshakeException("Proxy hostname verification failed"); + } + + // Step 3: Send CONNECT request through SSL connection + sendConnectRequest(sslSocket, targetHost, targetPort, proxyCredentialsProvider); + + // Step 4: Validate CONNECT response through SSL connection + validateConnectResponse(sslSocket); + + // Step 5: Return SSL socket for tunnel communication + return sslSocket; + + } catch (Exception e) { + // Clean up resources on error + if (sslSocket != null) { + try { + sslSocket.close(); + } catch (IOException closeEx) { + // Ignore close exceptions + } + } else if (rawSocket != null) { + try { + rawSocket.close(); + } catch (IOException closeEx) { + // Ignore close exceptions + } + } + + if (e instanceof HttpRetryException) { + throw (HttpRetryException) e; + } else if (e instanceof IOException) { + throw (IOException) e; + } else { + throw new IOException("Failed to establish SSL tunnel", e); + } + } + } + + /** + * Sends CONNECT request through SSL connection to proxy. + */ + private void sendConnectRequest(@NonNull SSLSocket sslSocket, + @NonNull String targetHost, + int targetPort, + @Nullable ProxyCredentialsProvider proxyCredentialsProvider) throws IOException { + + PrintWriter writer = new PrintWriter(new OutputStreamWriter(sslSocket.getOutputStream(), StandardCharsets.UTF_8), false); + writer.write("CONNECT " + targetHost + ":" + targetPort + " HTTP/1.1" + CRLF); + writer.write("Host: " + targetHost + ":" + targetPort + CRLF); + + if (proxyCredentialsProvider != null) { + addProxyAuthHeader(proxyCredentialsProvider, writer); + } + + // Send empty line to end headers + writer.write(CRLF); + writer.flush(); + } + + private void addProxyAuthHeader(@NonNull ProxyCredentialsProvider proxyCredentialsProvider, PrintWriter writer) { + if (proxyCredentialsProvider instanceof BearerCredentialsProvider) { + // Send Proxy-Authorization header if credentials are set + String bearerToken = ((BearerCredentialsProvider) proxyCredentialsProvider).getToken(); + if (bearerToken != null && !bearerToken.trim().isEmpty()) { + writer.write(PROXY_AUTHORIZATION_HEADER + ": Bearer " + bearerToken + CRLF); + } + } else if (proxyCredentialsProvider instanceof BasicCredentialsProvider) { + BasicCredentialsProvider basicCredentialsProvider = (BasicCredentialsProvider) proxyCredentialsProvider; + String userName = basicCredentialsProvider.getUsername(); + String password = basicCredentialsProvider.getPassword(); + if (userName != null && !userName.trim().isEmpty() && password != null && !password.trim().isEmpty()) { + writer.write(PROXY_AUTHORIZATION_HEADER + ": Basic " + mBase64Encoder.encode(userName + ":" + password) + CRLF); + } + } + } + + /** + * Validates CONNECT response through SSL connection. + * Only reads status line and headers, leaving the stream open for tunneling. + */ + private void validateConnectResponse(@NonNull SSLSocket sslSocket) throws IOException { + + try { + BufferedReader reader = new BufferedReader(new InputStreamReader(sslSocket.getInputStream(), StandardCharsets.UTF_8)); + + String statusLine = reader.readLine(); + if (statusLine == null) { + throw new IOException("No CONNECT response received from proxy"); + } + + // Parse status code + String[] statusParts = statusLine.split(" "); + if (statusParts.length < 2) { + throw new IOException("Invalid CONNECT response status line: " + statusLine); + } + + int statusCode; + try { + statusCode = Integer.parseInt(statusParts[1]); + } catch (NumberFormatException e) { + throw new IOException("Invalid CONNECT response status code: " + statusLine, e); + } + + // Read headers until empty line (but don't process them for CONNECT) + String headerLine; + while ((headerLine = reader.readLine()) != null && !headerLine.trim().isEmpty()) { + // no-op + } + + // Check status code + if (statusCode != 200) { + if (statusCode == HttpURLConnection.HTTP_PROXY_AUTH) { + throw new HttpRetryException("CONNECT request failed with status " + statusCode + ": " + statusLine, HttpURLConnection.HTTP_PROXY_AUTH); + } + throw new IOException("CONNECT request failed with status " + statusCode + ": " + statusLine); + } + } catch (IOException e) { + if (e instanceof HttpRetryException) { + throw e; + } + + throw new IOException("Failed to validate CONNECT response from proxy: " + e.getMessage(), e); + } + } +} diff --git a/src/main/java/io/split/android/client/service/ServiceConstants.java b/src/main/java/io/split/android/client/service/ServiceConstants.java index ecfa4ad05..95c4886fd 100644 --- a/src/main/java/io/split/android/client/service/ServiceConstants.java +++ b/src/main/java/io/split/android/client/service/ServiceConstants.java @@ -35,6 +35,7 @@ public class ServiceConstants { public static final String WORKER_PARAM_CONFIGURED_FILTER_TYPE = "configuredFilterType"; public static final String WORKER_PARAM_FLAGS_SPEC = "flagsSpec"; public static final String WORKER_PARAM_CERTIFICATE_PINS = "certificatePins"; + public static final String WORKER_PARAM_USES_PROXY = "usesProxy"; public static final int LAST_SEEN_IMPRESSION_CACHE_SIZE = 2000; public static final int MY_SEGMENT_V2_DATA_SIZE = 1024 * 10;// bytes diff --git a/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java b/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java index 9e8f11f21..78a8f316b 100644 --- a/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java +++ b/src/main/java/io/split/android/client/service/sseclient/sseclient/SseClientImpl.java @@ -36,6 +36,7 @@ public class SseClientImpl implements SseClient { private final StringHelper mStringHelper; private HttpStreamRequest mHttpStreamRequest = null; + private HttpStreamResponse mHttpStreamResponse = null; private static final String PUSH_NOTIFICATION_CHANNELS_PARAM = "channel"; private static final String PUSH_NOTIFICATION_TOKEN_PARAM = "accessToken"; @@ -71,8 +72,21 @@ public void disconnect() { private void close() { Logger.d("Disconnecting SSE client"); if (mStatus.getAndSet(DISCONNECTED) != DISCONNECTED) { + // Close the HttpStreamResponse first to clean up sockets + if (mHttpStreamResponse != null) { + try { + mHttpStreamResponse.close(); + Logger.v("HttpStreamResponse closed successfully"); + } catch (IOException e) { + Logger.w("Failed to close HttpStreamResponse: " + e.getMessage()); + } + mHttpStreamResponse = null; + } + + // Close the HttpStreamRequest if (mHttpStreamRequest != null) { mHttpStreamRequest.close(); + mHttpStreamRequest = null; } Logger.d("SSE client disconnected"); } @@ -94,9 +108,9 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { .addParameter(PUSH_NOTIFICATION_TOKEN_PARAM, rawToken) .build(); mHttpStreamRequest = mHttpClient.streamRequest(url); - HttpStreamResponse response = mHttpStreamRequest.execute(); - if (response.isSuccess()) { - bufferedReader = response.getBufferedReader(); + mHttpStreamResponse = mHttpStreamRequest.execute(); + if (mHttpStreamResponse.isSuccess()) { + bufferedReader = mHttpStreamResponse.getBufferedReader(); if (bufferedReader != null) { Logger.d("Streaming connection opened"); mStatus.set(CONNECTED); @@ -126,8 +140,8 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { throw (new IOException("Buffer is null")); } } else { - Logger.e("Streaming connection error. Http return code " + response.getHttpStatus()); - isErrorRetryable = !response.isClientRelatedError(); + Logger.e("Streaming connection error. Http return code " + mHttpStreamResponse.getHttpStatus()); + isErrorRetryable = !mHttpStreamResponse.isClientRelatedError(); } } catch (URISyntaxException e) { logError("An error has occurred while creating stream Url ", e); @@ -149,8 +163,7 @@ public void connect(SseJwtToken token, ConnectionListener connectionListener) { } } - private void logError(String message, Exception e) { + private static void logError(String message, Exception e) { Logger.e(message + " : " + e.getLocalizedMessage()); } - } diff --git a/src/main/java/io/split/android/client/service/synchronizer/WorkManagerWrapper.java b/src/main/java/io/split/android/client/service/synchronizer/WorkManagerWrapper.java index c65042116..ca3761bca 100644 --- a/src/main/java/io/split/android/client/service/synchronizer/WorkManagerWrapper.java +++ b/src/main/java/io/split/android/client/service/synchronizer/WorkManagerWrapper.java @@ -147,6 +147,7 @@ private Data buildInputData(Data customData) { dataBuilder.putString(ServiceConstants.WORKER_PARAM_DATABASE_NAME, mDatabaseName); dataBuilder.putString(ServiceConstants.WORKER_PARAM_API_KEY, mApiKey); dataBuilder.putBoolean(ServiceConstants.WORKER_PARAM_ENCRYPTION_ENABLED, mSplitClientConfig.encryptionEnabled()); + dataBuilder.putBoolean(ServiceConstants.WORKER_PARAM_USES_PROXY, mSplitClientConfig.proxy() != null); if (mSplitClientConfig.certificatePinningConfiguration() != null) { try { Map> pins = mSplitClientConfig.certificatePinningConfiguration().getPins(); diff --git a/src/main/java/io/split/android/client/service/workmanager/HttpClientProvider.java b/src/main/java/io/split/android/client/service/workmanager/HttpClientProvider.java new file mode 100644 index 000000000..ed4a2cdb2 --- /dev/null +++ b/src/main/java/io/split/android/client/service/workmanager/HttpClientProvider.java @@ -0,0 +1,133 @@ +package io.split.android.client.service.workmanager; + +import androidx.annotation.Nullable; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + +import io.split.android.android_client.BuildConfig; +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.BasicCredentialsProvider; +import io.split.android.client.network.BearerCredentialsProvider; +import io.split.android.client.network.CertificatePinningConfiguration; +import io.split.android.client.network.CertificatePinningConfigurationProvider; +import io.split.android.client.network.HttpClient; +import io.split.android.client.network.HttpClientImpl; +import io.split.android.client.network.HttpProxy; +import io.split.android.client.network.SplitHttpHeadersBuilder; +import io.split.android.client.storage.cipher.SplitCipherFactory; +import io.split.android.client.storage.db.SplitRoomDatabase; +import io.split.android.client.storage.db.StorageFactory; +import io.split.android.client.storage.general.GeneralInfoStorage; +import io.split.android.client.utils.HttpProxySerializer; + +class HttpClientProvider { + + public static HttpClient buildHttpClient(String apiKey, String certPinningConfig, boolean usesProxy, SplitRoomDatabase mDatabase) { + return buildHttpClient(apiKey, buildCertPinningConfig(certPinningConfig), buildProxyConfig(usesProxy, mDatabase, apiKey)); + } + + private static HttpClient buildHttpClient(String apiKey, @Nullable CertificatePinningConfiguration certificatePinningConfiguration, HttpProxy proxyConfiguration) { + HttpClientImpl.Builder builder = new HttpClientImpl.Builder(); + + if (certificatePinningConfiguration != null) { + builder.setCertificatePinningConfiguration(certificatePinningConfiguration); + } + + if (proxyConfiguration != null) { + builder.setProxy(proxyConfiguration); + } + + HttpClient httpClient = builder + .build(); + + SplitHttpHeadersBuilder headersBuilder = new SplitHttpHeadersBuilder(); + headersBuilder.setClientVersion(BuildConfig.SPLIT_VERSION_NAME); + headersBuilder.setApiToken(apiKey); + headersBuilder.addJsonTypeHeaders(); + httpClient.addHeaders(headersBuilder.build()); + + return httpClient; + } + + @Nullable + private static CertificatePinningConfiguration buildCertPinningConfig(@Nullable String pinsJson) { + if (pinsJson == null || pinsJson.trim().isEmpty()) { + return null; + } + + return CertificatePinningConfigurationProvider.getCertificatePinningConfiguration(pinsJson); + } + + private static HttpProxy buildProxyConfig(boolean usesProxy, SplitRoomDatabase database, String apiKey) { + if (!usesProxy) { + return null; + } + + GeneralInfoStorage storage = StorageFactory.getGeneralInfoStorage(database, SplitCipherFactory.create(apiKey, true)); + HttpProxyDto proxyConfigDto = HttpProxySerializer.deserialize(storage); + + if (proxyConfigDto == null) { + return null; + } + + if (proxyConfigDto.host == null) { + return null; + } + + HttpProxy.Builder builder = HttpProxy.newBuilder(proxyConfigDto.host, proxyConfigDto.port); + + addCredentialsProvider(proxyConfigDto, builder); + addMtls(proxyConfigDto, builder); + addCaCert(proxyConfigDto, builder); + + return builder.build(); + } + + private static void addCaCert(HttpProxyDto proxyConfigDto, HttpProxy.Builder builder) { + if (proxyConfigDto.caCert != null) { + InputStream caCertStream = stringToInputStream(proxyConfigDto.caCert); + builder.proxyCacert(caCertStream); + } + } + + private static void addMtls(HttpProxyDto proxyConfigDto, HttpProxy.Builder builder) { + if (proxyConfigDto.clientCert != null && proxyConfigDto.clientKey != null) { + InputStream clientCertStream = stringToInputStream(proxyConfigDto.clientCert); + InputStream clientKeyStream = stringToInputStream(proxyConfigDto.clientKey); + builder.mtls(clientCertStream, clientKeyStream); + } + } + + private static void addCredentialsProvider(HttpProxyDto proxyConfigDto, HttpProxy.Builder builder) { + if (proxyConfigDto.username != null && proxyConfigDto.password != null) { + builder.credentialsProvider(new BasicCredentialsProvider() { + @Override + public String getUsername() { + return proxyConfigDto.username; + } + + @Override + public String getPassword() { + return proxyConfigDto.password; + } + }); + } else if (proxyConfigDto.bearerToken != null) { + builder.credentialsProvider(new BearerCredentialsProvider() { + + @Override + public String getToken() { + return proxyConfigDto.bearerToken; + } + }); + } + } + + private static InputStream stringToInputStream(String input) { + if (input == null) { + return null; + } + return new ByteArrayInputStream(input.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/src/main/java/io/split/android/client/service/workmanager/SplitWorker.java b/src/main/java/io/split/android/client/service/workmanager/SplitWorker.java index b24b231ce..812b8963f 100644 --- a/src/main/java/io/split/android/client/service/workmanager/SplitWorker.java +++ b/src/main/java/io/split/android/client/service/workmanager/SplitWorker.java @@ -3,17 +3,11 @@ import android.content.Context; import androidx.annotation.NonNull; -import androidx.annotation.Nullable; import androidx.work.Data; import androidx.work.Worker; import androidx.work.WorkerParameters; -import io.split.android.android_client.BuildConfig; -import io.split.android.client.network.CertificatePinningConfiguration; -import io.split.android.client.network.CertificatePinningConfigurationProvider; import io.split.android.client.network.HttpClient; -import io.split.android.client.network.HttpClientImpl; -import io.split.android.client.network.SplitHttpHeadersBuilder; import io.split.android.client.service.ServiceConstants; import io.split.android.client.service.executor.SplitTask; import io.split.android.client.storage.db.SplitRoomDatabase; @@ -35,7 +29,9 @@ public SplitWorker(@NonNull Context context, String apiKey = inputData.getString(ServiceConstants.WORKER_PARAM_API_KEY); mEndpoint = inputData.getString(ServiceConstants.WORKER_PARAM_ENDPOINT); mDatabase = SplitRoomDatabase.getDatabase(context, databaseName); - mHttpClient = buildHttpClient(apiKey, buildCertPinningConfig(inputData.getString(ServiceConstants.WORKER_PARAM_CERTIFICATE_PINS))); + mHttpClient = HttpClientProvider.buildHttpClient(apiKey, + inputData.getString(ServiceConstants.WORKER_PARAM_CERTIFICATE_PINS), + inputData.getBoolean(ServiceConstants.WORKER_PARAM_USES_PROXY, false), mDatabase); } @NonNull @@ -60,32 +56,4 @@ public HttpClient getHttpClient() { public String getEndPoint() { return mEndpoint; } - - private static HttpClient buildHttpClient(String apiKey, @Nullable CertificatePinningConfiguration certificatePinningConfiguration) { - HttpClientImpl.Builder builder = new HttpClientImpl.Builder(); - - if (certificatePinningConfiguration != null) { - builder.setCertificatePinningConfiguration(certificatePinningConfiguration); - } - - HttpClient httpClient = builder - .build(); - - SplitHttpHeadersBuilder headersBuilder = new SplitHttpHeadersBuilder(); - headersBuilder.setClientVersion(BuildConfig.SPLIT_VERSION_NAME); - headersBuilder.setApiToken(apiKey); - headersBuilder.addJsonTypeHeaders(); - httpClient.addHeaders(headersBuilder.build()); - - return httpClient; - } - - @Nullable - private static CertificatePinningConfiguration buildCertPinningConfig(@Nullable String pinsJson) { - if (pinsJson == null || pinsJson.trim().isEmpty()) { - return null; - } - - return CertificatePinningConfigurationProvider.getCertificatePinningConfiguration(pinsJson); - } } diff --git a/src/main/java/io/split/android/client/service/workmanager/splits/StorageProvider.java b/src/main/java/io/split/android/client/service/workmanager/splits/StorageProvider.java index d60e81967..b9eb7b5d3 100644 --- a/src/main/java/io/split/android/client/service/workmanager/splits/StorageProvider.java +++ b/src/main/java/io/split/android/client/service/workmanager/splits/StorageProvider.java @@ -14,11 +14,18 @@ class StorageProvider { private final SplitRoomDatabase mDatabase; private final boolean mShouldRecordTelemetry; private final SplitCipher mCipher; + // some values in general info storage require encryption always + private final SplitCipher mAlwaysEncryptedCipher; StorageProvider(SplitRoomDatabase database, String apiKey, boolean encryptionEnabled, boolean shouldRecordTelemetry) { mDatabase = database; mCipher = SplitCipherFactory.create(apiKey, encryptionEnabled); mShouldRecordTelemetry = shouldRecordTelemetry; + if (encryptionEnabled) { + mAlwaysEncryptedCipher = mCipher; + } else { + mAlwaysEncryptedCipher = SplitCipherFactory.create(apiKey, true); + } } SplitsStorage provideSplitsStorage() { @@ -40,6 +47,6 @@ RuleBasedSegmentStorageProducer provideRuleBasedSegmentStorage() { } GeneralInfoStorage provideGeneralInfoStorage() { - return StorageFactory.getGeneralInfoStorage(mDatabase); + return StorageFactory.getGeneralInfoStorage(mDatabase, mAlwaysEncryptedCipher); } } diff --git a/src/main/java/io/split/android/client/storage/db/StorageFactory.java b/src/main/java/io/split/android/client/storage/db/StorageFactory.java index 6dc6e4c4d..1d591e551 100644 --- a/src/main/java/io/split/android/client/storage/db/StorageFactory.java +++ b/src/main/java/io/split/android/client/storage/db/StorageFactory.java @@ -153,8 +153,8 @@ public static PersistentImpressionsObserverCacheStorage getImpressionsObserverCa return new SqlitePersistentImpressionsObserverCacheStorage(splitRoomDatabase.impressionsObserverCacheDao(), expirationPeriod, executorService); } - public static GeneralInfoStorage getGeneralInfoStorage(SplitRoomDatabase splitRoomDatabase) { - return new GeneralInfoStorageImpl(splitRoomDatabase.generalInfoDao()); + public static GeneralInfoStorage getGeneralInfoStorage(SplitRoomDatabase splitRoomDatabase, SplitCipher splitCipher) { + return new GeneralInfoStorageImpl(splitRoomDatabase.generalInfoDao(), splitCipher); } public static PersistentRuleBasedSegmentStorage getPersistentRuleBasedSegmentStorage(SplitRoomDatabase splitRoomDatabase, SplitCipher splitCipher, GeneralInfoStorage generalInfoStorage) { @@ -163,7 +163,7 @@ public static PersistentRuleBasedSegmentStorage getPersistentRuleBasedSegmentSto public static RuleBasedSegmentStorageProducer getRuleBasedSegmentStorageForWorker(SplitRoomDatabase splitRoomDatabase, SplitCipher splitCipher) { PersistentRuleBasedSegmentStorage persistentRuleBasedSegmentStorage = - new SqLitePersistentRuleBasedSegmentStorageProvider(splitCipher, splitRoomDatabase, getGeneralInfoStorage(splitRoomDatabase)).get(); + new SqLitePersistentRuleBasedSegmentStorageProvider(splitCipher, splitRoomDatabase, getGeneralInfoStorage(splitRoomDatabase, null)).get(); return new RuleBasedSegmentStorageProducerImpl(persistentRuleBasedSegmentStorage, new ConcurrentHashMap<>(), new AtomicLong(-1)); } } diff --git a/src/main/java/io/split/android/client/storage/general/GeneralInfoStorage.java b/src/main/java/io/split/android/client/storage/general/GeneralInfoStorage.java index 87a6a55ec..b3ad1215c 100644 --- a/src/main/java/io/split/android/client/storage/general/GeneralInfoStorage.java +++ b/src/main/java/io/split/android/client/storage/general/GeneralInfoStorage.java @@ -38,4 +38,9 @@ public interface GeneralInfoStorage { void setLastProxyUpdateTimestamp(long timestamp); long getLastProxyUpdateTimestamp(); + + @Nullable + String getProxyConfig(); + + void setProxyConfig(@Nullable String proxyConfig); } diff --git a/src/main/java/io/split/android/client/storage/general/GeneralInfoStorageImpl.java b/src/main/java/io/split/android/client/storage/general/GeneralInfoStorageImpl.java index c351d9a48..304729d57 100644 --- a/src/main/java/io/split/android/client/storage/general/GeneralInfoStorageImpl.java +++ b/src/main/java/io/split/android/client/storage/general/GeneralInfoStorageImpl.java @@ -5,6 +5,7 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import io.split.android.client.storage.cipher.SplitCipher; import io.split.android.client.storage.db.GeneralInfoDao; import io.split.android.client.storage.db.GeneralInfoEntity; @@ -13,11 +14,14 @@ public class GeneralInfoStorageImpl implements GeneralInfoStorage { private static final String ROLLOUT_CACHE_LAST_CLEAR_TIMESTAMP = "rolloutCacheLastClearTimestamp"; private static final String RBS_CHANGE_NUMBER = "rbsChangeNumber"; private static final String LAST_PROXY_CHECK_TIMESTAMP = "lastProxyCheckTimestamp"; + private static final String PROXY_CONFIG = "proxyConfig"; private final GeneralInfoDao mGeneralInfoDao; + private final SplitCipher mAlwaysEncryptedSplitCipher; - public GeneralInfoStorageImpl(GeneralInfoDao generalInfoDao) { + public GeneralInfoStorageImpl(GeneralInfoDao generalInfoDao, @Nullable SplitCipher splitCipher) { mGeneralInfoDao = checkNotNull(generalInfoDao); + mAlwaysEncryptedSplitCipher = splitCipher; } @Override @@ -109,4 +113,27 @@ public long getLastProxyUpdateTimestamp() { GeneralInfoEntity entity = mGeneralInfoDao.getByName(LAST_PROXY_CHECK_TIMESTAMP); return entity != null ? entity.getLongValue() : 0L; } + + @Nullable + @Override + public String getProxyConfig() { + GeneralInfoEntity entity = mGeneralInfoDao.getByName(PROXY_CONFIG); + if (entity == null) { + return null; + } + + if (mAlwaysEncryptedSplitCipher != null) { + return mAlwaysEncryptedSplitCipher.decrypt(entity.getStringValue()); + } + + return entity.getStringValue(); + } + + @Override + public void setProxyConfig(@Nullable String proxyConfig) { + if (mAlwaysEncryptedSplitCipher != null) { + proxyConfig = mAlwaysEncryptedSplitCipher.encrypt(proxyConfig); + } + mGeneralInfoDao.update(new GeneralInfoEntity(PROXY_CONFIG, proxyConfig)); + } } diff --git a/src/main/java/io/split/android/client/utils/HttpProxySerializer.java b/src/main/java/io/split/android/client/utils/HttpProxySerializer.java new file mode 100644 index 000000000..889f9344a --- /dev/null +++ b/src/main/java/io/split/android/client/utils/HttpProxySerializer.java @@ -0,0 +1,42 @@ +package io.split.android.client.utils; + +import androidx.annotation.Nullable; + +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.HttpProxy; +import io.split.android.client.storage.general.GeneralInfoStorage; + +/** + * Utility class for serializing and deserializing HttpProxy objects. + */ +public class HttpProxySerializer { + + private HttpProxySerializer() { + } + + @Nullable + public static String serialize(@Nullable HttpProxy httpProxy) { + if (httpProxy == null) { + return null; + } + HttpProxyDto dto = new HttpProxyDto(httpProxy); + return Json.toJson(dto); + } + + + @Nullable + public static HttpProxyDto deserialize(GeneralInfoStorage storage) { + if (storage == null) { + return null; + } + String json = storage.getProxyConfig(); + if (json == null || json.isEmpty()) { + return null; + } + try { + return Json.fromJson(json, HttpProxyDto.class); + } catch (Exception e) { + return null; + } + } +} diff --git a/src/main/java/io/split/android/client/validators/TreatmentManagerFactoryImpl.java b/src/main/java/io/split/android/client/validators/TreatmentManagerFactoryImpl.java index e9a24631a..287fb94b4 100644 --- a/src/main/java/io/split/android/client/validators/TreatmentManagerFactoryImpl.java +++ b/src/main/java/io/split/android/client/validators/TreatmentManagerFactoryImpl.java @@ -17,6 +17,9 @@ import io.split.android.client.storage.splits.SplitsStorage; import io.split.android.client.telemetry.storage.TelemetryStorageProducer; import io.split.android.engine.experiments.SplitParser; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; public class TreatmentManagerFactoryImpl implements TreatmentManagerFactory { @@ -32,6 +35,7 @@ public class TreatmentManagerFactoryImpl implements TreatmentManagerFactory { private final ValidationMessageLogger mValidationMessageLogger; private final SplitFilterValidator mFlagSetsValidator; private final PropertyValidator mPropertyValidator; + private final FallbackTreatmentsCalculator mFallbackCalculator; public TreatmentManagerFactoryImpl(@NonNull KeyValidator keyValidator, @NonNull SplitValidator splitValidator, @@ -41,14 +45,22 @@ public TreatmentManagerFactoryImpl(@NonNull KeyValidator keyValidator, @NonNull TelemetryStorageProducer telemetryStorageProducer, @NonNull SplitParser splitParser, @Nullable FlagSetsFilter flagSetsFilter, - @NonNull SplitsStorage splitsStorage) { + @NonNull SplitsStorage splitsStorage, + @Nullable FallbackTreatmentsConfiguration fallbackTreatments) { mKeyValidator = checkNotNull(keyValidator); mSplitValidator = checkNotNull(splitValidator); mCustomerImpressionListener = checkNotNull(customerImpressionListener); mLabelsEnabled = labelsEnabled; mAttributesMerger = checkNotNull(attributesMerger); mTelemetryStorageProducer = checkNotNull(telemetryStorageProducer); - mEvaluator = new EvaluatorImpl(splitsStorage, splitParser); + FallbackTreatmentsCalculator calculator; + if (fallbackTreatments != null) { + calculator = new FallbackTreatmentsCalculatorImpl(fallbackTreatments); + } else { + calculator = new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build()); + } + mEvaluator = new EvaluatorImpl(splitsStorage, splitParser, calculator); + mFallbackCalculator = calculator; mFlagSetsFilter = flagSetsFilter; mSplitsStorage = checkNotNull(splitsStorage); mValidationMessageLogger = new ValidationMessageLoggerImpl(); @@ -74,7 +86,8 @@ public TreatmentManager getTreatmentManager(Key key, ListenableEventsManager eve mSplitsStorage, mValidationMessageLogger, mFlagSetsValidator, - mPropertyValidator + mPropertyValidator, + mFallbackCalculator ); } } diff --git a/src/main/java/io/split/android/client/validators/TreatmentManagerHelper.java b/src/main/java/io/split/android/client/validators/TreatmentManagerHelper.java index 3539a7c4b..125f594c3 100644 --- a/src/main/java/io/split/android/client/validators/TreatmentManagerHelper.java +++ b/src/main/java/io/split/android/client/validators/TreatmentManagerHelper.java @@ -7,18 +7,20 @@ import java.util.Map; import io.split.android.client.SplitResult; -import io.split.android.grammar.Treatments; +import io.split.android.client.fallback.FallbackTreatment; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; class TreatmentManagerHelper { - static Map controlTreatmentsForSplitsWithConfig(SplitValidator splitValidator, ValidationMessageLogger validationLogger, List splits, String validationTag, TreatmentManagerImpl.ResultTransformer resultTransformer) { + static Map controlTreatmentsForSplitsWithConfig(SplitValidator splitValidator, ValidationMessageLogger validationLogger, List splits, String validationTag, TreatmentManagerImpl.ResultTransformer resultTransformer, FallbackTreatmentsCalculator mFallbackCalculator) { Map results = new HashMap<>(); for (String split : splits) { if (isInvalidSplit(splitValidator, validationTag, validationLogger, split)) { continue; } - results.put(split.trim(), resultTransformer.transform(new SplitResult(Treatments.CONTROL))); + FallbackTreatment fallback = mFallbackCalculator.resolve(split); + results.put(split.trim(), resultTransformer.transform(new SplitResult(fallback.getTreatment(), fallback.getConfig()))); } return results; diff --git a/src/main/java/io/split/android/client/validators/TreatmentManagerImpl.java b/src/main/java/io/split/android/client/validators/TreatmentManagerImpl.java index 7493e2a7c..cf09d400b 100644 --- a/src/main/java/io/split/android/client/validators/TreatmentManagerImpl.java +++ b/src/main/java/io/split/android/client/validators/TreatmentManagerImpl.java @@ -22,6 +22,8 @@ import io.split.android.client.attributes.AttributesMerger; import io.split.android.client.events.ListenableEventsManager; import io.split.android.client.events.SplitEvent; +import io.split.android.client.fallback.FallbackTreatment; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; import io.split.android.client.impressions.DecoratedImpression; import io.split.android.client.impressions.Impression; import io.split.android.client.impressions.ImpressionListener; @@ -30,7 +32,6 @@ import io.split.android.client.telemetry.storage.TelemetryStorageProducer; import io.split.android.client.utils.Json; import io.split.android.client.utils.logger.Logger; -import io.split.android.grammar.Treatments; public class TreatmentManagerImpl implements TreatmentManager { @@ -52,6 +53,8 @@ public class TreatmentManagerImpl implements TreatmentManager { private final SplitsStorage mSplitsStorage; private final SplitFilterValidator mFlagSetsValidator; private final PropertyValidator mPropertyValidator; + @NonNull + private final FallbackTreatmentsCalculator mFallbackCalculator; public TreatmentManagerImpl(String matchingKey, String bucketingKey, @@ -68,7 +71,8 @@ public TreatmentManagerImpl(String matchingKey, @NonNull SplitsStorage splitsStorage, @NonNull ValidationMessageLogger validationLogger, @NonNull SplitFilterValidator flagSetsValidator, - @NonNull PropertyValidator propertyValidator) { + @NonNull PropertyValidator propertyValidator, + @NonNull FallbackTreatmentsCalculator fallbackCalculator) { mEvaluator = evaluator; mKeyValidator = keyValidator; mSplitValidator = splitValidator; @@ -85,6 +89,7 @@ public TreatmentManagerImpl(String matchingKey, mSplitsStorage = checkNotNull(splitsStorage); mFlagSetsValidator = checkNotNull(flagSetsValidator); mPropertyValidator = checkNotNull(propertyValidator); + mFallbackCalculator = checkNotNull(fallbackCalculator); } @Override @@ -100,14 +105,19 @@ public String getTreatment(String split, Map attributes, Evaluat Method.TREATMENT ).get(split); - return (treatment == null) ? Treatments.CONTROL : treatment; + if (treatment == null) { + FallbackTreatment fallback = mFallbackCalculator.resolve(split); + return fallback.getTreatment(); + } + return treatment; } catch (Exception ex) { // In case get fails for some reason Logger.e("Client " + Method.TREATMENT.getMethod() + " exception", ex); mTelemetryStorageProducer.recordException(Method.TREATMENT); - return Treatments.CONTROL; + FallbackTreatment fallback = mFallbackCalculator.resolve(split); + return fallback.getTreatment(); } } @@ -124,13 +134,18 @@ public SplitResult getTreatmentWithConfig(String split, Map attr Method.TREATMENT_WITH_CONFIG ).get(split); - return (splitResult == null) ? new SplitResult(Treatments.CONTROL) : splitResult; + if (splitResult == null) { + FallbackTreatment fallback = mFallbackCalculator.resolve(split); + return new SplitResult(fallback.getTreatment(), fallback.getConfig()); + } + return splitResult; } catch (Exception ex) { // In case get fails for some reason Logger.e("Client " + Method.TREATMENT_WITH_CONFIG.getMethod() + " exception", ex); mTelemetryStorageProducer.recordException(Method.TREATMENT_WITH_CONFIG); - return new SplitResult(Treatments.CONTROL); + FallbackTreatment fallback = mFallbackCalculator.resolve(split); + return new SplitResult(fallback.getTreatment(), fallback.getConfig()); } } @@ -285,7 +300,8 @@ private TreatmentResult getTreatmentWithConfigWithoutMetrics(String split, Map Map getControlTreatmentsForSplitsWithConfig(@Nullable Lis mValidationLogger, (names != null) ? names : new ArrayList<>(), validationTag, - resultTransformer); + resultTransformer, + mFallbackCalculator); } private EvaluationResult evaluateIfReady(String featureFlagName, @@ -390,7 +408,8 @@ private EvaluationResult evaluateIfReady(String featureFlagName, mValidationLogger.w("the SDK is not ready, results may be incorrect for feature flag " + featureFlagName + ". Make sure to wait for SDK readiness before using this method", validationTag); mTelemetryStorageProducer.recordNonReadyUsage(); - return new EvaluationResult(Treatments.CONTROL, TreatmentLabels.NOT_READY, null, null, false); + FallbackTreatment fallback = mFallbackCalculator.resolve(featureFlagName, TreatmentLabels.NOT_READY); + return new EvaluationResult(fallback.getTreatment(), fallback.getLabel(), null, fallback.getConfig(), false); } return mEvaluator.getTreatment(mMatchingKey, mBucketingKey, featureFlagName, attributes); } diff --git a/src/main/java/io/split/android/engine/splitter/Splitter.java b/src/main/java/io/split/android/engine/splitter/Splitter.java index 507ba9c6b..6c181af58 100644 --- a/src/main/java/io/split/android/engine/splitter/Splitter.java +++ b/src/main/java/io/split/android/engine/splitter/Splitter.java @@ -1,10 +1,10 @@ package io.split.android.engine.splitter; +import java.util.List; + import io.split.android.client.dtos.Partition; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; import io.split.android.client.utils.MurmurHash3; -import io.split.android.grammar.Treatments; - -import java.util.List; /** * These set of functions figure out which treatment a key should see. @@ -14,11 +14,11 @@ public class Splitter { private static final int ALGO_LEGACY = 1; private static final int ALGO_MURMUR = 2; - public static String getTreatment(String key, int seed, List partitions, int algo) { + public static String getTreatment(String key, int seed, List partitions, int algo, FallbackTreatmentsCalculator fallbackCalculator) { // 1. when there are no partitions, we just return control if (partitions.isEmpty()) { - return Treatments.CONTROL; + return fallbackCalculator.resolve(key).getTreatment(); } @@ -26,7 +26,8 @@ public static String getTreatment(String key, int seed, List partitio return partitions.get(0).treatment; } - return getTreatment(bucket(hash(key, seed, algo)), partitions); + String controlTreatment = fallbackCalculator.resolve(key).getTreatment(); + return getTreatment(bucket(hash(key, seed, algo)), partitions, controlTreatment); } static long hash(String key, int seed, int algo) { @@ -65,10 +66,11 @@ static int legacy_hash(String key, int seed) { /** * @param bucket - * @param partitions MUST HAVE more than one partitions. + * @param partitions MUST HAVE more than one partitions. + * @param controlTreatment * @return */ - private static String getTreatment(int bucket, List partitions) { + private static String getTreatment(int bucket, List partitions, String controlTreatment) { int bucketsCoveredThusFar = 0; @@ -80,7 +82,7 @@ private static String getTreatment(int bucket, List partitions) { } } - return Treatments.CONTROL; + return controlTreatment; } /*package private*/ diff --git a/src/main/java/io/split/android/grammar/Treatments.java b/src/main/java/io/split/android/grammar/Treatments.java index 9d352b078..66c1e2b1d 100644 --- a/src/main/java/io/split/android/grammar/Treatments.java +++ b/src/main/java/io/split/android/grammar/Treatments.java @@ -7,25 +7,7 @@ public class Treatments { public static final String CONTROL = "control"; - - /** - * OFF is a synonym for CONTROL. - */ public static final String OFF = "off"; public static final String ON = "on"; - public static boolean isControl(String treatment) { - return CONTROL.equals(treatment) || OFF.equals(treatment); - } - - public static String controlSynonym(String treatment) { - if (!isControl(treatment)) { - throw new IllegalArgumentException("Not a control treatment: " + treatment); - } - if (Treatments.OFF.equals(treatment)) { - return Treatments.CONTROL; - } - return Treatments.OFF; - } - } diff --git a/src/test/java/io/split/android/client/SplitClientConfigTest.java b/src/test/java/io/split/android/client/SplitClientConfigTest.java index 1e69b9c12..b97ea6381 100644 --- a/src/test/java/io/split/android/client/SplitClientConfigTest.java +++ b/src/test/java/io/split/android/client/SplitClientConfigTest.java @@ -3,9 +3,11 @@ import static junit.framework.Assert.assertFalse; import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertNull; +import static junit.framework.TestCase.assertSame; import static junit.framework.TestCase.assertTrue; import androidx.annotation.NonNull; +import androidx.annotation.Nullable; import org.junit.Test; @@ -13,7 +15,11 @@ import java.util.Queue; import java.util.concurrent.TimeUnit; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; import io.split.android.client.network.CertificatePinningConfiguration; +import io.split.android.client.network.ProxyConfiguration; +import io.split.android.client.network.SplitAuthenticatedRequest; +import io.split.android.client.network.SplitAuthenticator; import io.split.android.client.utils.logger.LogPrinter; import io.split.android.client.utils.logger.Logger; import io.split.android.client.utils.logger.SplitLogLevel; @@ -256,6 +262,63 @@ public void nullRolloutCacheConfigurationSetsDefault() { assertEquals(1, logMessages.size()); } + @Test + public void fallbackTreatmentsIsNullByDefault() { + SplitClientConfig config = SplitClientConfig.builder().build(); + assertNull(config.fallbackTreatments()); + } + + @Test + public void fallbackTreatmentsAreCorrectlySet() { + FallbackTreatmentsConfiguration ftConfiguration = FallbackTreatmentsConfiguration.builder().build(); + SplitClientConfig config = SplitClientConfig.builder() + .fallbackTreatments(ftConfiguration) + .build(); + + assertSame(ftConfiguration, config.fallbackTreatments()); + } + + @Test + public void proxyHostAndProxyConfigurationSetLogWarning() { + Queue logMessages = getLogMessagesQueue(); + SplitClientConfig.builder() + .logLevel(SplitLogLevel.WARNING) + .proxyHost("proxyHost") + .proxyConfiguration(ProxyConfiguration.builder().url("http://proxy.url").build()) + .build(); + assertEquals(1, logMessages.size()); + assertEquals("Both the deprecated proxy configuration methods (proxyHost, proxyAuthenticator) and the new ProxyConfiguration builder are being used. ProxyConfiguration will take precedence.", logMessages.poll()); + } + + @Test + public void proxyAuthenticatorAndProxyConfigurationSetLogWarning() { + Queue logMessages = getLogMessagesQueue(); + SplitClientConfig.builder() + .logLevel(SplitLogLevel.WARNING) + .proxyAuthenticator(new SplitAuthenticator() { + @Nullable + @Override + public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest request) { + return null; + } + }) + .proxyConfiguration(ProxyConfiguration.builder().url("http://proxy.url").build()) + .build(); + assertEquals(1, logMessages.size()); + assertEquals("Both the deprecated proxy configuration methods (proxyHost, proxyAuthenticator) and the new ProxyConfiguration builder are being used. ProxyConfiguration will take precedence.", logMessages.poll()); + } + + @Test + public void proxyConfigurationWithNoUrlSetLogWarning() { + Queue logMessages = getLogMessagesQueue(); + SplitClientConfig.builder() + .logLevel(SplitLogLevel.WARNING) + .proxyConfiguration(ProxyConfiguration.builder().build()) + .build(); + assertEquals(1, logMessages.size()); + assertEquals("Proxy configuration with no URL. This will prevent SplitFactory from working.", logMessages.poll()); + } + @NonNull private static Queue getLogMessagesQueue() { Queue logMessages = new LinkedList<>(); diff --git a/src/test/java/io/split/android/client/SplitFactoryHelperTest.kt b/src/test/java/io/split/android/client/SplitFactoryHelperTest.kt index 04cc76c80..0a0f46ec4 100644 --- a/src/test/java/io/split/android/client/SplitFactoryHelperTest.kt +++ b/src/test/java/io/split/android/client/SplitFactoryHelperTest.kt @@ -2,10 +2,13 @@ package io.split.android.client import android.content.Context import io.split.android.client.SplitFactoryHelper.Initializer.Listener +import io.split.android.client.api.Key import io.split.android.client.events.EventsManagerCoordinator import io.split.android.client.events.SplitInternalEvent +import io.split.android.client.exceptions.SplitInstantiationException import io.split.android.client.lifecycle.SplitLifecycleManager -import io.split.android.client.service.CleanUpDatabaseTask +import io.split.android.client.network.HttpProxy +import io.split.android.client.network.ProxyConfiguration import io.split.android.client.service.executor.SplitSingleThreadTaskExecutor import io.split.android.client.service.executor.SplitTaskExecutionInfo import io.split.android.client.service.executor.SplitTaskExecutionListener @@ -13,7 +16,12 @@ import io.split.android.client.service.executor.SplitTaskExecutor import io.split.android.client.service.executor.SplitTaskType import io.split.android.client.service.synchronizer.RolloutCacheManager import io.split.android.client.service.synchronizer.SyncManager +import io.split.android.client.service.synchronizer.WorkManagerWrapper +import io.split.android.client.storage.general.GeneralInfoStorage +import io.split.android.client.utils.HttpProxySerializer import junit.framework.TestCase.assertEquals +import junit.framework.TestCase.assertFalse +import junit.framework.TestCase.assertTrue import org.junit.After import org.junit.Before import org.junit.Test @@ -21,11 +29,12 @@ import org.mockito.Mock import org.mockito.Mockito import org.mockito.Mockito.any import org.mockito.Mockito.mock +import org.mockito.Mockito.mockStatic +import org.mockito.Mockito.never import org.mockito.Mockito.verify import org.mockito.Mockito.`when` import org.mockito.MockitoAnnotations import java.io.File -import java.lang.IllegalArgumentException import java.util.concurrent.locks.ReentrantLock class SplitFactoryHelperTest { @@ -184,4 +193,131 @@ class SplitFactoryHelperTest { verify(lifecycleManager).register(syncManager) verify(initLock).unlock() } + + @Test + fun `initializing with proxy config with null url throws`() { + var exceptionThrown = false + try { + SplitFactoryBuilder.build("sdk_key", Key("user"), SplitClientConfig.builder().proxyConfiguration( + ProxyConfiguration.builder().build()).build(), context) + } catch (splitInstantiationException: SplitInstantiationException) { + exceptionThrown = (splitInstantiationException.message ?: "").contains("When configured, proxy host cannot be null") + } + + assertTrue(exceptionThrown) + } + + @Test + fun `initializing with proxy config with valid url does not throw`() { + var exceptionThrown = false + try { + SplitFactoryBuilder.build("sdk_key", Key("user"), SplitClientConfig.builder().proxyConfiguration( + ProxyConfiguration.builder().url("http://localhost:8080").build()).build(), context) + } catch (splitInstantiationException: SplitInstantiationException) { + exceptionThrown = (splitInstantiationException.message ?: "").contains("When configured, proxy host cannot be null") + } + + assertFalse(exceptionThrown) + } + + @Test + fun `setupProxyForBackgroundSync should start thread when proxy is not null, not legacy, and background sync is enabled`() { + val httpProxy = mock(HttpProxy::class.java) + val config = mock(SplitClientConfig::class.java) + val proxyConfigSaveTask = mock(Runnable::class.java) + + `when`(config.proxy()).thenReturn(httpProxy) + `when`(httpProxy.isLegacy()).thenReturn(false) + `when`(config.synchronizeInBackground()).thenReturn(true) + + SplitFactoryHelper.setupProxyForBackgroundSync(config, proxyConfigSaveTask) + + Thread.sleep(100) + verify(proxyConfigSaveTask).run() + } + + @Test + fun `setupProxyForBackgroundSync should not start thread when proxy is null`() { + val config = mock(SplitClientConfig::class.java) + val proxyConfigSaveTask = mock(Runnable::class.java) + + `when`(config.proxy()).thenReturn(null) + `when`(config.synchronizeInBackground()).thenReturn(true) + + SplitFactoryHelper.setupProxyForBackgroundSync(config, proxyConfigSaveTask) + + Thread.sleep(100) + verify(proxyConfigSaveTask, never()).run() + } + + @Test + fun `setupProxyForBackgroundSync should not start thread when proxy is legacy`() { + val httpProxy = mock(HttpProxy::class.java) + val config = mock(SplitClientConfig::class.java) + val proxyConfigSaveTask = mock(Runnable::class.java) + + `when`(config.proxy()).thenReturn(httpProxy) + `when`(httpProxy.isLegacy()).thenReturn(true) + `when`(config.synchronizeInBackground()).thenReturn(true) + + SplitFactoryHelper.setupProxyForBackgroundSync(config, proxyConfigSaveTask) + + Thread.sleep(100) // Give time to ensure no thread was started + } + + @Test + fun `setupProxyForBackgroundSync should not start thread when background sync is disabled`() { + val httpProxy = mock(HttpProxy::class.java) + val config = mock(SplitClientConfig::class.java) + val proxyConfigSaveTask = mock(Runnable::class.java) + + `when`(config.proxy()).thenReturn(httpProxy) + `when`(httpProxy.isLegacy()).thenReturn(false) + `when`(config.synchronizeInBackground()).thenReturn(false) + + SplitFactoryHelper.setupProxyForBackgroundSync(config, proxyConfigSaveTask) + + Thread.sleep(100) // Give time to ensure no thread was started + } + + @Test + fun `getProxyConfigSaveTask should return runnable that saves proxy config`() { + val config = mock(SplitClientConfig::class.java) + val httpProxy = mock(HttpProxy::class.java) + val workManagerWrapper = mock(WorkManagerWrapper::class.java) + val generalInfoStorage = mock(GeneralInfoStorage::class.java) + val serializedProxy = "serialized_proxy_json" + + `when`(config.proxy()).thenReturn(httpProxy) + + mockStatic(HttpProxySerializer::class.java).use { mockedSerializer -> + mockedSerializer.`when` { HttpProxySerializer.serialize(httpProxy) }.thenReturn(serializedProxy) + + val runnable = SplitFactoryHelper.getProxyConfigSaveTask(config, workManagerWrapper, generalInfoStorage) + runnable.run() + + verify(generalInfoStorage).setProxyConfig(serializedProxy) + verify(workManagerWrapper, never()).removeWork() + } + } + + @Test + fun `getProxyConfigSaveTask should handle exceptions and disable background sync`() { + val config = mock(SplitClientConfig::class.java) + val httpProxy = mock(HttpProxy::class.java) + val workManagerWrapper = mock(WorkManagerWrapper::class.java) + val generalInfoStorage = mock(GeneralInfoStorage::class.java) + + `when`(config.proxy()).thenReturn(httpProxy) + + mockStatic(HttpProxySerializer::class.java).use { mockedSerializer -> + mockedSerializer.`when` { HttpProxySerializer.serialize(httpProxy) }.thenThrow(RuntimeException("Test exception")) + + val runnable = SplitFactoryHelper.getProxyConfigSaveTask(config, workManagerWrapper, generalInfoStorage) + runnable.run() + + verify(generalInfoStorage, never()).setProxyConfig(any()) + verify(workManagerWrapper).removeWork() + } + } } diff --git a/src/test/java/io/split/android/client/TreatmentManagerEvaluationOptionsTest.java b/src/test/java/io/split/android/client/TreatmentManagerEvaluationOptionsTest.java index 91fb30c09..e8718653c 100644 --- a/src/test/java/io/split/android/client/TreatmentManagerEvaluationOptionsTest.java +++ b/src/test/java/io/split/android/client/TreatmentManagerEvaluationOptionsTest.java @@ -19,6 +19,8 @@ import io.split.android.client.attributes.AttributesManager; import io.split.android.client.attributes.AttributesMerger; import io.split.android.client.events.ListenableEventsManager; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; import io.split.android.client.impressions.Impression; import io.split.android.client.impressions.ImpressionListener; import io.split.android.client.storage.splits.SplitsStorage; @@ -68,7 +70,8 @@ public void setUp() { mSplitsStorage, mValidationMessageLogger, new FlagSetsValidatorImpl(), - mPropertyValidator); + mPropertyValidator, + new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())); } @Test diff --git a/src/test/java/io/split/android/client/TreatmentManagerExceptionsTest.java b/src/test/java/io/split/android/client/TreatmentManagerExceptionsTest.java index e0f494bca..4f8432e18 100644 --- a/src/test/java/io/split/android/client/TreatmentManagerExceptionsTest.java +++ b/src/test/java/io/split/android/client/TreatmentManagerExceptionsTest.java @@ -26,6 +26,8 @@ import io.split.android.client.attributes.AttributesMerger; import io.split.android.client.events.ListenableEventsManager; import io.split.android.client.events.SplitEvent; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; import io.split.android.client.impressions.Impression; import io.split.android.client.impressions.ImpressionListener; import io.split.android.client.storage.splits.SplitsStorage; @@ -83,7 +85,8 @@ public void setUp() { mSplitsStorage, new ValidationMessageLoggerImpl(), mFlagSetsValidator, - new PropertyValidatorImpl()); + new PropertyValidatorImpl(), + new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())); when(evaluator.getTreatment(anyString(), anyString(), anyString(), anyMap())).thenReturn(new EvaluationResult("test", "label")); } diff --git a/src/test/java/io/split/android/client/TreatmentManagerTelemetryTest.java b/src/test/java/io/split/android/client/TreatmentManagerTelemetryTest.java index 8d15f22af..222de7750 100644 --- a/src/test/java/io/split/android/client/TreatmentManagerTelemetryTest.java +++ b/src/test/java/io/split/android/client/TreatmentManagerTelemetryTest.java @@ -22,6 +22,8 @@ import io.split.android.client.attributes.AttributesMerger; import io.split.android.client.events.ListenableEventsManager; import io.split.android.client.events.SplitEvent; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; import io.split.android.client.impressions.ImpressionListener; import io.split.android.client.storage.splits.SplitsStorage; import io.split.android.client.telemetry.model.Method; @@ -76,7 +78,8 @@ public void setUp() { mFlagSetsFilter, mSplitsStorage, new ValidationMessageLoggerImpl(), new FlagSetsValidatorImpl(), - new PropertyValidatorImpl()); + new PropertyValidatorImpl(), + new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())); when(evaluator.getTreatment(anyString(), anyString(), anyString(), anyMap())).thenReturn(new EvaluationResult("test", "label")); } diff --git a/src/test/java/io/split/android/client/TreatmentManagerTest.java b/src/test/java/io/split/android/client/TreatmentManagerTest.java index 6fc11d770..ce889d69d 100644 --- a/src/test/java/io/split/android/client/TreatmentManagerTest.java +++ b/src/test/java/io/split/android/client/TreatmentManagerTest.java @@ -28,6 +28,8 @@ import io.split.android.client.dtos.Split; import io.split.android.client.events.ListenableEventsManager; import io.split.android.client.events.SplitEvent; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; import io.split.android.client.impressions.DecoratedImpression; import io.split.android.client.impressions.ImpressionListener; import io.split.android.client.storage.mysegments.MySegmentsStorage; @@ -370,7 +372,8 @@ private TreatmentManager createTreatmentManager(String matchingKey, String bucke new KeyValidatorImpl(), splitValidator, mock(ImpressionListener.FederatedImpressionListener.class), config.labelsEnabled(), eventsManager, mock(AttributesManager.class), mock(AttributesMerger.class), - mock(TelemetryStorageProducer.class), mFlagSetsFilter, mSplitsStorage, validationLogger, new FlagSetsValidatorImpl(), new PropertyValidatorImpl()); + mock(TelemetryStorageProducer.class), mFlagSetsFilter, mSplitsStorage, validationLogger, new FlagSetsValidatorImpl(), new PropertyValidatorImpl(), + new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())); } private TreatmentManagerImpl initializeTreatmentManager() { @@ -400,7 +403,8 @@ private TreatmentManagerImpl initializeTreatmentManager(Evaluator evaluator) { telemetryStorageProducer, mFlagSetsFilter, mSplitsStorage, - new ValidationMessageLoggerImpl(), new FlagSetsValidatorImpl(), new PropertyValidatorImpl()); + new ValidationMessageLoggerImpl(), new FlagSetsValidatorImpl(), new PropertyValidatorImpl(), + new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())); } private Map splitsMap(List splits) { diff --git a/src/test/java/io/split/android/client/TreatmentManagerWithFlagSetsTest.java b/src/test/java/io/split/android/client/TreatmentManagerWithFlagSetsTest.java index 4588b8cb2..aa12c3d5e 100644 --- a/src/test/java/io/split/android/client/TreatmentManagerWithFlagSetsTest.java +++ b/src/test/java/io/split/android/client/TreatmentManagerWithFlagSetsTest.java @@ -26,6 +26,8 @@ import io.split.android.client.attributes.AttributesMerger; import io.split.android.client.events.ListenableEventsManager; import io.split.android.client.events.SplitEvent; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; import io.split.android.client.impressions.ImpressionListener; import io.split.android.client.storage.splits.SplitsStorage; import io.split.android.client.telemetry.model.Method; @@ -155,7 +157,8 @@ private void initializeTreatmentManager() { mAttributesMerger, mTelemetryStorageProducer, mFlagSetsFilter, - mSplitsStorage, new ValidationMessageLoggerImpl(), new FlagSetsValidatorImpl(), new PropertyValidatorImpl()); + mSplitsStorage, new ValidationMessageLoggerImpl(), new FlagSetsValidatorImpl(), new PropertyValidatorImpl(), + new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())); } @Test diff --git a/src/test/java/io/split/android/client/fallback/FallbackTreatmentTest.java b/src/test/java/io/split/android/client/fallback/FallbackTreatmentTest.java new file mode 100644 index 000000000..3914e2c58 --- /dev/null +++ b/src/test/java/io/split/android/client/fallback/FallbackTreatmentTest.java @@ -0,0 +1,53 @@ +package io.split.android.client.fallback; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; + +import org.junit.Test; + +public class FallbackTreatmentTest { + + @Test + public void constructorSetsFields() { + FallbackTreatment ft = new FallbackTreatment("off", "{\"k\":true}", "my label"); + assertEquals("off", ft.getTreatment()); + assertEquals("{\"k\":true}", ft.getConfig()); + assertEquals("my label", ft.getLabel()); + } + + @Test + public void configCanBeNull() { + FallbackTreatment ft = new FallbackTreatment("off", null, "my label"); + assertEquals("off", ft.getTreatment()); + assertNull(ft.getConfig()); + assertEquals("my label", ft.getLabel()); + } + + @Test + public void labelCanBeNull() { + FallbackTreatment ft = new FallbackTreatment("off", null, null); + assertEquals("off", ft.getTreatment()); + assertNull(ft.getConfig()); + assertNull(ft.getLabel()); + } + + @Test + public void convenienceConstructorSetsNullConfigAndLabel() { + FallbackTreatment ft = new FallbackTreatment("off"); + assertEquals("off", ft.getTreatment()); + assertNull(ft.getConfig()); + assertNull(ft.getLabel()); + } + + @Test + public void equalityAndHashCodeByValue() { + FallbackTreatment a = new FallbackTreatment("off", null); + FallbackTreatment b = new FallbackTreatment("off", null); + FallbackTreatment c = new FallbackTreatment("on", null); + + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + assertNotEquals(a, c); + } +} diff --git a/src/test/java/io/split/android/client/fallback/FallbackTreatmentsCalculatorTest.java b/src/test/java/io/split/android/client/fallback/FallbackTreatmentsCalculatorTest.java new file mode 100644 index 000000000..37c5353c0 --- /dev/null +++ b/src/test/java/io/split/android/client/fallback/FallbackTreatmentsCalculatorTest.java @@ -0,0 +1,125 @@ +package io.split.android.client.fallback; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import io.split.android.client.TreatmentLabels; + +public class FallbackTreatmentsCalculatorTest { + + @Test + public void flagLevelOverrideTakesPrecedence() { + FallbackTreatment global = new FallbackTreatment("off", "{\"g\":true}"); + FallbackTreatment byFlag = new FallbackTreatment("on", "{\"f\":true}"); + Map map = new HashMap<>(); + map.put("my_flag", byFlag); + FallbackTreatmentsConfiguration config = FallbackTreatmentsConfiguration.builder() + .global(global) + .byFlag(map) + .build(); + + FallbackTreatmentsCalculator calculator = new FallbackTreatmentsCalculatorImpl(config); + FallbackTreatment resolvedExisting = calculator.resolve("my_flag"); + FallbackTreatment resolvedOther = calculator.resolve("other_flag"); + + assertNotNull(resolvedExisting); + assertEquals(byFlag, resolvedExisting); + assertNotNull(resolvedOther); + assertEquals(global, resolvedOther); + } + + @Test + public void globalFallbackIsReturnedWhenNoFlagOverride() { + FallbackTreatment global = new FallbackTreatment("off"); + FallbackTreatmentsConfiguration config = FallbackTreatmentsConfiguration.builder() + .global(global) + .byFlag(Collections.emptyMap()) + .build(); + + FallbackTreatmentsCalculator calculator = new FallbackTreatmentsCalculatorImpl(config); + FallbackTreatment resolved = calculator.resolve("any_flag"); + + assertNotNull(resolved); + assertEquals(global, resolved); + } + + @Test + public void flagLevelFallbackIsReturnedWhenConfigured() { + FallbackTreatment byFlag = new FallbackTreatment("on"); + Map map = new HashMap<>(); + map.put("flagA", byFlag); + FallbackTreatmentsConfiguration config = FallbackTreatmentsConfiguration.builder() + .byFlag(map) + .build(); + + FallbackTreatmentsCalculator calculator = new FallbackTreatmentsCalculatorImpl(config); + FallbackTreatment resolved = calculator.resolve("flagA"); + + assertNotNull(resolved); + assertEquals(byFlag, resolved); + } + + @Test + public void returnsControlWhenNoFallbackConfigured() { + FallbackTreatmentsConfiguration config = FallbackTreatmentsConfiguration.builder() + .build(); + + FallbackTreatmentsCalculator calculator = new FallbackTreatmentsCalculatorImpl(config); + FallbackTreatment resolved = calculator.resolve("nope"); + + assertNotNull(resolved); + assertEquals(new FallbackTreatment("control", null, null), resolved); + } + + @Test + public void nonexistentFlagFallsBackToGlobal() { + FallbackTreatment global = new FallbackTreatment("off"); + Map map = new HashMap<>(); + map.put("flagA", new FallbackTreatment("on")); + FallbackTreatmentsConfiguration config = FallbackTreatmentsConfiguration.builder() + .global(global) + .byFlag(map) + .build(); + + FallbackTreatmentsCalculator calculator = new FallbackTreatmentsCalculatorImpl(config); + FallbackTreatment resolved = calculator.resolve("flagB"); + + assertNotNull(resolved); + assertEquals(global, resolved); + } + + @Test + public void labelIsPrefixed() { + FallbackTreatment global = new FallbackTreatment("off"); + FallbackTreatmentsConfiguration config = FallbackTreatmentsConfiguration.builder() + .global(global) + .build(); + + FallbackTreatmentsCalculator calculator = new FallbackTreatmentsCalculatorImpl(config); + FallbackTreatment resolved = calculator.resolve("flagA", TreatmentLabels.EXCEPTION); + + assertNotNull(resolved); + assertEquals("fallback - exception", resolved.getLabel()); + } + + @Test + public void noLabelReturnsNull() { + FallbackTreatment global = new FallbackTreatment("off"); + FallbackTreatmentsConfiguration config = FallbackTreatmentsConfiguration.builder() + .global(global) + .build(); + + FallbackTreatmentsCalculator calculator = new FallbackTreatmentsCalculatorImpl(config); + FallbackTreatment resolved = calculator.resolve("flagA", null); + + assertNotNull(resolved); + assertNull(resolved.getLabel()); + } +} diff --git a/src/test/java/io/split/android/client/fallback/FallbackTreatmentsConfigurationTest.java b/src/test/java/io/split/android/client/fallback/FallbackTreatmentsConfigurationTest.java new file mode 100644 index 000000000..8160e396a --- /dev/null +++ b/src/test/java/io/split/android/client/fallback/FallbackTreatmentsConfigurationTest.java @@ -0,0 +1,194 @@ +package io.split.android.client.fallback; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedDeque; + +import io.split.android.client.utils.logger.LogPrinterStub; +import io.split.android.client.utils.logger.Logger; +import io.split.android.client.utils.logger.SplitLogLevel; + +public class FallbackTreatmentsConfigurationTest { + + @Test + public void constructorSetsFields() { + FallbackTreatment global = new FallbackTreatment("off"); + Map map = new HashMap<>(); + map.put("flagA", new FallbackTreatment("off")); + + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .global(global) + .byFlag(map) + .build(); + + assertSame(global, cfg.getGlobal()); + assertEquals(1, cfg.getByFlag().size()); + assertEquals("off", cfg.getByFlag().get("flagA").getTreatment()); + } + + @Test + public void byFlagIsUnmodifiable() { + FallbackTreatment global = new FallbackTreatment("off"); + Map byFlag = new HashMap<>(); + byFlag.put("flagA", new FallbackTreatment("off")); + + FallbackTreatmentsConfiguration config = FallbackTreatmentsConfiguration.builder() + .global(global) + .byFlag(byFlag) + .build(); + + byFlag.put("flagB", new FallbackTreatment("on")); + + // config map must not change + assertEquals(1, config.getByFlag().size()); + + try { + config.getByFlag().put("x", new FallbackTreatment("on")); + throw new AssertionError("Map should be unmodifiable"); + } catch (UnsupportedOperationException expected) { + + } + } + + @Test + public void equalityAndHashCodeByValue() { + FallbackTreatment global = new FallbackTreatment("off"); + Map a = new HashMap<>(); + a.put("flagA", new FallbackTreatment("off")); + + Map b = new HashMap<>(); + b.put("flagA", new FallbackTreatment("off")); + + FallbackTreatmentsConfiguration configOne = FallbackTreatmentsConfiguration.builder().global(global).byFlag(a).build(); + FallbackTreatmentsConfiguration configTwo = FallbackTreatmentsConfiguration.builder().global(global).byFlag(b).build(); + FallbackTreatmentsConfiguration configThree = FallbackTreatmentsConfiguration.builder().global((String) null).byFlag(b).build(); + + assertEquals(configOne, configTwo); + assertEquals(configOne.hashCode(), configTwo.hashCode()); + assertNotEquals(configOne, configThree); + assertNotEquals(configOne.hashCode(), configThree.hashCode()); + } + + @Test + public void globalStringOverloadBuildsFallbackWithNullConfig() { + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .global("on") + .build(); + + FallbackTreatment global = cfg.getGlobal(); + assertEquals("on", global.getTreatment()); + assertNull(global.getConfig()); + } + + @Test + public void byFlagStringMapOverloadBuildsFallbacksWithNullConfig() { + Map flagTreatments = new HashMap<>(); + flagTreatments.put("flagA", "on"); + flagTreatments.put("flagB", "off"); + + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .byFlagStrings(flagTreatments) + .build(); + + assertEquals(2, cfg.getByFlag().size()); + assertEquals("on", cfg.getByFlag().get("flagA").getTreatment()); + assertNull(cfg.getByFlag().get("flagA").getConfig()); + assertEquals("off", cfg.getByFlag().get("flagB").getTreatment()); + assertNull(cfg.getByFlag().get("flagB").getConfig()); + } + + @Test + public void callingByFlagStringsAfterByFlagMergesResultsAndLogsWarning() { + LogPrinterStub printer = new LogPrinterStub(); + Logger.instance().setPrinter(printer); + Logger.instance().setLevel(SplitLogLevel.WARNING); + + Map first = new HashMap<>(); + first.put("flagA", new FallbackTreatment("off", "cfgA")); + first.put("flagB", new FallbackTreatment("on")); + + Map second = new HashMap<>(); + second.put("flagA", "on"); // should override flagA + second.put("flagC", "off"); + + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .byFlag(first) + .byFlagStrings(second) + .build(); + + assertEquals(3, cfg.getByFlag().size()); + assertEquals("on", cfg.getByFlag().get("flagA").getTreatment()); + assertNull(cfg.getByFlag().get("flagA").getConfig()); + assertEquals("on", cfg.getByFlag().get("flagB").getTreatment()); + assertEquals("off", cfg.getByFlag().get("flagC").getTreatment()); + // warning logged with expected content for overridden key + ConcurrentLinkedDeque warns = printer.getLoggedMessages().get(android.util.Log.WARN); + assertFalse("Expected at least one warning", warns.isEmpty()); + boolean containsExpected = warns.stream().anyMatch(m -> m.contains("Fallback treatments - Duplicate fallback for flag 'flagA'. Overriding existing value.")); + assertTrue("Expected warning mentioning overridden key 'flagA'", containsExpected); + } + + @Test + public void callingByFlagAfterByFlagStringsMergesResultsAndLogsWarning() { + LogPrinterStub printer = new LogPrinterStub(); + Logger.instance().setPrinter(printer); + Logger.instance().setLevel(SplitLogLevel.WARNING); + + Map first = new HashMap<>(); + first.put("flagA", "off"); + first.put("flagB", "on"); + + Map second = new HashMap<>(); + second.put("flagA", new FallbackTreatment("on", "cfgA")); // should override flagA + second.put("flagC", new FallbackTreatment("off")); + + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .byFlagStrings(first) + .byFlag(second) + .build(); + + assertEquals(3, cfg.getByFlag().size()); + assertEquals("on", cfg.getByFlag().get("flagA").getTreatment()); + assertEquals("cfgA", cfg.getByFlag().get("flagA").getConfig()); + assertEquals("on", cfg.getByFlag().get("flagB").getTreatment()); + assertNull(cfg.getByFlag().get("flagB").getConfig()); + assertEquals("off", cfg.getByFlag().get("flagC").getTreatment()); + + boolean warned = !printer.getLoggedMessages().get(android.util.Log.WARN).isEmpty(); + assertTrue("Expected a warning log when merging byFlag and byFlagStrings", warned); + } + + @Test + public void byFlagAndByFlagStrings_NoOverlap_NoWarning() { + LogPrinterStub printer = new LogPrinterStub(); + Logger.instance().setPrinter(printer); + Logger.instance().setLevel(SplitLogLevel.WARNING); + + Map first = new HashMap<>(); + first.put("flagA", new FallbackTreatment("off")); + + Map second = new HashMap<>(); + second.put("flagB", "on"); + + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .byFlag(first) + .byFlagStrings(second) + .build(); + + assertEquals(2, cfg.getByFlag().size()); + assertEquals("off", cfg.getByFlag().get("flagA").getTreatment()); + assertEquals("on", cfg.getByFlag().get("flagB").getTreatment()); + + boolean warned = !printer.getLoggedMessages().get(android.util.Log.WARN).isEmpty(); + assertFalse("Did not expect a warning", warned); + } +} diff --git a/src/test/java/io/split/android/client/fallback/FallbacksSanitizerImplTest.java b/src/test/java/io/split/android/client/fallback/FallbacksSanitizerImplTest.java new file mode 100644 index 000000000..638dab54e --- /dev/null +++ b/src/test/java/io/split/android/client/fallback/FallbacksSanitizerImplTest.java @@ -0,0 +1,173 @@ +package io.split.android.client.fallback; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Deque; +import java.util.HashMap; +import java.util.Map; + +import io.split.android.client.utils.logger.LogPrinterStub; +import io.split.android.client.utils.logger.Logger; +import io.split.android.client.utils.logger.SplitLogLevel; + +public class FallbacksSanitizerImplTest { + + private FallbacksSanitizerImpl mSanitizer; + private LogPrinterStub mLogPrinter; + + private static final String VALID_FLAG = "my_flag"; + private static final String INVALID_FLAG_WITH_SPACE = "my flag"; + private static final String LONG_101; + static { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 101; i++) sb.append('a'); + LONG_101 = sb.toString(); + } + + @Before + public void setUp() { + mSanitizer = new FallbacksSanitizerImpl(); + mLogPrinter = new LogPrinterStub(); + Logger.instance().setLevel(SplitLogLevel.VERBOSE); + Logger.instance().setPrinter(mLogPrinter); + } + + @Test + public void dropsInvalidFlagNamesAndTreatments() { + Map byFlag = new HashMap<>(); + byFlag.put(VALID_FLAG, new FallbackTreatment("on")); + byFlag.put(INVALID_FLAG_WITH_SPACE, new FallbackTreatment("off")); + byFlag.put(LONG_101, new FallbackTreatment("off")); + byFlag.put("tooLongTreatment", new FallbackTreatment(LONG_101)); + + FallbackTreatment sanitizedGlobal = mSanitizer.sanitizeGlobal(new FallbackTreatment("on")); + Map sanitizedByFlag = mSanitizer.sanitizeByFlag(byFlag); + FallbackTreatmentsConfiguration sanitized = FallbackTreatmentsConfiguration.builder() + .global(sanitizedGlobal) + .byFlag(sanitizedByFlag) + .build(); + + Deque errors = mLogPrinter.getLoggedMessages().get(android.util.Log.ERROR); + assertTrue("Expected ERROR logs to be present", errors != null && !errors.isEmpty()); + long invalidFlagNameCount = errors.stream().filter(m -> m.contains("Invalid flag name")).count(); + assertEquals(2, invalidFlagNameCount); + assertTrue(errors.stream().anyMatch(m -> m.contains("Discarded flag 'my flag'"))); + // invalid treatment for a specific flag name and contains the full expected message + assertTrue(errors.stream().anyMatch(m -> m.contains("Discarded treatment for flag 'tooLongTreatment'"))); + assertTrue(errors.stream().anyMatch(m -> m.contains("Invalid treatment (max 100 chars and comply with ^[0-9]+[.a-zA-Z0-9_-]*$|^[a-zA-Z]+[a-zA-Z0-9_-]*$)"))); + + assertEquals(1, sanitized.getByFlag().size()); + assertEquals("on", sanitized.getByFlag().get(VALID_FLAG).getTreatment()); + } + + @Test + public void dropsInvalidGlobalTreatment() { + FallbackTreatment sanitizedGlobal = mSanitizer.sanitizeGlobal(new FallbackTreatment(LONG_101)); // invalid treatment length + Map sanitizedByFlag = mSanitizer.sanitizeByFlag(null); + FallbackTreatmentsConfiguration sanitized = FallbackTreatmentsConfiguration.builder() + .global(sanitizedGlobal) + .byFlag(sanitizedByFlag) + .build(); + + // Assert error log for discarded global fallback only + Deque errors = mLogPrinter.getLoggedMessages().get(android.util.Log.ERROR); + assertTrue("Expected ERROR logs to be present", errors != null && !errors.isEmpty()); + assertTrue(errors.stream().anyMatch(m -> m.contains("Discarded global fallback"))); + + assertNull(sanitized.getGlobal()); + assertEquals(0, sanitized.getByFlag().size()); + } + + @Test + public void byFlagTreatmentIsDroppedWhenInvalidFormat() { + Map byFlag = new HashMap<>(); + byFlag.put(VALID_FLAG, new FallbackTreatment("on.off")); + byFlag.put("valid_num_dot", new FallbackTreatment("123.on")); + byFlag.put("null_treatment", new FallbackTreatment(null)); + + FallbackTreatment sanitizedGlobal = mSanitizer.sanitizeGlobal(null); + Map sanitizedByFlag = mSanitizer.sanitizeByFlag(byFlag); + FallbackTreatmentsConfiguration sanitized = FallbackTreatmentsConfiguration.builder() + .global(sanitizedGlobal) + .byFlag(sanitizedByFlag) + .build(); + + // Assert error logs for invalid treatments under flags + Deque errors = mLogPrinter.getLoggedMessages().get(android.util.Log.ERROR); + assertTrue("Expected ERROR logs to be present", errors != null && !errors.isEmpty()); + assertTrue(errors.stream().anyMatch(m -> m.contains("Discarded treatment for flag '" + VALID_FLAG + "'"))); + assertTrue(errors.stream().anyMatch(m -> m.contains("Invalid treatment (max 100 chars and comply with ^[0-9]+[.a-zA-Z0-9_-]*$|^[a-zA-Z]+[a-zA-Z0-9_-]*$)"))); + assertTrue(errors.stream().anyMatch(m -> m.contains("Discarded treatment for flag 'null_treatment'"))); + // Ensure no error for valid flag/treatment + assertTrue(errors.stream().noneMatch(m -> m.contains("Discarded treatment for flag 'valid_num_dot'"))); + + // Only the valid one should remain + assertEquals(1, sanitized.getByFlag().size()); + assertEquals("123.on", sanitized.getByFlag().get("valid_num_dot").getTreatment()); + } + + @Test + public void globalTreatmentIsDroppedWhenInvalidFormat() { + Map byFlag = new HashMap<>(); + byFlag.put(VALID_FLAG, new FallbackTreatment("on_1-2")); + byFlag.put("null_treatment", new FallbackTreatment(null)); + + // Global invalid due to regex (letters cannot be followed by '.') + FallbackTreatment sanitizedGlobal = mSanitizer.sanitizeGlobal(new FallbackTreatment("on.off")); + Map sanitizedByFlag = mSanitizer.sanitizeByFlag(byFlag); + FallbackTreatmentsConfiguration sanitized = FallbackTreatmentsConfiguration.builder() + .global(sanitizedGlobal) + .byFlag(sanitizedByFlag) + .build(); + + // Assert error logs were emitted for invalid entries + Deque errorLogs = mLogPrinter.getLoggedMessages().get(android.util.Log.ERROR); + assertTrue("Expected ERROR logs to be present", errorLogs != null && !errorLogs.isEmpty()); + boolean hasGlobalDiscard = false; + boolean hasNullFlagDiscard = false; + for (String msg : errorLogs) { + if (msg.contains("Discarded global fallback")) { + hasGlobalDiscard = true; + } + if (msg.contains("Discarded treatment for flag 'null_treatment'")) { + hasNullFlagDiscard = true; + } + } + assertTrue("Expected an error about discarded global fallback", hasGlobalDiscard); + assertTrue("Expected an error about discarded treatment for flag 'null_treatment'", hasNullFlagDiscard); + + assertNull(sanitized.getGlobal()); + // Ensure only the valid by-flag entry is preserved + assertEquals(1, sanitized.getByFlag().size()); + assertEquals("on_1-2", sanitized.getByFlag().get(VALID_FLAG).getTreatment()); + } + + @Test + public void validFormatTreatmentIsNotDropped() { + Map byFlag = new HashMap<>(); + byFlag.put("numWithDot", new FallbackTreatment("123.on")); + byFlag.put(VALID_FLAG, new FallbackTreatment("on_1-2")); + + FallbackTreatment sanitizedGlobal = mSanitizer.sanitizeGlobal(new FallbackTreatment("on")); + Map sanitizedByFlag = mSanitizer.sanitizeByFlag(byFlag); + FallbackTreatmentsConfiguration sanitized = FallbackTreatmentsConfiguration.builder() + .global(sanitizedGlobal) + .byFlag(sanitizedByFlag) + .build(); + + assertEquals(2, sanitized.getByFlag().size()); + assertTrue(sanitized.getByFlag().containsKey("numWithDot")); + assertEquals("123.on", sanitized.getByFlag().get("numWithDot").getTreatment()); + assertEquals("on_1-2", sanitized.getByFlag().get(VALID_FLAG).getTreatment()); + assertEquals("on", sanitized.getGlobal().getTreatment()); + + // No ERROR logs expected for valid-only case + Deque errors4 = mLogPrinter.getLoggedMessages().get(android.util.Log.ERROR); + assertTrue(errors4 == null || errors4.isEmpty()); + } +} diff --git a/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java b/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java new file mode 100644 index 000000000..738300ce7 --- /dev/null +++ b/src/test/java/io/split/android/client/network/DefaultBase64EncoderTest.java @@ -0,0 +1,47 @@ +package io.split.android.client.network; + +import static org.mockito.Mockito.mockStatic; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockedStatic; + +import java.nio.charset.StandardCharsets; + +import io.split.android.client.utils.Base64Util; + +public class DefaultBase64EncoderTest { + + private DefaultBase64Encoder encoder; + private MockedStatic mockedBase64Util; + + @Before + public void setUp() { + encoder = new DefaultBase64Encoder(); + mockedBase64Util = mockStatic(Base64Util.class); + } + + @After + public void tearDown() { + mockedBase64Util.close(); + } + + @Test + public void encodeStringUsesBase64Util() { + String input = "test string"; + + encoder.encode(input); + + mockedBase64Util.verify(() -> Base64Util.encode(input)); + } + + @Test + public void encodeByteArrayUsesBase64Util() { + byte[] input = "test bytes".getBytes(StandardCharsets.UTF_8); + + encoder.encode(input); + + mockedBase64Util.verify(() -> Base64Util.encode(input)); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpClientTest.java b/src/test/java/io/split/android/client/network/HttpClientTest.java index c78b52696..2daa5063b 100644 --- a/src/test/java/io/split/android/client/network/HttpClientTest.java +++ b/src/test/java/io/split/android/client/network/HttpClientTest.java @@ -23,11 +23,14 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.io.InputStream; import java.lang.reflect.Type; import java.net.URI; import java.net.URISyntaxException; import java.net.URL; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -52,12 +55,12 @@ public class HttpClientTest { private MockWebServer mWebServer; private MockWebServer mProxyServer; private HttpClient client; - private UrlSanitizer mUrlSanitizer; + private UrlSanitizer mUrlSanitizerMock; @Before public void setup() throws IOException { - mUrlSanitizer = mock(UrlSanitizer.class); - when(mUrlSanitizer.getUrl(any())).thenAnswer(new Answer() { + mUrlSanitizerMock = mock(UrlSanitizer.class); + when(mUrlSanitizerMock.getUrl(any())).thenAnswer(new Answer() { @Override public URL answer(InvocationOnMock invocation) throws Throwable { URI argument = invocation.getArgument(0); @@ -219,7 +222,7 @@ public void addHeaders() throws InterruptedException, URISyntaxException, HttpEx } @Test - public void addStreamingHeaders() throws InterruptedException, URISyntaxException, HttpException { + public void addStreamingHeaders() throws InterruptedException, HttpException, IOException { client.addStreamingHeaders(Collections.singletonMap("my_header", "my_header_value")); HttpUrl url = mWebServer.url("/test1/"); @@ -277,8 +280,8 @@ public MockResponse dispatch(RecordedRequest request) { HttpClient client = new HttpClientImpl.Builder() .setContext(mock(Context.class)) - .setUrlSanitizer(mUrlSanitizer) - .setProxy(new HttpProxy(mProxyServer.getHostName(), mProxyServer.getPort())) + .setUrlSanitizer(mUrlSanitizerMock) + .setProxy(HttpProxy.newBuilder(mProxyServer.getHostName(), mProxyServer.getPort()).buildLegacy()) .build(); HttpRequest request = client.request(mWebServer.url("/test1/").uri(), HttpMethod.GET); @@ -312,7 +315,7 @@ public MockResponse dispatch(RecordedRequest request) { HttpClient client = new HttpClientImpl.Builder() .setContext(mock(Context.class)) - .setUrlSanitizer(mUrlSanitizer) + .setUrlSanitizer(mUrlSanitizerMock) .setProxyAuthenticator(new SplitAuthenticator() { @Override public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest request) { @@ -322,7 +325,7 @@ public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest return request; } }) - .setProxy(new HttpProxy(mProxyServer.getHostName(), mProxyServer.getPort())) + .setProxy(HttpProxy.newBuilder(mProxyServer.getHostName(), mProxyServer.getPort()).buildLegacy()) .build(); HttpRequest request = client.request(mWebServer.url("/test1/").uri(), HttpMethod.GET); @@ -367,7 +370,7 @@ public MockResponse dispatch(RecordedRequest request) { HttpClient client = new HttpClientImpl.Builder() .setContext(mock(Context.class)) - .setUrlSanitizer(mUrlSanitizer) + .setUrlSanitizer(mUrlSanitizerMock) .setProxyAuthenticator(new SplitAuthenticator() { @Override public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest request) { @@ -377,7 +380,7 @@ public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest return request; } }) - .setProxy(new HttpProxy(mProxyServer.getHostName(), mProxyServer.getPort())) + .setProxy(HttpProxy.newBuilder(mProxyServer.getHostName(), mProxyServer.getPort()).buildLegacy()) .build(); HttpRequest request = client.request(mWebServer.url("/test1/").uri(), HttpMethod.POST, "{}"); @@ -400,6 +403,79 @@ public SplitAuthenticatedRequest authenticate(@NonNull SplitAuthenticatedRequest mProxyServer.shutdown(); } + + @Test + public void copyStreamToByteArrayWithSimpleString() { + String testString = "Test string content"; + InputStream inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)); + + byte[] result = HttpClientImpl.copyStreamToByteArray(inputStream); + + assertNotNull("Result should not be null", result); + assertEquals("Result should match original string", testString, new String(result, StandardCharsets.UTF_8)); + + byte[] buffer = new byte[100]; + try { + int bytesRead = inputStream.read(buffer); + assertEquals("Stream should be readable and contain the same content", testString, + new String(buffer, 0, bytesRead, StandardCharsets.UTF_8)); + } catch (IOException e) { + Assert.fail("Should be able to read from stream after copying: " + e.getMessage()); + } + } + + @Test + public void copyStreamToByteArrayWithEmptyStream() { + InputStream emptyStream = new ByteArrayInputStream(new byte[0]); + + byte[] result = HttpClientImpl.copyStreamToByteArray(emptyStream); + + assertNotNull("Result should not be null even for empty stream", result); + assertEquals("Result should be empty array", 0, result.length); + } + + @Test + public void copyStreamToByteArrayWithNullStream() { + byte[] result = HttpClientImpl.copyStreamToByteArray(null); + + assertNull("Result should be null for null input", result); + } + + @Test + public void copyStreamToByteArrayWithNonMarkableStream() { + InputStream nonMarkableStream = new InputStream() { + private final byte[] data = "Test data".getBytes(StandardCharsets.UTF_8); + private int position = 0; + + @Override + public int read() { + if (position < data.length) { + return data[position++] & 0xff; + } + return -1; + } + + @Override + public boolean markSupported() { + return false; + } + }; + + byte[] result = HttpClientImpl.copyStreamToByteArray(nonMarkableStream); + + assertNotNull("Result should not be null", result); + assertEquals("Result should match original content", "Test data", + new String(result, StandardCharsets.UTF_8)); + + int nextByte = -1; + try { + nextByte = nonMarkableStream.read(); + } catch (IOException e) { + Assert.fail("Reading from stream should not throw exception"); + } + assertEquals("Stream should be at EOF", -1, nextByte); + } + @After public void tearDown() throws IOException { mWebServer.shutdown(); @@ -456,7 +532,7 @@ public MockResponse dispatch(RecordedRequest request) throws InterruptedExceptio mWebServer.start(); client = new HttpClientImpl.Builder() - .setUrlSanitizer(mUrlSanitizer) + .setUrlSanitizer(mUrlSanitizerMock) .build(); } diff --git a/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java b/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java new file mode 100644 index 000000000..f6f3454dd --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpClientTunnellingProxyTest.java @@ -0,0 +1,657 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.URI; +import java.net.URL; +import java.nio.file.Files; +import java.security.KeyStore; +import java.util.Base64; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.mockwebserver.Dispatcher; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import okhttp3.tls.HeldCertificate; + +public class HttpClientTunnellingProxyTest { + + private UrlSanitizer mUrlSanitizerMock; + private Base64Decoder mBas64Decoder; + + @Before + public void setUp() { + + // override the default hostname verifier for testing + HttpsURLConnection.setDefaultHostnameVerifier(new HostnameVerifier() { + @Override + public boolean verify(String hostname, SSLSession sslSession) { + return true; + } + }); + mUrlSanitizerMock = mock(UrlSanitizer.class); + when(mUrlSanitizerMock.getUrl(any())).thenAnswer(new Answer() { + @Override + public URL answer(InvocationOnMock invocation) throws Throwable { + URI argument = invocation.getArgument(0); + + return new URL(argument.toString()); + } + }); + mBas64Decoder = new Base64Decoder() { + @Override + public byte[] decode(String base64) { + return Base64.getDecoder().decode(base64); + } + }; + } + + @Test + public void proxyCacertProxyTunnelling() throws Exception { + // 1. Create separate CA and server certs for proxy and origin + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + HeldCertificate originCa = new HeldCertificate.Builder() + .commonName("Test Origin CA") + .certificateAuthority(0) + .build(); + HeldCertificate originCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(originCa) + .build(); + + // 2. Start HTTP origin server (not HTTPS to avoid SSL layering issues) + MockWebServer originServer = new MockWebServer(); + CountDownLatch originLatch = new CountDownLatch(1); + final String[] methodAndPath = new String[2]; + originServer.setDispatcher(new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + methodAndPath[0] = request.getMethod(); + methodAndPath[1] = request.getPath(); + originLatch.countDown(); + return new MockResponse().setBody("from origin!"); + } + }); + // Use HTTP instead of HTTPS to test tunnel establishment without SSL layering issues + originServer.start(); + + // 3. Start SSL tunnel proxy (server-only SSL, no client cert required) + TunnelProxySslServerOnly tunnelProxy = new TunnelProxySslServerOnly(0, proxyServerCert); + tunnelProxy.start(); + while (tunnelProxy.mServerSocket == null || tunnelProxy.mServerSocket.getLocalPort() == 0) { + Thread.sleep(10); + } + int assignedProxyPort = tunnelProxy.mServerSocket.getLocalPort(); + + // 4. Write BOTH proxy CA and origin CA certs to temp file (for combined trust store) + File caCertFile = File.createTempFile("proxy-ca", ".pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write(proxyCa.certificatePem()); + writer.write(originCa.certificatePem()); + } + + // 5. Configure HttpProxy with PROXY_CACERT + HttpProxy proxy = HttpProxy.newBuilder("localhost", assignedProxyPort) + .proxyCacert(Files.newInputStream(caCertFile.toPath())) + .build(); + + // 6. Build client (let builder/factory handle trust) + HttpClient client = new HttpClientImpl.Builder() + .setProxy(proxy) + .setBase64Decoder(mBas64Decoder) + .setUrlSanitizer(mUrlSanitizerMock) + .build(); + + // 7. Make a request to the origin server (should tunnel via SSL proxy) + URI uri = originServer.url("/test").uri(); + HttpRequest req = client.request(uri, HttpMethod.GET); + HttpResponse resp = req.execute(); + assertNotNull(resp); + assertEquals(200, resp.getHttpStatus()); + assertEquals("from origin!", resp.getData()); + + // Assert that the tunnel was established and the origin received the request + assertTrue("TunnelProxy did not tunnel the request in time", tunnelProxy.getTunnelLatch().await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertTrue("Origin server did not receive the request in time", originLatch.await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertEquals("GET", methodAndPath[0]); + assertEquals("/test", methodAndPath[1]); + + tunnelProxy.stopProxy(); + originServer.shutdown(); + } + + @Test + public void proxyCacertProxyTunnelling_SslOverSsl() throws Exception { + // 1. Create separate CA and server certs for proxy and origin + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + HeldCertificate originCa = new HeldCertificate.Builder() + .commonName("Test Origin CA") + .certificateAuthority(0) + .build(); + HeldCertificate originCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(originCa) + .build(); + + // 2. Start HTTPS origin server + MockWebServer originServer = new MockWebServer(); + originServer.useHttps(createSslSocketFactory(originCert), false); // Use HTTPS for SSL-over-SSL + CountDownLatch originLatch = new CountDownLatch(1); + final String[] methodAndPath = new String[2]; + originServer.setDispatcher(new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + methodAndPath[0] = request.getMethod(); + methodAndPath[1] = request.getPath(); + originLatch.countDown(); + return new MockResponse().setBody("from https origin!"); + } + }); + originServer.start(); + + // 3. Start SSL tunnel proxy (server-only SSL, no client cert required) + TunnelProxySslServerOnly tunnelProxy = new TunnelProxySslServerOnly(0, proxyServerCert); + tunnelProxy.start(); + while (tunnelProxy.mServerSocket == null || tunnelProxy.mServerSocket.getLocalPort() == 0) { + Thread.sleep(10); + } + int assignedProxyPort = tunnelProxy.mServerSocket.getLocalPort(); + + // 4. Write BOTH proxy CA and origin CA certs to temp file (for combined trust store) + File caCertFile = File.createTempFile("proxy-ca", ".pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write(proxyCa.certificatePem()); + writer.write(originCa.certificatePem()); // Client needs to trust origin CA as well + } + + // 5. Configure HttpProxy with PROXY_CACERT + HttpProxy proxy = HttpProxy.newBuilder("localhost", assignedProxyPort) + .proxyCacert(Files.newInputStream(caCertFile.toPath())) + .build(); + + // 6. Build client (let builder/factory handle trust) + HttpClient client = new HttpClientImpl.Builder() + .setProxy(proxy) + .setBase64Decoder(mBas64Decoder) + .setUrlSanitizer(mUrlSanitizerMock) + .build(); + + // 7. Make a request to the HTTPS origin server (should tunnel via SSL proxy) + URI uri = originServer.url("/test").uri(); + HttpRequest req = client.request(uri, HttpMethod.GET); + HttpResponse resp = req.execute(); + assertNotNull(resp); + assertEquals(200, resp.getHttpStatus()); + assertEquals("from https origin!", resp.getData()); + + // Assert that the tunnel was established and the origin received the request + assertTrue("TunnelProxy did not tunnel the request in time", tunnelProxy.getTunnelLatch().await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertTrue("Origin server did not receive the request in time", originLatch.await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertEquals("GET", methodAndPath[0]); + assertEquals("/test", methodAndPath[1]); + + tunnelProxy.stopProxy(); + originServer.shutdown(); + } + + /** + * Negative test: mTLS proxy requires client certificate, but client presents none. + * The proxy should reject the connection, and the client should throw SSLHandshakeException. + */ + @Test + public void proxyMtlsProxyTunnelling_rejectsNoClientCert() throws Exception { + // 1. Create CA, proxy/server, and origin certs + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + HeldCertificate originCa = new HeldCertificate.Builder() + .commonName("Test Origin CA") + .certificateAuthority(0) + .build(); + HeldCertificate originCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(originCa) + .build(); + + // Write proxy server cert and key to temp files + File proxyCertFile = File.createTempFile("proxy-server", ".crt"); + File proxyKeyFile = File.createTempFile("proxy-server", ".key"); + try (FileWriter writer = new FileWriter(proxyCertFile)) { + writer.write(proxyServerCert.certificatePem()); + } + try (FileWriter writer = new FileWriter(proxyKeyFile)) { + writer.write(proxyServerCert.privateKeyPkcs8Pem()); + } + // Write proxy CA cert (for client auth) to temp file + File proxyCaFile = File.createTempFile("proxy-ca", ".crt"); + try (FileWriter writer = new FileWriter(proxyCaFile)) { + writer.write(proxyCa.certificatePem()); + } + + // 2. Start HTTPS origin server + MockWebServer originServer = new MockWebServer(); + originServer.useHttps(createSslSocketFactory(originCert), false); + originServer.start(); + + // 3. Start mTLS tunnel proxy + TunnelProxySsl tunnelProxy = new TunnelProxySsl(0, proxyServerCert, proxyCa); + tunnelProxy.start(); + while (tunnelProxy.mServerSocket == null || tunnelProxy.mServerSocket.getLocalPort() == 0) { + Thread.sleep(10); + } + int assignedProxyPort = tunnelProxy.mServerSocket.getLocalPort(); + + // 4. Configure HttpProxy WITHOUT client cert (should be rejected) + HttpProxy proxy = HttpProxy.newBuilder("localhost", assignedProxyPort) + .proxyCacert(Files.newInputStream(proxyCaFile.toPath())) // only trust, no client auth + .build(); + + // 5. Build client (let builder/factory handle trust) + HttpClient client = new HttpClientImpl.Builder() + .setProxy(proxy) + .setBase64Decoder(mBas64Decoder) + .setUrlSanitizer(mUrlSanitizerMock) + .build(); + + // 6. Make a request to the origin server (should fail at proxy handshake) + URI uri = originServer.url("/test").uri(); + HttpRequest req = client.request(uri, HttpMethod.GET); + boolean handshakeFailed = false; + try { + req.execute(); + } catch (Exception e) { + handshakeFailed = true; + } + assertTrue("Expected SSL handshake to fail due to missing client certificate", handshakeFailed); + + tunnelProxy.stopProxy(); + tunnelProxy.join(); + originServer.shutdown(); + proxyCertFile.delete(); + proxyKeyFile.delete(); + proxyCaFile.delete(); + } + + /** + * Positive test: mTLS proxy requires client certificate, and client presents a valid certificate. + * The proxy should accept the connection, tunnel should be established, and the request should reach the origin. + */ + @Test + public void proxyMtlsProxyTunnelling() throws Exception { + // 1. Create CA, proxy/server, client, and origin certs + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + HeldCertificate clientCert = new HeldCertificate.Builder() + .commonName("Test Client") + .signedBy(proxyCa) + .build(); + HeldCertificate originCa = new HeldCertificate.Builder() + .commonName("Test Origin CA") + .certificateAuthority(0) + .build(); + HeldCertificate originCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(originCa) + .build(); + + // Write proxy server cert and key to temp files + File proxyCertFile = File.createTempFile("proxy-server", ".crt"); + File proxyKeyFile = File.createTempFile("proxy-server", ".key"); + try (FileWriter writer = new FileWriter(proxyCertFile)) { + writer.write(proxyServerCert.certificatePem()); + } + try (FileWriter writer = new FileWriter(proxyKeyFile)) { + writer.write(proxyServerCert.privateKeyPkcs8Pem()); + } + // Write proxy CA cert (for client auth) to temp file + File proxyCaFile = File.createTempFile("proxy-ca", ".crt"); + try (FileWriter writer = new FileWriter(proxyCaFile)) { + writer.write(proxyCa.certificatePem()); + } + + // Write client certificate and key to separate files (PEM format) + File clientCertFile = File.createTempFile("client", ".crt"); + File clientKeyFile = File.createTempFile("client", ".key"); + try (FileWriter writer = new FileWriter(clientCertFile)) { + writer.write(clientCert.certificatePem()); + } + try (FileWriter writer = new FileWriter(clientKeyFile)) { + writer.write(clientCert.privateKeyPkcs8Pem()); + } + + // 2. Start HTTP origin server (not HTTPS to avoid SSL layering issues) + MockWebServer originServer = new MockWebServer(); + CountDownLatch originLatch = new CountDownLatch(1); + final String[] methodAndPath = new String[2]; + originServer.setDispatcher(new Dispatcher() { + @Override + public MockResponse dispatch(RecordedRequest request) { + methodAndPath[0] = request.getMethod(); + methodAndPath[1] = request.getPath(); + originLatch.countDown(); + return new MockResponse().setBody("from origin!"); + } + }); + // Use HTTP instead of HTTPS to test tunnel establishment without SSL layering issues + originServer.start(); + + // 3. Start mTLS tunnel proxy + TunnelProxySsl tunnelProxy = new TunnelProxySsl(0, proxyServerCert, proxyCa); + tunnelProxy.start(); + while (tunnelProxy.mServerSocket == null || tunnelProxy.mServerSocket.getLocalPort() == 0) { + Thread.sleep(10); + } + int assignedProxyPort = tunnelProxy.mServerSocket.getLocalPort(); + + // 4. Configure HttpProxy with mTLS (client cert, key, and CA) + HttpProxy proxy = HttpProxy.newBuilder("localhost", assignedProxyPort) + .mtls( + Files.newInputStream(clientCertFile.toPath()), + Files.newInputStream(clientKeyFile.toPath()) + ) + .proxyCacert(Files.newInputStream(proxyCaFile.toPath())) + .build(); + + // 5. Build client (let builder/factory handle trust) + HttpClient client = new HttpClientImpl.Builder() + .setProxy(proxy) + .setBase64Decoder(mBas64Decoder) + .setUrlSanitizer(mUrlSanitizerMock) + .build(); + + // 6. Make a request to the origin server (should tunnel via proxy) + URI uri = originServer.url("/test").uri(); + HttpRequest req = client.request(uri, HttpMethod.GET); + HttpResponse resp = req.execute(); + assertNotNull(resp); + assertEquals(200, resp.getHttpStatus()); + assertEquals("from origin!", resp.getData()); + + // Assert that the tunnel was established and the origin received the request + assertTrue("TunnelProxy did not tunnel the request in time", tunnelProxy.getTunnelLatch().await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertTrue("Origin server did not receive the request in time", originLatch.await(5, java.util.concurrent.TimeUnit.SECONDS)); + assertEquals("GET", methodAndPath[0]); + assertEquals("/test", methodAndPath[1]); + + tunnelProxy.stopProxy(); + tunnelProxy.join(); + originServer.shutdown(); + proxyCertFile.delete(); + proxyKeyFile.delete(); + proxyCaFile.delete(); + } + + // Helper to create SSLSocketFactory from HeldCertificate + private static SSLSocketFactory createSslSocketFactory(HeldCertificate cert) throws Exception { + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", cert.keyPair().getPrivate(), "password".toCharArray(), new java.security.cert.Certificate[]{cert.certificate()}); + + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(kmf.getKeyManagers(), null, null); + + return sslContext.getSocketFactory(); + } + + /** + * TunnelProxySslServerOnly is an SSL proxy that presents a server certificate but doesn't require client certificates. + * This is used for testing proxy_cacert functionality where the client validates the proxy's certificate. + */ + static class TunnelProxySslServerOnly extends TunnelProxy { + private final HeldCertificate mServerCert; + private final AtomicBoolean mRunning = new AtomicBoolean(true); + + public TunnelProxySslServerOnly(int port, HeldCertificate serverCert) { + super(port); + this.mServerCert = serverCert; + } + + @Override + public void run() { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", mServerCert.keyPair().getPrivate(), "password".toCharArray(), new java.security.cert.Certificate[]{mServerCert.certificate()}); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + + // No client certificate validation - use default trust manager + sslContext.init(kmf.getKeyManagers(), null, null); + + javax.net.ssl.SSLServerSocketFactory factory = sslContext.getServerSocketFactory(); + mServerSocket = factory.createServerSocket(mPort); + mPort = mServerSocket.getLocalPort(); // Update mPort with the actual assigned port + + // Don't require client auth - this is server-only SSL + ((javax.net.ssl.SSLServerSocket) mServerSocket).setWantClientAuth(false); + ((javax.net.ssl.SSLServerSocket) mServerSocket).setNeedClientAuth(false); + + System.out.println("[TunnelProxySslServerOnly] Listening on port: " + mServerSocket.getLocalPort()); + while (mRunning.get()) { + Socket client = mServerSocket.accept(); + System.out.println("[TunnelProxySslServerOnly] Accepted connection from: " + client.getRemoteSocketAddress()); + new Thread(() -> handle(client)).start(); + } + } catch (IOException e) { + System.out.println("[TunnelProxySslServerOnly] Server socket closed or error: " + e); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + /** + * TunnelProxySsl is a minimal SSL/TLS proxy supporting mTLS (client authentication). + * It uses an SSLServerSocket and requires client certificates signed by the provided CA. + * Only for use in mTLS proxy integration tests. + */ + private static class TunnelProxySsl extends TunnelProxy { + private final HeldCertificate mServerCert; + private final HeldCertificate mClientCa; + private final AtomicBoolean mRunning = new AtomicBoolean(true); + + public TunnelProxySsl(int port, HeldCertificate serverCert, HeldCertificate clientCa) { + super(port); + this.mServerCert = serverCert; + this.mClientCa = clientCa; + } + @Override + public void run() { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", mServerCert.keyPair().getPrivate(), "password".toCharArray(), new java.security.cert.Certificate[]{mServerCert.certificate()}); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); + trustStore.load(null, null); + trustStore.setCertificateEntry("ca", mClientCa.certificate()); + javax.net.ssl.TrustManagerFactory tmf = javax.net.ssl.TrustManagerFactory.getInstance(javax.net.ssl.TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(trustStore); + sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + javax.net.ssl.SSLServerSocketFactory factory = sslContext.getServerSocketFactory(); + mServerSocket = factory.createServerSocket(mPort); + ((javax.net.ssl.SSLServerSocket) mServerSocket).setNeedClientAuth(true); + System.out.println("[TunnelProxySsl] Listening on port: " + mServerSocket.getLocalPort()); + while (mRunning.get()) { + Socket client = mServerSocket.accept(); + System.out.println("[TunnelProxySsl] Accepted connection from: " + client.getRemoteSocketAddress()); + new Thread(() -> handle(client)).start(); + } + } catch (IOException e) { + System.out.println("[TunnelProxySsl] Server socket closed or error: " + e); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + + /** + * Minimal CONNECT-capable proxy for HTTPS tunneling in tests. + * Listens on a port, accepts CONNECT requests, and pipes bytes between the client and the requested target. + * Used to simulate a real HTTPS proxy for end-to-end CA trust validation. + */ + private static class TunnelProxy extends Thread { + // Latch to signal that a CONNECT tunnel was established + private final CountDownLatch mTunnelLatch = new CountDownLatch(1); + // Port to listen on (0 = auto-assign) + protected int mPort; + // The server socket for accepting connections + public ServerSocket mServerSocket; + // Flag to control proxy shutdown + private final AtomicBoolean mRunning = new AtomicBoolean(true); + + /** + * Create a new TunnelProxy listening on the given port. + * @param port Port to listen on (0 = auto-assign) + */ + TunnelProxy(int port) { mPort = port; } + + /** + * Main accept loop. For each incoming client, start a handler thread. + */ + public void run() { + try { + mServerSocket = new ServerSocket(mPort); + System.out.println("[TunnelProxy] Listening on port: " + mServerSocket.getLocalPort()); + while (mRunning.get()) { + Socket client = mServerSocket.accept(); + System.out.println("[TunnelProxy] Accepted connection from: " + client.getRemoteSocketAddress()); + // Each client handled in its own thread + new Thread(() -> handle(client)).start(); + } + } catch (IOException ignored) { + System.out.println("[TunnelProxy] Server socket closed or error: " + ignored); + } + } + + /** + * Handles a single client connection. Waits for CONNECT, then establishes a tunnel. + */ + void handle(Socket client) { + try (BufferedReader in = new BufferedReader(new InputStreamReader(client.getInputStream())); + OutputStream out = client.getOutputStream()) { + String line = in.readLine(); + // Only handle CONNECT requests (as sent by HTTPS clients to a proxy) + if (line != null && line.startsWith("CONNECT")) { + mTunnelLatch.countDown(); + System.out.println("[TunnelProxy] Received CONNECT: " + line); + out.write("HTTP/1.1 200 Connection Established\r\n\r\n".getBytes()); + out.flush(); + String[] parts = line.split(" "); + String[] hostPort = parts[1].split(":"); + // Open a socket to the requested target (origin server) + Socket target = new Socket(hostPort[0], Integer.parseInt(hostPort[1])); + System.out.println("[TunnelProxy] Established tunnel to: " + hostPort[0] + ":" + hostPort[1]); + // Pipe bytes in both directions (client <-> target) until closed + Thread t1 = new Thread(() -> { + try { + pipe(client, target); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + Thread t2 = new Thread(() -> { + try { + pipe(target, client); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + t1.start(); t2.start(); + try { t1.join(); t2.join(); } catch (InterruptedException ignored) {} + System.out.println("[TunnelProxy] Tunnel closed for: " + hostPort[0] + ":" + hostPort[1]); + target.close(); + } + } catch (Exception ignored) { } + } + + /** + * Copies bytes from inSocket to outSocket until EOF. + * Used to relay data in both directions for the tunnel. + */ + private void pipe(Socket inSocket, Socket outSocket) throws IOException { + try (InputStream in = inSocket.getInputStream(); OutputStream out = outSocket.getOutputStream()) { + byte[] buf = new byte[1024]; + int len; + while ((len = in.read(buf)) != -1) { + out.write(buf, 0, len); + out.flush(); + } + } catch (IOException ignored) { } + } + + /** + * Stops the proxy by closing the server socket and setting the running flag to false. + */ + public void stopProxy() throws IOException { + mRunning.set(false); + if (mServerSocket != null && !mServerSocket.isClosed()) { + mServerSocket.close(); + System.out.println("[TunnelProxy] Proxy stopped."); + } + } + + public CountDownLatch getTunnelLatch() { + return mTunnelLatch; + } + } +} diff --git a/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java b/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java new file mode 100644 index 000000000..2fdc64e30 --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpOverTunnelExecutorTest.java @@ -0,0 +1,212 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.when; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.net.URL; +import java.util.Collections; + +public class HttpOverTunnelExecutorTest { + + private static final String CRLF = "\r\n"; + private HttpOverTunnelExecutor mExecutor; + + @Mock + private Socket mSocket; + + private OutputStream mOutputStream; + private InputStream mInputStream; + + @Before + public void setUp() throws IOException { + MockitoAnnotations.openMocks(this); + mExecutor = new HttpOverTunnelExecutor(); + mOutputStream = new ByteArrayOutputStream(); + + String httpResponse = "HTTP/1.1 200 OK" + CRLF + "Content-Length: 0\r\n\r\n"; + mInputStream = new ByteArrayInputStream(httpResponse.getBytes()); + + when(mSocket.getOutputStream()).thenReturn(mOutputStream); + when(mSocket.getInputStream()).thenReturn(mInputStream); + } + + @Test + public void postRequestWithBodyAndHeaders() throws IOException { + URL url = new URL("https://test.com/path"); + String body = "{\"key\":\"value\"}"; + java.util.Map headers = new java.util.HashMap<>(); + headers.put("Custom-Header", "CustomValue"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.POST, headers, body, null); + + String expectedRequest = "POST /path HTTP/1.1" + CRLF + + "Host: test.com" + CRLF + + "Custom-Header: CustomValue" + CRLF + + "Content-Length: 15" + CRLF + + "Connection: close" + CRLF + + CRLF + + body; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @Test + public void getRequestWithQuery() throws IOException { + URL url = new URL("http://test.com/path?q=1&v=2"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + + String expectedRequest = "GET /path?q=1&v=2 HTTP/1.1" + CRLF + + "Host: test.com" + CRLF + + "Connection: close" + CRLF + + CRLF; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @Test + public void getRequestWithNonDefaultPort() throws IOException { + URL url = new URL("http://test.com:8080/path"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + + String expectedRequest = "GET /path HTTP/1.1" + CRLF + + "Host: test.com:8080" + CRLF + + "Connection: close" + CRLF + + CRLF; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @Test + public void getRequestWithEmptyPath() throws IOException { + URL url = new URL("http://test.com"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + + String expectedRequest = "GET / HTTP/1.1" + CRLF + + "Host: test.com" + CRLF + + "Connection: close" + CRLF + + CRLF; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @Test(expected = IOException.class) + public void requestThrowsIOException() throws IOException { + URL url = new URL("http://test.com/path"); + when(mSocket.getOutputStream()).thenThrow(new IOException("Socket error")); + + mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + } + + @Test + public void getRequest() throws IOException { + URL url = new URL("http://test.com/path"); + + HttpResponse response = mExecutor.executeRequest(mSocket, url, HttpMethod.GET, Collections.emptyMap(), null, null); + + String expectedRequest = "GET /path HTTP/1.1" + CRLF + + "Host: test.com" + CRLF + + "Connection: close" + CRLF + + CRLF; + + assertEquals(expectedRequest, mOutputStream.toString()); + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + } + + @After + public void tearDown() throws IOException { + mOutputStream.close(); + mInputStream.close(); + } + + @Test + public void executeStreamRequestTest() throws IOException { + // Prepare HTTP response with headers and body + String httpResponse = "HTTP/1.1 200 OK\r\n" + + "Content-Type: application/json; charset=utf-8\r\n" + + "Content-Length: 16\r\n" + + "\r\n" + + "{\"data\":\"test\"}"; + + // Set up input stream with the HTTP response + ByteArrayInputStream inputStream = new ByteArrayInputStream(httpResponse.getBytes()); + when(mSocket.getInputStream()).thenReturn(inputStream); + + // Execute the stream request + URL url = new URL("https://test.com/stream"); + HttpStreamResponse response = mExecutor.executeStreamRequest( + mSocket, // finalSocket + mSocket, // tunnelSocket (using same mock for simplicity) + null, // originSocket + url, + HttpMethod.GET, + Collections.emptyMap(), + null // serverCertificates + ); + + // Verify the request was sent correctly + String expectedRequest = "GET /stream HTTP/1.1\r\n" + + "Host: test.com\r\n" + + "Connection: close\r\n" + + "\r\n"; + assertEquals(expectedRequest, mOutputStream.toString()); + + // Verify the response properties + assertNotNull(response); + assertEquals(200, response.getHttpStatus()); + + // Verify we can read from the response stream + BufferedReader reader = response.getBufferedReader(); + assertNotNull(reader); + StringBuilder responseBody = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + responseBody.append(line); + } + assertEquals("{\"data\":\"test\"}", responseBody.toString()); + + // Close the response + response.close(); + } + + @Test(expected = IOException.class) + public void executeStreamRequestWithSocketException() throws IOException { + URL url = new URL("http://test.com/stream"); + when(mSocket.getOutputStream()).thenThrow(new IOException("Socket error")); + + mExecutor.executeStreamRequest( + mSocket, + mSocket, + null, + url, + HttpMethod.GET, + Collections.emptyMap(), + null + ); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpRequestHelperTest.java b/src/test/java/io/split/android/client/network/HttpRequestHelperTest.java new file mode 100644 index 000000000..b3f08fb55 --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpRequestHelperTest.java @@ -0,0 +1,112 @@ +package io.split.android.client.network; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.lang.reflect.Method; +import java.net.HttpURLConnection; +import java.net.Proxy; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSocketFactory; + +public class HttpRequestHelperTest { + + @Mock + private HttpURLConnection mockConnection; + @Mock + private HttpsURLConnection mockHttpsConnection; + @Mock + private URL mockUrl; + @Mock + private SplitUrlConnectionAuthenticator mockAuthenticator; + @Mock + private SSLSocketFactory mockSslSocketFactory; + @Mock + private DevelopmentSslConfig mockDevelopmentSslConfig; + @Mock + private CertificateChecker mockCertificateChecker; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + when(mockUrl.openConnection()).thenReturn(mockConnection); + when(mockUrl.openConnection(any(Proxy.class))).thenReturn(mockConnection); + when(mockAuthenticator.authenticate(any(HttpURLConnection.class))).thenReturn(mockConnection); + } + + @Test + public void addHeaders() throws Exception { + Map headers = new HashMap<>(); + headers.put("Content-Type", "application/json"); + headers.put("Authorization", "Bearer token123"); + headers.put(null, "This should be ignored"); + + Method addHeadersMethod = HttpRequestHelper.class.getDeclaredMethod( + "addHeaders", HttpURLConnection.class, Map.class); + addHeadersMethod.setAccessible(true); + addHeadersMethod.invoke(null, mockConnection, headers); + + verify(mockConnection).addRequestProperty("Content-Type", "application/json"); + verify(mockConnection).addRequestProperty("Authorization", "Bearer token123"); + verify(mockConnection, never()).addRequestProperty(null, "This should be ignored"); + } + + @Test + public void applyTimeouts() { + HttpRequestHelper.applyTimeouts(5000, 3000, mockConnection); + verify(mockConnection).setReadTimeout(5000); + verify(mockConnection).setConnectTimeout(3000); + + HttpRequestHelper.applyTimeouts(0, 0, mockConnection); + verify(mockConnection, times(1)).setReadTimeout(any(Integer.class)); + verify(mockConnection, times(1)).setConnectTimeout(any(Integer.class)); + + HttpRequestHelper.applyTimeouts(-1000, -500, mockConnection); + verify(mockConnection, times(1)).setReadTimeout(any(Integer.class)); + verify(mockConnection, times(1)).setConnectTimeout(any(Integer.class)); + } + + @Test + public void applySslConfigWithDevelopmentSslConfig() { + when(mockDevelopmentSslConfig.getSslSocketFactory()).thenReturn(mockSslSocketFactory); + + HttpRequestHelper.applySslConfig(null, mockDevelopmentSslConfig, mockHttpsConnection); + + verify(mockHttpsConnection).setSSLSocketFactory(mockSslSocketFactory); + verify(mockHttpsConnection).setHostnameVerifier(any()); + } + + @Test + public void pinsAreCheckedWithCertificateChecker() throws SSLPeerUnverifiedException { + HttpRequestHelper.checkPins(mockHttpsConnection, mockCertificateChecker); + + verify(mockCertificateChecker).checkPins(mockHttpsConnection); + } + + @Test + public void pinsAreNotCheckedWithoutCertificateChecker() throws SSLPeerUnverifiedException { + HttpRequestHelper.checkPins(mockHttpsConnection, null); + + verify(mockCertificateChecker, never()).checkPins(any()); + } + + @Test + public void pinsAreNotCheckedForNonHttpsConnections() throws SSLPeerUnverifiedException { + HttpRequestHelper.checkPins(mockConnection, mockCertificateChecker); + + verify(mockCertificateChecker, never()).checkPins(any()); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java b/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java new file mode 100644 index 000000000..0bc972444 --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java @@ -0,0 +1,528 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.security.cert.Certificate; +import java.util.List; +import java.util.Map; + +public class HttpResponseConnectionAdapterTest { + + @Mock + private HttpResponse mMockResponse; + + @Mock + private Certificate mMockCertificate; + + private URL mTestUrl; + private Certificate[] mTestCertificates; + private HttpResponseConnectionAdapter mAdapter; + + @Before + public void setUp() throws MalformedURLException { + mMockCertificate = mock(Certificate.class); + mMockResponse = mock(HttpResponse.class); + mTestUrl = new URL("https://example.com/test"); + mTestCertificates = new Certificate[]{mMockCertificate}; + } + + @Test + public void responseCodeIsValueFromResponse() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals(200, mAdapter.getResponseCode()); + } + + @Test + public void successfulResponse() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("OK", mAdapter.getResponseMessage()); + } + + @Test + public void status400Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(400); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Bad Request", mAdapter.getResponseMessage()); + } + + @Test + public void status401Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(401); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Unauthorized", mAdapter.getResponseMessage()); + } + + @Test + public void status403Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(403); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Forbidden", mAdapter.getResponseMessage()); + } + + @Test + public void status404Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(404); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Not Found", mAdapter.getResponseMessage()); + } + + @Test + public void status500Response() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(500); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("Internal Server Error", mAdapter.getResponseMessage()); + } + + @Test + public void statusUnknownResponse() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(418); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + assertEquals("HTTP 418", mAdapter.getResponseMessage()); + } + + @Test + public void successfulInputStream() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(200); + when(mMockResponse.getData()).thenReturn("test data"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream inputStream = mAdapter.getInputStream(); + assertNotNull(inputStream); + + byte[] buffer = new byte[1024]; + int bytesRead = inputStream.read(buffer); + String result = new String(buffer, 0, bytesRead, StandardCharsets.UTF_8); + assertEquals("test data", result); + } + + @Test + public void nullDataInputStream() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(200); + when(mMockResponse.getData()).thenReturn(null); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream inputStream = mAdapter.getInputStream(); + assertNotNull(inputStream); + + byte[] buffer = new byte[1024]; + int bytesRead = inputStream.read(buffer); + assertEquals(-1, bytesRead); + } + + @Test(expected = IOException.class) + public void inputStreamErrorStatusThrows() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(400); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + mAdapter.getInputStream(); + } + + @Test(expected = IOException.class) + public void inputStream500Throws() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(500); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + mAdapter.getInputStream(); + } + + @Test + public void status400ErrorStream() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(400); + when(mMockResponse.getData()).thenReturn("error message"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream errorStream = mAdapter.getErrorStream(); + assertNotNull(errorStream); + + byte[] buffer = new byte[1024]; + int bytesRead = errorStream.read(buffer); + String result = new String(buffer, 0, bytesRead, StandardCharsets.UTF_8); + assertEquals("error message", result); + } + + @Test + public void errorStreamStatus500() throws IOException { + when(mMockResponse.getHttpStatus()).thenReturn(500); + when(mMockResponse.getData()).thenReturn(null); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream errorStream = mAdapter.getErrorStream(); + assertNotNull(errorStream); + + byte[] buffer = new byte[1024]; + int bytesRead = errorStream.read(buffer); + assertEquals(-1, bytesRead); // Empty stream + } + + @Test + public void errorStreamIsNullForSuccessful() { + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + InputStream errorStream = mAdapter.getErrorStream(); + assertNull(errorStream); + } + + @Test + public void usingProxyIsAlwaysTrue() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + assertTrue(mAdapter.usingProxy()); // This is only used for Proxy + } + + @Test + public void getServerCertificatesReturnsCerts() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + Certificate[] certificates = mAdapter.getServerCertificates(); + assertSame(mTestCertificates, certificates); + assertEquals(1, certificates.length); + assertSame(mMockCertificate, certificates[0]); + } + + @Test + public void nullServerCertificates() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, null); + + Certificate[] certificates = mAdapter.getServerCertificates(); + assertNull(certificates); + } + + @Test + public void contentTypeIsJsonForJsonData() { + when(mMockResponse.getData()).thenReturn("{\"key\": \"value\"}"); + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getHeaderField("content-type"); + assertEquals("application/json; charset=utf-8", contentType); + } + + @Test + public void nullHeaderField() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String result = mAdapter.getHeaderField(null); + assertNull(result); + } + + @Test + public void getHeaderIsCaseInsensitive() { + when(mMockResponse.getData()).thenReturn("{\"key\": \"value\"}"); + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType1 = mAdapter.getHeaderField("Content-Type"); + String contentType2 = mAdapter.getHeaderField("CONTENT-TYPE"); + String contentType3 = mAdapter.getHeaderField("content-type"); + + assertEquals("application/json; charset=utf-8", contentType1); + assertEquals("application/json; charset=utf-8", contentType2); + assertEquals("application/json; charset=utf-8", contentType3); + } + + @Test + public void generatedHeaderFieldsCanBeRetrieved() throws IOException { + when(mMockResponse.getData()).thenReturn("{\"test\": \"data\"}"); + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + Map> headers = mAdapter.getHeaderFields(); + assertNotNull(headers); + + assertTrue(headers.containsKey("content-type")); + assertEquals("application/json; charset=utf-8", headers.get("content-type").get(0)); + + assertTrue(headers.containsKey("content-length")); + assertEquals("16", headers.get("content-length").get(0)); + + assertTrue(headers.containsKey("content-encoding")); + assertEquals("utf-8", headers.get("content-encoding").get(0)); + + assertTrue(headers.containsKey("status")); + assertEquals("200 OK", headers.get("status").get(0)); + } + + @Test + public void getContentLengthWithData() { + when(mMockResponse.getData()).thenReturn("Hello World"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + long length = mAdapter.getContentLengthLong(); + assertEquals(11, length); // "Hello World" is 11 bytes + } + + @Test + public void getContentLengthWithNullData() { + when(mMockResponse.getData()).thenReturn(null); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + long length = mAdapter.getContentLengthLong(); + assertEquals(0, length); + } + + @Test + public void getContentLengthEmptyData() { + when(mMockResponse.getData()).thenReturn(""); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + long length = mAdapter.getContentLengthLong(); + assertEquals(0, length); + } + + @Test + public void getContentTypeJsonData() { + when(mMockResponse.getData()).thenReturn("{\"key\": \"value\"}"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertEquals("application/json; charset=utf-8", contentType); + } + + @Test + public void getContentTypeJsonArray() { + when(mMockResponse.getData()).thenReturn("[{\"key\": \"value\"}]"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertEquals("application/json; charset=utf-8", contentType); + } + + @Test + public void getContentTypeHtmlData() { + when(mMockResponse.getData()).thenReturn("Test"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertEquals("text/html; charset=utf-8", contentType); + } + + @Test + public void getContentTypeText() { + when(mMockResponse.getData()).thenReturn("Plain text content"); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertEquals("text/plain; charset=utf-8", contentType); + } + + @Test + public void getContentTypeNullDataHasNoContentType() { + when(mMockResponse.getData()).thenReturn(null); + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String contentType = mAdapter.getContentType(); + assertNull(contentType); + } + + @Test + public void testGetContentEncoding() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + String encoding = mAdapter.getContentEncoding(); + assertEquals("utf-8", encoding); + } + + @Test + public void testGetDate() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + long currentTime = System.currentTimeMillis(); + long date = mAdapter.getDate(); + + // Should be close to current time (within 1 second) + assertTrue(Math.abs(date - currentTime) < 1000); + } + + @Test + public void urlCanBeRetrieved() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + URL url = mAdapter.getURL(); + assertSame(mTestUrl, url); + } + + @Test(expected = IOException.class) + public void getOutputStreamThrowsWhenNotEnabled() throws IOException { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + // Should throw exception since doOutput is not enabled + mAdapter.getOutputStream(); + } + + @Test + public void setDoOutputEnablesOutput() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + // Initially doOutput should be false + assertEquals(false, mAdapter.getDoOutput()); + + // After setting doOutput to true, getDoOutput should return true + mAdapter.setDoOutput(true); + assertEquals(true, mAdapter.getDoOutput()); + } + + @Test + public void getOutputStreamAfterEnablingOutput() throws IOException { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + mAdapter.setDoOutput(true); + + assertNotNull("Output stream should not be null when doOutput is enabled", mAdapter.getOutputStream()); + } + + @Test + public void writeToOutputStream() throws IOException { + // Create a ByteArrayOutputStream to capture the written data + ByteArrayOutputStream testOutputStream = new ByteArrayOutputStream(); + + // Use the constructor that accepts a custom OutputStream + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates, testOutputStream); + mAdapter.setDoOutput(true); + + // Write test data to the output stream + String testData = "Test output data"; + mAdapter.getOutputStream().write(testData.getBytes(StandardCharsets.UTF_8)); + + // Verify that the data was written correctly + assertEquals("Written data should match the input", testData, testOutputStream.toString(StandardCharsets.UTF_8.name())); + } + + @Test + public void disconnectClosesOutputStream() throws IOException { + // Create a custom OutputStream that tracks if it's been closed + TestOutputStream testOutputStream = new TestOutputStream(); + + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates, testOutputStream); + mAdapter.setDoOutput(true); + + // Get the output stream and write some data + mAdapter.getOutputStream().write("Test".getBytes(StandardCharsets.UTF_8)); + + // Verify the stream is not closed yet + assertFalse("Output stream should not be closed before disconnect", testOutputStream.isClosed()); + + // Disconnect should close the output stream + mAdapter.disconnect(); + + // Verify the stream was closed + assertTrue("Output stream should be closed after disconnect", testOutputStream.isClosed()); + } + + @Test + public void disconnectClosesInputStream() throws IOException { + // Create a custom InputStream that tracks if it's been closed + TestInputStream testInputStream = new TestInputStream("Test response data".getBytes(StandardCharsets.UTF_8)); + TestOutputStream testOutputStream = new TestOutputStream(); + + // Create adapter with injected test input stream + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter( + mTestUrl, + mMockResponse, + mTestCertificates, + testOutputStream, + testInputStream, + null); + + // Get the input stream and read some data to simulate usage + InputStream stream = mAdapter.getInputStream(); + byte[] buffer = new byte[10]; + stream.read(buffer); + + // Verify the stream is not closed yet + assertFalse("Input stream should not be closed before disconnect", testInputStream.isClosed()); + + // Disconnect should close the input stream + mAdapter.disconnect(); + + // Verify the stream was closed + assertTrue("Input stream should be closed after disconnect", testInputStream.isClosed()); + } + + /** + * Custom OutputStream implementation for testing that tracks if it's been closed. + */ + private static class TestOutputStream extends ByteArrayOutputStream { + private boolean mClosed = false; + + @Override + public void close() throws IOException { + super.close(); + mClosed = true; + } + + public boolean isClosed() { + return mClosed; + } + } + + private static class TestInputStream extends ByteArrayInputStream { + private boolean mClosed = false; + + public TestInputStream(byte[] data) { + super(data); + } + + @Override + public void close() throws IOException { + super.close(); + mClosed = true; + } + + public boolean isClosed() { + return mClosed; + } + } + + @Test + public void disconnectClosesErrorStream() throws IOException { + TestInputStream testErrorStream = new TestInputStream("Error data".getBytes(StandardCharsets.UTF_8)); + TestOutputStream testOutputStream = new TestOutputStream(); + + when(mMockResponse.getHttpStatus()).thenReturn(404); // Error status + mAdapter = new HttpResponseConnectionAdapter( + mTestUrl, + mMockResponse, + mTestCertificates, + testOutputStream, + null, + testErrorStream); + + // Get the error stream and read some data to simulate usage + InputStream stream = mAdapter.getErrorStream(); + byte[] buffer = new byte[10]; + stream.read(buffer); + + assertFalse("Error stream should not be closed before disconnect", testErrorStream.isClosed()); + + mAdapter.disconnect(); + + assertTrue("Error stream should be closed after disconnect", testErrorStream.isClosed()); + } +} diff --git a/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java b/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java new file mode 100644 index 000000000..aa60be76e --- /dev/null +++ b/src/test/java/io/split/android/client/network/HttpStreamResponseTest.java @@ -0,0 +1,148 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.BufferedReader; +import java.io.IOException; +import java.net.Socket; + +public class HttpStreamResponseTest { + + private static final int HTTP_STATUS_OK = 200; + private static final int HTTP_STATUS_BAD_REQUEST = 400; + + @Mock + private BufferedReader mockBufferedReader; + + @Mock + private Socket mockTunnelSocket; + + @Mock + private Socket mockOriginSocket; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void createFromTunnelSocketReturnsValidResponse() { + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Verify the response is created correctly + assertNotNull(response); + assertEquals(HTTP_STATUS_OK, response.getHttpStatus()); + } + + @Test + public void createFromTunnelSocketWithNullSocketsReturnsValidResponse() { + // Create response with null sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_BAD_REQUEST, + mockBufferedReader, + null, + null + ); + + // Verify the response is created correctly + assertNotNull(response); + assertEquals(HTTP_STATUS_BAD_REQUEST, response.getHttpStatus()); + } + + @Test + public void closeSuccessfullyClosesAllResources() throws IOException { + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Close the response + response.close(); + + // Verify all resources were closed in the correct order + verify(mockBufferedReader, times(1)).close(); + verify(mockOriginSocket, times(1)).close(); + verify(mockTunnelSocket, times(1)).close(); + } + + @Test + public void closeWithNullSocketsOnlyClosesBufferedReader() throws IOException { + // Create response with null sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + null, + null + ); + + // Close the response + response.close(); + + // Verify only the BufferedReader was closed + verify(mockBufferedReader, times(1)).close(); + verifyNoMoreInteractions(mockTunnelSocket, mockOriginSocket); + } + + @Test + public void closeWithSameTunnelAndOriginSocketClosesSocketOnce() throws IOException { + // Create response with the same socket for tunnel and origin + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockTunnelSocket + ); + + // Close the response + response.close(); + + // Verify BufferedReader was closed + verify(mockBufferedReader, times(1)).close(); + + // Verify tunnel socket was closed only once (since it's the same as origin socket) + verify(mockTunnelSocket, times(1)).close(); + } + + @Test + public void closeWithExceptionsSucceeds() throws IOException { + // Setup mocks to throw exceptions when closed + doThrow(new IOException("BufferedReader close error")).when(mockBufferedReader).close(); + doThrow(new IOException("Origin socket close error")).when(mockOriginSocket).close(); + doThrow(new IOException("Tunnel socket close error")).when(mockTunnelSocket).close(); + + // Create response with both sockets + HttpStreamResponseImpl response = HttpStreamResponseImpl.createFromTunnelSocket( + HTTP_STATUS_OK, + mockBufferedReader, + mockTunnelSocket, + mockOriginSocket + ); + + // Close the response - should not throw exceptions + response.close(); + + // Verify all resources were attempted to be closed despite exceptions + verify(mockBufferedReader, times(1)).close(); + verify(mockOriginSocket, times(1)).close(); + verify(mockTunnelSocket, times(1)).close(); + } +} diff --git a/src/test/java/io/split/android/client/network/ProxyConfigurationTest.java b/src/test/java/io/split/android/client/network/ProxyConfigurationTest.java new file mode 100644 index 000000000..5730579b5 --- /dev/null +++ b/src/test/java/io/split/android/client/network/ProxyConfigurationTest.java @@ -0,0 +1,142 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; + +public class ProxyConfigurationTest { + + private static final String VALID_URL = "http://proxy.example.com:8080"; + private static final String INVALID_URL = "invalid://\\url"; + private static final String URL_WITH_PATH = "https://proxy.example.com:8080/path/to/proxy"; + private static final String URL_WITH_PATH_NORMALIZED = "https://proxy.example.com:8080/path/to/proxy"; + + @Mock + private BearerCredentialsProvider mockBearerCredentialsProvider; + + @Mock + private BasicCredentialsProvider mockBasicCredentialsProvider; + + private InputStream clientCert; + private InputStream clientPk; + private InputStream caCert; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + + clientCert = new ByteArrayInputStream("client-cert-content".getBytes()); + clientPk = new ByteArrayInputStream("client-pk-content".getBytes()); + caCert = new ByteArrayInputStream("ca-cert-content".getBytes()); + } + + @Test + public void buildWithValidUrl() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .build(); + + assertNotNull("Configuration should be created with valid URL", config); + assertEquals("URL should match", VALID_URL, config.getUrl().toString()); + } + + @Test + public void buildWithInvalidUrlHasNoNullConfig() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(INVALID_URL) + .build(); + + assertNotNull(config); + } + + @Test + public void urlIsNormalized() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(URL_WITH_PATH) + .build(); + + assertNotNull("Configuration should be created with URL containing path", config); + assertEquals("URL should be normalized", URL_WITH_PATH_NORMALIZED, config.getUrl().toString()); + } + + @Test + public void bearerCredentialsProvider() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .credentialsProvider(mockBearerCredentialsProvider) + .build(); + + assertNotNull("Configuration should be created with bearer credentials", config); + assertSame("Credentials provider should match", mockBearerCredentialsProvider, config.getCredentialsProvider()); + } + + @Test + public void basicCredentialsProvider() { + BasicCredentialsProvider provider = mockBasicCredentialsProvider; + + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .credentialsProvider(provider) + .build(); + + assertNotNull("Configuration should be created with basic credentials", config); + assertSame("Credentials provider should match", provider, config.getCredentialsProvider()); + } + + @Test + public void mtlsValues() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .mtls(clientCert, clientPk) + .build(); + + assertNotNull("Configuration should be created with mTLS", config); + assertSame("Client certificate should match", clientCert, config.getClientCert()); + assertSame("Client private key should match", clientPk, config.getClientPk()); + } + + @Test + public void cacert() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .caCert(caCert) + .build(); + + assertNotNull("Configuration should be created with CA certificate", config); + assertSame("CA certificate should match", caCert, config.getCaCert()); + } + + @Test + public void allOptions() { + ProxyConfiguration config = ProxyConfiguration.builder() + .url(VALID_URL) + .credentialsProvider(mockBearerCredentialsProvider) + .mtls(clientCert, clientPk) + .caCert(caCert) + .build(); + + assertNotNull("Configuration should be created with all options", config); + assertEquals("URL should match", VALID_URL, config.getUrl().toString()); + assertSame("Credentials provider should match", mockBearerCredentialsProvider, config.getCredentialsProvider()); + assertSame("Client certificate should match", clientCert, config.getClientCert()); + assertSame("Client private key should match", clientPk, config.getClientPk()); + assertSame("CA certificate should match", caCert, config.getCaCert()); + } + + @Test + public void buildWithoutUrlReturnsNonNullConfig() { + ProxyConfiguration config = ProxyConfiguration.builder() + .credentialsProvider(mockBearerCredentialsProvider) + .build(); + + assertNotNull("Configuration should be created with null URL", config); + } +} diff --git a/src/test/java/io/split/android/client/network/ProxySslSocketFactoryProviderImplTest.java b/src/test/java/io/split/android/client/network/ProxySslSocketFactoryProviderImplTest.java new file mode 100644 index 000000000..b4d88868a --- /dev/null +++ b/src/test/java/io/split/android/client/network/ProxySslSocketFactoryProviderImplTest.java @@ -0,0 +1,139 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertNotNull; + +import androidx.annotation.NonNull; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileWriter; +import java.util.Base64; + +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.tls.HeldCertificate; + +public class ProxySslSocketFactoryProviderImplTest { + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + private final Base64Decoder mBase64Decoder = new Base64Decoder() { + @Override + public byte[] decode(String base64) { + return Base64.getDecoder().decode(base64); + } + }; + + @Test + public void creatingWithValidCaCertCreatesSocketFactory() throws Exception { + HeldCertificate ca = getCaCert(); + File caCertFile = tempFolder.newFile("held-ca.pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write(ca.certificatePem()); + } + ProxySslSocketFactoryProviderImpl provider = getProvider(); + try (FileInputStream fis = new FileInputStream(caCertFile)) { + SSLSocketFactory socketFactory = provider.create(fis); + assertNotNull(socketFactory); + } + } + + @Test(expected = Exception.class) + public void creatingWithInvalidCaCertThrows() throws Exception { + File caCertFile = tempFolder.newFile("invalid-ca.pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write("not a cert"); + } + ProxySslSocketFactoryProviderImpl provider = getProvider(); + try (FileInputStream fis = new FileInputStream(caCertFile)) { + provider.create(fis); + } + } + + @Test + public void creatingWithValidMtlsParamsCreatesSocketFactory() throws Exception { + // Create CA cert and client cert & key + HeldCertificate ca = getCaCert(); + HeldCertificate clientCert = getClientCert(ca); + File caCertFile = createCaCertFile(ca); + File clientCertFile = tempFolder.newFile("client.crt"); + File clientKeyFile = tempFolder.newFile("client.key"); + + // Write client certificate and key to separate files + try (FileWriter writer = new FileWriter(clientCertFile)) { + writer.write(clientCert.certificatePem()); + } + try (FileWriter writer = new FileWriter(clientKeyFile)) { + writer.write(clientCert.privateKeyPkcs8Pem()); + } + + // Create socket factory + ProxySslSocketFactoryProviderImpl factory = new ProxySslSocketFactoryProviderImpl(mBase64Decoder); + SSLSocketFactory sslSocketFactory; + try (FileInputStream caCertStream = new FileInputStream(caCertFile); + FileInputStream clientCertStream = new FileInputStream(clientCertFile); + FileInputStream clientKeyStream = new FileInputStream(clientKeyFile)) { + sslSocketFactory = factory.create(caCertStream, clientCertStream, clientKeyStream); + } + + assertNotNull(sslSocketFactory); + } + + @Test(expected = Exception.class) + public void creatingWithInvalidMtlsParamsThrows() throws Exception { + // Create valid CA cert but invalid client cert/key files + HeldCertificate ca = getCaCert(); + File caCertFile = createCaCertFile(ca); + File invalidClientCertFile = tempFolder.newFile("invalid-client.crt"); + File invalidClientKeyFile = tempFolder.newFile("invalid-client.key"); + + // Write invalid data to cert and key files + try (FileWriter writer = new FileWriter(invalidClientCertFile)) { + writer.write("invalid certificate"); + } + try (FileWriter writer = new FileWriter(invalidClientKeyFile)) { + writer.write("invalid key"); + } + + ProxySslSocketFactoryProviderImpl provider = getProvider(); + try (FileInputStream caCertStream = new FileInputStream(caCertFile); + FileInputStream invalidClientCertStream = new FileInputStream(invalidClientCertFile); + FileInputStream invalidClientKeyStream = new FileInputStream(invalidClientKeyFile)) { + provider.create(caCertStream, invalidClientCertStream, invalidClientKeyStream); + } + } + + private File createCaCertFile(HeldCertificate ca) throws Exception { + File caCertFile = tempFolder.newFile("mtls-ca.pem"); + try (FileWriter writer = new FileWriter(caCertFile)) { + writer.write(ca.certificatePem()); + } + return caCertFile; + } + + @NonNull + private static HeldCertificate getCaCert() { + return new HeldCertificate.Builder() + .commonName("Test CA") + .certificateAuthority(0) + .build(); + } + + @NonNull + private static HeldCertificate getClientCert(HeldCertificate ca) { + return new HeldCertificate.Builder() + .commonName("Test Client") + .signedBy(ca) + .build(); + } + + @NonNull + private ProxySslSocketFactoryProviderImpl getProvider() { + return new ProxySslSocketFactoryProviderImpl(mBase64Decoder); + } +} diff --git a/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java b/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java new file mode 100644 index 000000000..203e50570 --- /dev/null +++ b/src/test/java/io/split/android/client/network/RawHttpResponseParserTest.java @@ -0,0 +1,250 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.Socket; +import java.security.cert.Certificate; +import java.util.Objects; + +public class RawHttpResponseParserTest { + + private final Certificate[] mServerCertificates = new Certificate[]{}; + + @Test + public void httpResponseWithValidResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: application/json\r\n" + + "Content-Length: 25\r\n" + + "\r\n" + + "{\"message\":\"Hello World\"}"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + assertEquals("Response data should match", "{\"message\":\"Hello World\"}", response.getData()); + assertTrue("Response should be successful", response.isSuccess()); + } + + @Test + public void responseWithErrorStatusReturnsErrorResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 500 Internal Server Error\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 13\r\n" + + "\r\n" + + "Server Error!"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 500", 500, response.getHttpStatus()); + assertEquals("Response data should match", "Server Error!", response.getData()); + assertFalse("Response should not be successful", response.isSuccess()); + } + + @Test + public void responseWithNoContentLengthReadsUntilEnd() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/plain\r\n" + + "Connection: close\r\n" + + "\r\n" + + "This is response data\r\n" + + "with multiple lines\r\n" + + "until connection closes"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + assertNotNull("Response data should not be null", response.getData()); + assertTrue("Response data should contain expected content", + response.getData().contains("This is response data")); + assertTrue("Response data should contain multiple lines", + response.getData().contains("with multiple lines")); + } + + @Test + public void responseWithNoBodyReturnsEmptyData() throws Exception { + String rawHttpResponse = + "HTTP/1.1 204 No Content\r\n" + + "Content-Length: 0\r\n" + + "\r\n"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 204", 204, response.getHttpStatus()); + assertTrue("Response data should be null or empty", + response.getData() == null || response.getData().isEmpty()); + } + + @Test + public void responseWithInvalidStatusLineThrowsException() throws Exception { + String rawHttpResponse = "INVALID STATUS LINE\r\n\r\n"; + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpResponse(inputStream, mServerCertificates); + fail("Should have thrown exception for invalid status line"); + } catch (IOException e) { + assertTrue("Exception should mention invalid status", + Objects.requireNonNull(e.getMessage()).contains("Invalid HTTP status")); + } + } + + @Test + public void responseWithEmptyStreamThrowsException() throws Exception { + InputStream inputStream = new ByteArrayInputStream(new byte[0]); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpResponse(inputStream, mServerCertificates); + fail("Should have thrown exception for empty stream"); + } catch (IOException e) { + assertTrue("Exception should mention no response", + e.getMessage().contains("No HTTP response")); + } + } + + @Test + public void responseWithChunkedEncodingHandlesCorrectly() throws Exception { + String rawHttpResponse = + // headers + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + // 1st chunk size + "15\r\n" + + // 1st chunk data + "This is chunked data!" + + "\r\n" + + + // 2nd chunk size + "0\r\n" + + "\r\n"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpResponse response = parser.parseHttpResponse(inputStream, mServerCertificates); + + assertNotNull("Response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + assertNotNull("Response data should not be null", response.getData()); + assertTrue("Response data should contain expected content", + response.getData().contains("This is chunked data!")); + } + + // Tests for parseHttpStreamResponse method + + @Test + public void parseHttpStreamResponseWithValidInputReturnsCorrectResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: application/json\r\n" + + "Content-Length: 25\r\n" + + "\r\n" + + "{\"message\":\"Hello World\"}"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + Socket mockTunnelSocket = mock(Socket.class); + Socket mockOriginSocket = mock(Socket.class); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, mockTunnelSocket, mockOriginSocket); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithNullSocketsReturnsValidResponse() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 13\r\n" + + "\r\n" + + "Hello, World!"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, null, null); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithDifferentContentTypeUsesCorrectCharset() throws Exception { + String rawHttpResponse = + "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/html; charset=ISO-8859-1\r\n" + + "Content-Length: 20\r\n" + + "\r\n" + + "Test Page"; + + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("ISO-8859-1")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + HttpStreamResponse response = parser.parseHttpStreamResponse(inputStream, null, null); + + assertNotNull("Stream response should not be null", response); + assertEquals("Status code should be 200", 200, response.getHttpStatus()); + } + + @Test + public void parseHttpStreamResponseWithEmptyStreamThrowsException() throws Exception { + InputStream inputStream = new ByteArrayInputStream(new byte[0]); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpStreamResponse(inputStream, null, null); + fail("Should have thrown exception for empty stream"); + } catch (IOException e) { + assertTrue("Exception should mention no response", + e.getMessage().contains("No HTTP response")); + } + } + + @Test + public void parseHttpStreamResponseWithInvalidStatusLineThrowsException() throws Exception { + String rawHttpResponse = "INVALID STATUS LINE\r\n\r\n"; + InputStream inputStream = new ByteArrayInputStream(rawHttpResponse.getBytes("UTF-8")); + RawHttpResponseParser parser = new RawHttpResponseParser(); + + try { + parser.parseHttpStreamResponse(inputStream, null, null); + fail("Should have thrown exception for invalid status line"); + } catch (IOException e) { + assertTrue("Exception should mention invalid status", + Objects.requireNonNull(e.getMessage()).contains("Invalid HTTP status")); + } + } +} diff --git a/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java new file mode 100644 index 000000000..f6a1157a0 --- /dev/null +++ b/src/test/java/io/split/android/client/network/SslProxyTunnelEstablisherTest.java @@ -0,0 +1,465 @@ +package io.split.android.client.network; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintWriter; +import java.net.HttpRetryException; +import java.net.Socket; +import java.security.KeyStore; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocket; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocketFactory; + +import okhttp3.tls.HeldCertificate; + +public class SslProxyTunnelEstablisherTest { + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + private TestSslProxy testProxy; + private SSLSocketFactory clientSslSocketFactory; + + @Before + public void setUp() throws Exception { + // override the default hostname verifier for testing + HttpsURLConnection.setDefaultHostnameVerifier(new HostnameVerifier() { + @Override + public boolean verify(String hostname, SSLSession sslSession) { + return true; + } + }); + + // Create test certificates + HeldCertificate proxyCa = new HeldCertificate.Builder() + .commonName("Test Proxy CA") + .certificateAuthority(0) + .build(); + HeldCertificate proxyServerCert = new HeldCertificate.Builder() + .commonName("localhost") + .signedBy(proxyCa) + .build(); + + // Create SSL socket factory that trusts the proxy CA + File proxyCaFile = tempFolder.newFile("proxy-ca.pem"); + try (FileWriter writer = new FileWriter(proxyCaFile)) { + writer.write(proxyCa.certificatePem()); + } + + ProxySslSocketFactoryProvider factory = new ProxySslSocketFactoryProviderImpl(); + try (java.io.FileInputStream caInput = new java.io.FileInputStream(proxyCaFile)) { + clientSslSocketFactory = factory.create(caInput); + } + + // Start test SSL proxy + testProxy = new TestSslProxy(0, proxyServerCert); + testProxy.start(); + + // Wait for proxy to start + while (testProxy.getPort() == 0) { + Thread.sleep(10); + } + } + + @After + public void tearDown() throws Exception { + if (testProxy != null) { + testProxy.stop(); + } + } + + @Test + public void establishTunnelWithValidSslProxySucceeds() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + String targetHost = "example.com"; + int targetPort = 443; + BearerCredentialsProvider proxyCredentialsProvider = mock(BearerCredentialsProvider.class); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + targetHost, + targetPort, + clientSslSocketFactory, + proxyCredentialsProvider, + false); + + assertNotNull("Tunnel socket should not be null", tunnelSocket); + assertTrue("Tunnel socket should be connected", tunnelSocket.isConnected()); + + // Verify CONNECT request was sent and successful + assertTrue("Proxy should have received CONNECT request", + testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + assertEquals("CONNECT example.com:443 HTTP/1.1", testProxy.getReceivedConnectLine()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNotTrustedCertificatedThrows() throws Exception { + SSLContext untrustedContext = SSLContext.getInstance("TLS"); + untrustedContext.init(null, null, null); + SSLSocketFactory untrustedSocketFactory = untrustedContext.getSocketFactory(); + BearerCredentialsProvider proxyCredentialsProvider = mock(BearerCredentialsProvider.class); + + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + try { + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + untrustedSocketFactory, + proxyCredentialsProvider, + false); + fail("Should have thrown exception for untrusted certificate"); + } catch (IOException e) { + assertTrue("Exception should be SSL-related", e.getMessage().contains("certification")); + } + } + + @Test + public void establishTunnelWithFailingProxyConnectionThrows() { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + BearerCredentialsProvider proxyCredentialsProvider = mock(BearerCredentialsProvider.class); + + try { + establisher.establishTunnel( + "localhost", + -1234, + "example.com", + 443, + clientSslSocketFactory, + proxyCredentialsProvider, + false); + fail("Should have thrown exception for connection failure"); + } catch (IOException e) { + // The implementation wraps the original exception with a descriptive message + assertTrue(e.getMessage().contains("Failed to establish SSL tunnel")); + } + } + + @Test + public void bearerTokenIsPassedWhenSet() throws IOException, InterruptedException { + // For Bearer token, we don't need to mock the Base64Encoder since it's not used + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BearerCredentialsProvider() { + @Override + public String getToken() { + return "token"; + } + }, + false); + boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); + assertTrue("Proxy should have received authorization header", await); + assertEquals("Proxy-Authorization: Bearer token", testProxy.getReceivedAuthHeader()); + } + + @Test + public void basicAuthIsPassedWhenSet() throws IOException, InterruptedException { + // Create a mock Base64Encoder + Base64Encoder mockEncoder = mock(Base64Encoder.class); + String mockEncodedCredentials = "MOCK_ENCODED_CREDENTIALS"; + when(mockEncoder.encode("username:password")).thenReturn(mockEncodedCredentials); + + // Create SslProxyTunnelEstablisher with the mock encoder + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(mockEncoder); + + establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BasicCredentialsProvider() { + @Override + public String getUsername() { + return "username"; + } + + @Override + public String getPassword() { + return "password"; + } + }, + false); + boolean await = testProxy.getAuthorizationHeaderReceived().await(5, TimeUnit.SECONDS); + assertTrue("Proxy should have received authorization header", await); + + // The expected header should contain the mock encoded credentials + String expectedHeader = "Proxy-Authorization: Basic " + mockEncodedCredentials; + assertEquals(expectedHeader, testProxy.getReceivedAuthHeader()); + } + + @Test + public void establishTunnelWithNullCredentialsProviderDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null, + false); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNullBearerTokenDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BearerCredentialsProvider() { + @Override + public String getToken() { + return null; + } + }, + false); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithEmptyBearerTokenDoesNotAddAuthHeader() throws Exception { + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + Socket tunnelSocket = establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + new BearerCredentialsProvider() { + @Override + public String getToken() { + return ""; + } + }, + false); + + assertNotNull(tunnelSocket); + assertTrue(testProxy.getConnectRequestReceived().await(5, TimeUnit.SECONDS)); + + assertEquals(1, testProxy.getAuthorizationHeaderReceived().getCount()); + + tunnelSocket.close(); + } + + @Test + public void establishTunnelWithNullStatusLineThrowsIOException() { + testProxy.setConnectResponse(null); + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + IOException exception = assertThrows(IOException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null, false)); + + assertNotNull(exception); + } + + @Test + public void establishTunnelWithMalformedStatusLineThrowsIOException() { + testProxy.setConnectResponse("HTTP/1.1"); // Malformed, missing status code + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + IOException exception = assertThrows(IOException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null, + false)); + + assertNotNull(exception); + } + + @Test + public void establishTunnelWithProxyAuthRequiredThrowsHttpRetryException() { + testProxy.setConnectResponse("HTTP/1.1 407 Proxy Authentication Required"); + SslProxyTunnelEstablisher establisher = new SslProxyTunnelEstablisher(); + + HttpRetryException exception = assertThrows(HttpRetryException.class, () -> establisher.establishTunnel( + "localhost", + testProxy.getPort(), + "example.com", + 443, + clientSslSocketFactory, + null, + false)); + + assertEquals(407, exception.responseCode()); + } + + /** + * Test SSL proxy that accepts SSL connections and handles CONNECT requests. + */ + private static class TestSslProxy extends Thread { + private final int mPort; + private final HeldCertificate mServerCert; + private SSLServerSocket mServerSocket; + private final AtomicBoolean mRunning = new AtomicBoolean(true); + private final CountDownLatch mConnectRequestReceived = new CountDownLatch(1); + private final CountDownLatch mAuthorizationHeaderReceived = new CountDownLatch(1); + private final AtomicReference mReceivedConnectLine = new AtomicReference<>(); + private final AtomicReference mReceivedAuthHeader = new AtomicReference<>(); + private final AtomicReference mConnectResponse = new AtomicReference<>("HTTP/1.1 200 Connection established"); + + public TestSslProxy(int port, HeldCertificate serverCert) { + mPort = port; + mServerCert = serverCert; + } + + @Override + public void run() { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + ks.setKeyEntry("key", mServerCert.keyPair().getPrivate(), "password".toCharArray(), + new java.security.cert.Certificate[]{mServerCert.certificate()}); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, "password".toCharArray()); + sslContext.init(kmf.getKeyManagers(), null, null); + + mServerSocket = (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(mPort); + mServerSocket.setWantClientAuth(false); + mServerSocket.setNeedClientAuth(false); + + while (mRunning.get()) { + try { + Socket client = mServerSocket.accept(); + handleClient(client); + } catch (IOException e) { + if (mRunning.get()) { + System.err.println("Error accepting client: " + e.getMessage()); + } + } + } + } catch (Exception e) { + throw new RuntimeException("Failed to start test SSL proxy", e); + } + } + + private void handleClient(Socket client) { + try { + BufferedReader reader = new BufferedReader( + new InputStreamReader(client.getInputStream())); + PrintWriter writer = new PrintWriter(client.getOutputStream(), true); + + // Read CONNECT request + String line = reader.readLine(); + if (line != null && line.startsWith("CONNECT")) { + mReceivedConnectLine.set(line); + mConnectRequestReceived.countDown(); + + while((line = reader.readLine()) != null && !line.isEmpty()) { + if (line.contains("Authorization") && (line.contains("Bearer") || line.contains("Basic"))) { + mAuthorizationHeaderReceived.countDown(); + mReceivedAuthHeader.set(line); + } + } + + // Send configured CONNECT response + String response = mConnectResponse.get(); + if (response != null) { + writer.println(response); + writer.println(); + writer.flush(); + } + + // Keep connection open for tunnel + Thread.sleep(100); + } + } catch (Exception e) { + System.err.println("Error handling client: " + e.getMessage()); + } finally { + try { + client.close(); + } catch (IOException e) { + // Ignore + } + } + } + + public int getPort() { + return mServerSocket != null ? mServerSocket.getLocalPort() : 0; + } + + public CountDownLatch getConnectRequestReceived() { + return mConnectRequestReceived; + } + + public CountDownLatch getAuthorizationHeaderReceived() { + return mAuthorizationHeaderReceived; + } + + public String getReceivedConnectLine() { + return mReceivedConnectLine.get(); + } + + public String getReceivedAuthHeader() { + return mReceivedAuthHeader.get(); + } + + public void setConnectResponse(String connectResponse) { + mConnectResponse.set(connectResponse); + } + } +} diff --git a/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java b/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java index 1ab36656e..eeb53f2e1 100644 --- a/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java +++ b/src/test/java/io/split/android/client/service/sseclient/SseClientTest.java @@ -1,5 +1,13 @@ package io.split.android.client.service.sseclient; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -29,14 +37,6 @@ import io.split.android.client.service.sseclient.sseclient.SseHandler; import io.split.sharedtest.fake.HttpStreamResponseMock; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - public class SseClientTest { @Mock @@ -66,7 +66,7 @@ public void setup() throws URISyntaxException { } @Test - public void onConnect() throws InterruptedException, HttpException { + public void onConnect() throws InterruptedException, HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); TestConnListener connListener = spy(new TestConnListener(onOpenLatch)); @@ -87,7 +87,7 @@ public void onConnect() throws InterruptedException, HttpException { } @Test - public void onConnectNotConfirmed() throws InterruptedException, HttpException { + public void onConnectNotConfirmed() throws InterruptedException, HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); TestConnListener connListener = spy(new TestConnListener(onOpenLatch)); @@ -268,7 +268,7 @@ public void onConnectionSuccess() { } @Test - public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() throws HttpException { + public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() throws HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); BufferedReader reader = Mockito.mock(BufferedReader.class); @@ -286,7 +286,7 @@ public void nonRetryableErrorWhenRequestFailsWithHttpExceptionWith9009Code() thr } @Test - public void retryableErrorWhenRequestFailsWithHttpExceptionWithNullCode() throws HttpException { + public void retryableErrorWhenRequestFailsWithHttpExceptionWithNullCode() throws HttpException, IOException { CountDownLatch onOpenLatch = new CountDownLatch(1); BufferedReader reader = Mockito.mock(BufferedReader.class); diff --git a/src/test/java/io/split/android/client/service/workmanager/HttpClientProviderTest.java b/src/test/java/io/split/android/client/service/workmanager/HttpClientProviderTest.java new file mode 100644 index 000000000..2f215febe --- /dev/null +++ b/src/test/java/io/split/android/client/service/workmanager/HttpClientProviderTest.java @@ -0,0 +1,177 @@ +package io.split.android.client.service.workmanager; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mockStatic; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.junit.MockitoJUnitRunner; + +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.CertificatePinningConfiguration; +import io.split.android.client.network.CertificatePinningConfigurationProvider; +import io.split.android.client.network.HttpClient; +import io.split.android.client.storage.cipher.SplitCipher; +import io.split.android.client.storage.cipher.SplitCipherFactory; +import io.split.android.client.storage.db.SplitRoomDatabase; +import io.split.android.client.storage.db.StorageFactory; +import io.split.android.client.storage.general.GeneralInfoStorage; +import io.split.android.client.utils.HttpProxySerializer; + +@RunWith(MockitoJUnitRunner.class) +public class HttpClientProviderTest { + + @Mock + private SplitRoomDatabase mockDatabase; + + @Mock + private GeneralInfoStorage mockGeneralInfoStorage; + + @Mock + private SplitCipher mockSplitCipher; + + @Mock + private CertificatePinningConfiguration mockCertPinningConfig; + + @Mock + private HttpProxyDto mockHttpProxyDto; + + private static final String TEST_API_KEY = "test-api-key"; + private static final String TEST_CERT_PINNING_CONFIG = "{\"pins\":[]}"; + + @Test + public void shouldBuildHttpClientWithNullCertificatePinningConfig() { + HttpClient result = buildHttpClientWithMocks(null, false, null); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithValidCertificatePinningConfig() { + HttpClient result = buildHttpClientWithCertPinningMocks(false); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithValidProxyConfig() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWhenProxyConfigProvidedButDtoIsNull() { + HttpClient result = buildHttpClientWithMocks(null, true, null); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithProxyBasicAuth() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + mockHttpProxyDto.username = "testuser"; + mockHttpProxyDto.password = "testpass"; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithProxyBearerToken() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + mockHttpProxyDto.bearerToken = "test-bearer-token"; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithProxyMtlsAuth() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + mockHttpProxyDto.clientCert = "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----"; + mockHttpProxyDto.clientKey = "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----"; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWhenProxyHostIsNull() { + mockHttpProxyDto.host = null; + mockHttpProxyDto.port = 8080; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithEmptyCertificatePinningConfig() { + HttpClient result = buildHttpClientWithMocks("", false, null); + assertNotNull("HttpClient should not be null", result); + } + + @Test + public void shouldBuildHttpClientWithProxyCaCert() { + mockHttpProxyDto.host = "proxy.example.com"; + mockHttpProxyDto.port = 8080; + mockHttpProxyDto.caCert = "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----"; + + HttpClient result = buildHttpClientWithMocks(null, true, mockHttpProxyDto); + assertNotNull("HttpClient should not be null", result); + } + + private void setupCommonMocks(MockedStatic storageFactoryMock, + MockedStatic cipherFactoryMock, + MockedStatic serializerMock, + HttpProxyDto proxyDto) { + cipherFactoryMock.when(() -> SplitCipherFactory.create(TEST_API_KEY, true)) + .thenReturn(mockSplitCipher); + storageFactoryMock.when(() -> StorageFactory.getGeneralInfoStorage(mockDatabase, mockSplitCipher)) + .thenReturn(mockGeneralInfoStorage); + serializerMock.when(() -> HttpProxySerializer.deserialize(mockGeneralInfoStorage)) + .thenReturn(proxyDto); + } + + private HttpClient buildHttpClientWithMocks(String certPinningConfig, boolean usingProxy, HttpProxyDto proxyDto) { + try (MockedStatic storageFactoryMock = mockStatic(StorageFactory.class); + MockedStatic cipherFactoryMock = mockStatic(SplitCipherFactory.class); + MockedStatic serializerMock = mockStatic(HttpProxySerializer.class)) { + + setupCommonMocks(storageFactoryMock, cipherFactoryMock, serializerMock, proxyDto); + + return HttpClientProvider.buildHttpClient( + TEST_API_KEY, + certPinningConfig, + usingProxy, + mockDatabase + ); + } + } + + private HttpClient buildHttpClientWithCertPinningMocks(boolean usingProxy) { + try (MockedStatic storageFactoryMock = mockStatic(StorageFactory.class); + MockedStatic cipherFactoryMock = mockStatic(SplitCipherFactory.class); + MockedStatic serializerMock = mockStatic(HttpProxySerializer.class); + MockedStatic certProviderMock = mockStatic(CertificatePinningConfigurationProvider.class)) { + + setupCommonMocks(storageFactoryMock, cipherFactoryMock, serializerMock, null); + certProviderMock.when(() -> CertificatePinningConfigurationProvider.getCertificatePinningConfiguration(HttpClientProviderTest.TEST_CERT_PINNING_CONFIG)) + .thenReturn(mockCertPinningConfig); + + HttpClient result = HttpClientProvider.buildHttpClient( + TEST_API_KEY, + HttpClientProviderTest.TEST_CERT_PINNING_CONFIG, + usingProxy, + mockDatabase + ); + + certProviderMock.verify(() -> CertificatePinningConfigurationProvider.getCertificatePinningConfiguration(HttpClientProviderTest.TEST_CERT_PINNING_CONFIG)); + return result; + } + } +} diff --git a/src/test/java/io/split/android/client/storage/general/GeneralInfoStorageImplTest.java b/src/test/java/io/split/android/client/storage/general/GeneralInfoStorageImplTest.java index 9a207f2b4..53cae12a2 100644 --- a/src/test/java/io/split/android/client/storage/general/GeneralInfoStorageImplTest.java +++ b/src/test/java/io/split/android/client/storage/general/GeneralInfoStorageImplTest.java @@ -1,6 +1,9 @@ package io.split.android.client.storage.general; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -8,19 +11,45 @@ import org.junit.Before; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.HttpProxy; +import io.split.android.client.network.ProxyCredentialsProvider; +import io.split.android.client.storage.cipher.SplitCipher; import io.split.android.client.storage.db.GeneralInfoDao; import io.split.android.client.storage.db.GeneralInfoEntity; +import io.split.android.client.utils.HttpProxySerializer; public class GeneralInfoStorageImplTest { private GeneralInfoDao mGeneralInfoDao; + private SplitCipher mAlwaysEncryptedSplitCipher; private GeneralInfoStorageImpl mGeneralInfoStorage; @Before public void setUp() { + mAlwaysEncryptedSplitCipher = mock(SplitCipher.class); + when(mAlwaysEncryptedSplitCipher.encrypt(anyString())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + return "encrypted_" + invocation.getArgument(0); + } + }); + when(mAlwaysEncryptedSplitCipher.decrypt(anyString())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + return "decrypted_" + invocation.getArgument(0); + } + }); + mGeneralInfoDao = mock(GeneralInfoDao.class); - mGeneralInfoStorage = new GeneralInfoStorageImpl(mGeneralInfoDao); + mGeneralInfoStorage = new GeneralInfoStorageImpl(mGeneralInfoDao, mAlwaysEncryptedSplitCipher); } @Test @@ -190,4 +219,125 @@ public void setRbsChangeNumberSetsValueOnDao() { verify(mGeneralInfoDao).update(argThat(entity -> entity.getName().equals("rbsChangeNumber") && entity.getLongValue() == 123L)); } + + @Test + public void getProxyConfigReturnsValueFromDao() { + when(mGeneralInfoDao.getByName("proxyConfig")) + .thenReturn(new GeneralInfoEntity("proxyConfig", "encrypted_proxyConfigValue")); + String proxyConfig = mGeneralInfoStorage.getProxyConfig(); + + assertEquals("decrypted_encrypted_proxyConfigValue", proxyConfig); + verify(mGeneralInfoDao).getByName("proxyConfig"); + verify(mAlwaysEncryptedSplitCipher).decrypt("encrypted_proxyConfigValue"); + } + + @Test + public void getProxyConfigReturnsNullIfEntityIsNull() { + when(mGeneralInfoDao.getByName("proxyConfig")).thenReturn(null); + String proxyConfig = mGeneralInfoStorage.getProxyConfig(); + + assertNull(proxyConfig); + } + + @Test + public void setProxyConfigSetsValueOnDao() { + mGeneralInfoStorage.setProxyConfig("proxyConfigValue"); + + verify(mAlwaysEncryptedSplitCipher).encrypt("proxyConfigValue"); + verify(mGeneralInfoDao).update(argThat(entity -> + entity.getName().equals("proxyConfig") && + entity.getStringValue().equals("encrypted_proxyConfigValue"))); + } + + @Test + public void testSerializeAndStoreHttpProxy() { + String testHost = "proxy.example.com"; + int testPort = 8080; + String testUsername = "testuser"; + String testPassword = "testpass"; + String testClientCert = "-----BEGIN CERTIFICATE-----\nMIICertificateContent\n-----END CERTIFICATE-----"; + String testClientKey = "-----BEGIN PRIVATE KEY-----\nMIIKeyContent\n-----END PRIVATE KEY-----"; + String testCaCert = "-----BEGIN CA CERTIFICATE-----\nMIICACertContent\n-----END CA CERTIFICATE-----"; + + InputStream clientCertStream = new ByteArrayInputStream(testClientCert.getBytes(StandardCharsets.UTF_8)); + InputStream clientKeyStream = new ByteArrayInputStream(testClientKey.getBytes(StandardCharsets.UTF_8)); + InputStream caCertStream = new ByteArrayInputStream(testCaCert.getBytes(StandardCharsets.UTF_8)); + + ProxyCredentialsProvider credentialsProvider = mock(ProxyCredentialsProvider.class); + + HttpProxy httpProxy = HttpProxy.newBuilder(testHost, testPort) + .basicAuth(testUsername, testPassword) + .mtls(clientCertStream, clientKeyStream) + .proxyCacert(caCertStream) + .credentialsProvider(credentialsProvider) + .build(); + + String jsonProxy = HttpProxySerializer.serialize(httpProxy); + mGeneralInfoStorage.setProxyConfig(jsonProxy); + + verify(mGeneralInfoDao).update(argThat(entity -> + entity.getName().equals("proxyConfig") && + entity.getStringValue().startsWith("encrypted_"))); + } + + @Test + public void testGetProxyConfig() { + String jsonContent = "{\"host\":\"proxy.example.com\",\"port\":8080,\"username\":\"testuser\",\"password\":\"testpass\",\"client_cert\":\"cert-data\",\"client_key\":\"key-data\",\"ca_cert\":\"ca-data\",\"bearer_token\":\"token\"}"; + when(mAlwaysEncryptedSplitCipher.encrypt(anyString())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + return invocation.getArgument(0); + } + }); + when(mAlwaysEncryptedSplitCipher.decrypt(anyString())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + return invocation.getArgument(0); + } + }); + when(mGeneralInfoDao.getByName("proxyConfig")).thenReturn(new GeneralInfoEntity("proxyConfig", jsonContent)); + + String proxyConfigJson = mGeneralInfoStorage.getProxyConfig(); + + assertNotNull("Proxy config JSON should not be null", proxyConfigJson); + + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + assertNotNull("Deserialized DTO should not be null", dto); + assertEquals("Host should match", "proxy.example.com", dto.host); + assertEquals("Port should match", 8080, dto.port); + assertEquals("Username should match", "testuser", dto.username); + assertEquals("Password should match", "testpass", dto.password); + assertEquals("Client cert should match", "cert-data", dto.clientCert); + assertEquals("Client key should match", "key-data", dto.clientKey); + assertEquals("CA cert should match", "ca-data", dto.caCert); + assertEquals("token", dto.bearerToken); + } + + @Test + public void proxyConfigIsNullWhenStoredDataIsNull() { + when(mGeneralInfoDao.getByName("proxyConfig")).thenReturn(null); + + String proxyConfig = mGeneralInfoStorage.getProxyConfig(); + + assertNull("Proxy config should be null when entity is null", proxyConfig); + } + + @Test + public void proxyConfigIsNullWhenTheStoredValueIsNull() { + GeneralInfoEntity entity = new GeneralInfoEntity("proxyConfig", (String) null); + when(mGeneralInfoDao.getByName("proxyConfig")).thenReturn(entity); + + String proxyConfig = mGeneralInfoStorage.getProxyConfig(); + + assertNull("Proxy config should be null when entity value is null", proxyConfig); + } + + @Test + public void proxyConfigCanBeSetToNull() { + mGeneralInfoStorage.setProxyConfig(null); + + verify(mGeneralInfoDao).update(argThat(entity -> + entity.getName().equals("proxyConfig") && + entity.getStringValue() == null)); + } } diff --git a/src/test/java/io/split/android/client/utils/HttpProxySerializerTest.java b/src/test/java/io/split/android/client/utils/HttpProxySerializerTest.java new file mode 100644 index 000000000..3f1052be0 --- /dev/null +++ b/src/test/java/io/split/android/client/utils/HttpProxySerializerTest.java @@ -0,0 +1,115 @@ +package io.split.android.client.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + +import io.split.android.client.dtos.HttpProxyDto; +import io.split.android.client.network.BasicCredentialsProvider; +import io.split.android.client.network.HttpProxy; +import io.split.android.client.network.ProxyCredentialsProvider; +import io.split.android.client.storage.general.GeneralInfoStorage; + +public class HttpProxySerializerTest { + + private HttpProxy mHttpProxy; + private GeneralInfoStorage mGeneralInfoStorage; + private final String TEST_HOST = "proxy.example.com"; + private final int TEST_PORT = 8080; + private final String TEST_USERNAME = "testuser"; + private final String TEST_PASSWORD = "testpass"; + private final String TEST_CLIENT_CERT = "-----BEGIN CERTIFICATE-----\nMIICertificateContent\n-----END CERTIFICATE-----"; + private final String TEST_CLIENT_KEY = "-----BEGIN PRIVATE KEY-----\nMIIKeyContent\n-----END PRIVATE KEY-----"; + private final String TEST_CA_CERT = "-----BEGIN CA CERTIFICATE-----\nMIICACertContent\n-----END CA CERTIFICATE-----"; + + @Before + public void setUp() { + mGeneralInfoStorage = mock(GeneralInfoStorage.class); + + // Create input streams from test strings + InputStream clientCertStream = new ByteArrayInputStream(TEST_CLIENT_CERT.getBytes(StandardCharsets.UTF_8)); + InputStream clientKeyStream = new ByteArrayInputStream(TEST_CLIENT_KEY.getBytes(StandardCharsets.UTF_8)); + InputStream caCertStream = new ByteArrayInputStream(TEST_CA_CERT.getBytes(StandardCharsets.UTF_8)); + + // Mock the credentials provider + ProxyCredentialsProvider credentialsProvider = new BasicCredentialsProvider() { + @Override + public String getUsername() { + return TEST_USERNAME; + } + + @Override + public String getPassword() { + return TEST_PASSWORD; + } + }; + + // Create the HttpProxy object + mHttpProxy = HttpProxy.newBuilder(TEST_HOST, TEST_PORT) + .basicAuth(TEST_USERNAME, TEST_PASSWORD) + .mtls(clientCertStream, clientKeyStream) + .proxyCacert(caCertStream) + .credentialsProvider(credentialsProvider) + .build(); + } + + @Test + public void serializeHttpProxyWorks() { + // Serialize the HttpProxy object + String json = HttpProxySerializer.serialize(mHttpProxy); + when(mGeneralInfoStorage.getProxyConfig()).thenReturn(json); + + // Verify the serialization result + assertNotNull("Serialized JSON should not be null", json); + + // Deserialize back to HttpProxyDto + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + + // Verify the deserialized object + assertNotNull("Deserialized DTO should not be null", dto); + assertEquals("Host should match", TEST_HOST, dto.host); + assertEquals("Port should match", TEST_PORT, dto.port); + assertEquals("Username should match", TEST_USERNAME, dto.username); + assertEquals("Password should match", TEST_PASSWORD, dto.password); + assertEquals("Client cert should match", TEST_CLIENT_CERT, dto.clientCert); + assertEquals("Client key should match", TEST_CLIENT_KEY, dto.clientKey); + assertEquals("CA cert should match", TEST_CA_CERT, dto.caCert); + assertNull("Bearer token should be null", dto.bearerToken); + } + + @Test + public void testSerializeNullHttpProxy() { + String json = HttpProxySerializer.serialize(null); + assertNull("Serializing null should return null", json); + } + + @Test + public void deserializeNullJsonReturnsNull() { + when(mGeneralInfoStorage.getProxyConfig()).thenReturn(null); + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + assertNull("Deserializing null should return null", dto); + } + + @Test + public void deserializeEmptyJsonReturnsNull() { + when(mGeneralInfoStorage.getProxyConfig()).thenReturn(""); + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + assertNull("Deserializing empty string should return null", dto); + } + + @Test + public void deserializeInvalidJsonReturnsNull() { + when(mGeneralInfoStorage.getProxyConfig()).thenReturn("{ invalid json }"); + HttpProxyDto dto = HttpProxySerializer.deserialize(mGeneralInfoStorage); + assertNull("Deserializing invalid JSON should return null", dto); + } +} diff --git a/src/test/java/io/split/android/client/utils/SplitClientImplFactory.java b/src/test/java/io/split/android/client/utils/SplitClientImplFactory.java index 1c80f33c5..bc8ab7410 100644 --- a/src/test/java/io/split/android/client/utils/SplitClientImplFactory.java +++ b/src/test/java/io/split/android/client/utils/SplitClientImplFactory.java @@ -42,7 +42,7 @@ public static SplitClientImpl get(Key key, SplitsStorage splitsStorage) { TreatmentManagerFactory treatmentManagerFactory = new TreatmentManagerFactoryImpl( new KeyValidatorImpl(), new SplitValidatorImpl(), new ImpressionListener.FederatedImpressionListener(mock(DecoratedImpressionListener.class), Collections.emptyList()), false, new AttributesMergerImpl(), telemetryStorage, splitParser, - new FlagSetsFilterImpl(Collections.emptySet()), splitsStorage); + new FlagSetsFilterImpl(Collections.emptySet()), splitsStorage, null); AttributesManager attributesManager = mock(AttributesManager.class); SplitClientImpl c = new SplitClientImpl( diff --git a/src/test/java/io/split/android/client/utils/logger/LogPrinterStub.java b/src/test/java/io/split/android/client/utils/logger/LogPrinterStub.java index a63809dee..e8c53ef3d 100644 --- a/src/test/java/io/split/android/client/utils/logger/LogPrinterStub.java +++ b/src/test/java/io/split/android/client/utils/logger/LogPrinterStub.java @@ -2,43 +2,65 @@ import android.util.Log; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; public class LogPrinterStub implements LogPrinter { private final Set calls = new HashSet<>(); + private final Map> logs = new ConcurrentHashMap<>(); + + public LogPrinterStub() { + // Initialize for all Android log levels: VERBOSE(2) .. ASSERT(7) + for (int level = Log.VERBOSE; level <= Log.ASSERT; level++) { + logs.put(level, new ConcurrentLinkedDeque<>()); + } + } @Override public void v(String tag, String msg, Throwable tr) { + logs.get(Log.VERBOSE).add(msg); calls.add(Log.VERBOSE); } @Override public void d(String tag, String msg, Throwable tr) { + logs.get(Log.DEBUG).add(msg); calls.add(Log.DEBUG); } @Override public void i(String tag, String msg, Throwable tr) { + logs.get(Log.INFO).add(msg); calls.add(Log.INFO); } @Override public void w(String tag, String msg, Throwable tr) { + logs.get(Log.WARN).add(msg); calls.add(Log.WARN); } @Override public void e(String tag, String msg, Throwable tr) { + logs.get(Log.ERROR).add(msg); calls.add(Log.ERROR); } @Override public void wtf(String tag, String msg, Throwable tr) { + logs.get(Log.ASSERT).add(msg); calls.add(Log.ASSERT); } public boolean isCalled(Integer type) { return calls.contains(type); } + + public Map> getLoggedMessages() { + return new HashMap<>(logs); + } } diff --git a/src/test/java/io/split/android/client/validators/TreatmentManagerFactoryImplTest.java b/src/test/java/io/split/android/client/validators/TreatmentManagerFactoryImplTest.java new file mode 100644 index 000000000..dd2201dd4 --- /dev/null +++ b/src/test/java/io/split/android/client/validators/TreatmentManagerFactoryImplTest.java @@ -0,0 +1,46 @@ +package io.split.android.client.validators; + +import static org.junit.Assert.assertNotNull; +import static org.mockito.Mockito.mock; + +import androidx.annotation.NonNull; + +import org.junit.Test; + +import io.split.android.client.FlagSetsFilter; +import io.split.android.client.attributes.AttributesMerger; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.impressions.ImpressionListener; +import io.split.android.client.storage.splits.SplitsStorage; +import io.split.android.client.telemetry.storage.TelemetryStorage; +import io.split.android.engine.experiments.SplitParser; + +public class TreatmentManagerFactoryImplTest { + @Test + public void instantiateWithNullFallbackTreatmentsConfigDoesNotThrow() { + TreatmentManagerFactoryImpl treatmentManagerFactory = instantiate(null); + + assertNotNull(treatmentManagerFactory); + } + + @Test + public void instantiateWithNullByFactoryFallbackTreatmentsConfigDoesNotThrow() { + TreatmentManagerFactoryImpl treatmentManagerFactory = instantiate(FallbackTreatmentsConfiguration.builder().build()); + + assertNotNull(treatmentManagerFactory); + } + + @NonNull + private static TreatmentManagerFactoryImpl instantiate(FallbackTreatmentsConfiguration fallbackTreatments) { + return new TreatmentManagerFactoryImpl(mock(KeyValidator.class), + mock(SplitValidator.class), + mock(ImpressionListener.FederatedImpressionListener.class), + true, + mock(AttributesMerger.class), + mock(TelemetryStorage.class), + mock(SplitParser.class), + mock(FlagSetsFilter.class), + mock(SplitsStorage.class), + fallbackTreatments); + } +} diff --git a/src/test/java/io/split/android/client/validators/TreatmentManagerFallbackTreatmentsTest.java b/src/test/java/io/split/android/client/validators/TreatmentManagerFallbackTreatmentsTest.java new file mode 100644 index 000000000..1e7990603 --- /dev/null +++ b/src/test/java/io/split/android/client/validators/TreatmentManagerFallbackTreatmentsTest.java @@ -0,0 +1,265 @@ +package io.split.android.client.validators; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.Test; +import org.mockito.Mockito; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import io.split.android.client.TreatmentLabels; +import io.split.android.client.EvaluationResult; +import io.split.android.client.Evaluator; +import io.split.android.client.EvaluatorImpl; +import io.split.android.client.SplitResult; +import io.split.android.client.attributes.AttributesManager; +import io.split.android.client.attributes.AttributesMerger; +import io.split.android.client.dtos.Split; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatment; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; +import io.split.android.client.impressions.DecoratedImpression; +import io.split.android.client.impressions.Impression; +import io.split.android.client.impressions.ImpressionListener; +import io.split.android.client.telemetry.model.Method; +import io.split.android.client.telemetry.storage.TelemetryStorageProducer; +import io.split.android.client.events.ListenableEventsManager; +import io.split.android.client.events.SplitEvent; +import io.split.android.client.FlagSetsFilter; +import io.split.android.client.storage.splits.SplitsStorage; +import io.split.android.engine.experiments.SplitParser; + +public class TreatmentManagerFallbackTreatmentsTest { + + private static final String FLAG = "missing_flag"; + + @Test + public void evaluatorDefinitionNotFoundUsesFallback() { + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("FALLBACK_TREATMENT", "{\"k\":1}")) + .build(); + FallbackTreatmentsCalculator calc = new FallbackTreatmentsCalculatorImpl(cfg); + + SplitsStorage splitsStorage = Mockito.mock(SplitsStorage.class); + SplitParser splitParser = Mockito.mock(SplitParser.class); + when(splitsStorage.get(FLAG)).thenReturn(null); // definition not found + when(splitParser.parse(null, "m")).thenReturn(null); + + EvaluatorImpl evaluator = new EvaluatorImpl(splitsStorage, splitParser, calc); + + EvaluationResult res = evaluator.getTreatment("m", null, FLAG, Collections.emptyMap()); + + assertEquals("FALLBACK_TREATMENT", res.getTreatment()); + assertTrue(res.getLabel().startsWith("fallback - ")); + assertEquals("{\"k\":1}", res.getConfigurations()); + } + + @Test + public void evaluatorExceptionUsesFallback() { + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("FALLBACK_TREATMENT_2")) + .build(); + FallbackTreatmentsCalculator calc = new FallbackTreatmentsCalculatorImpl(cfg); + + SplitsStorage splitsStorage = Mockito.mock(SplitsStorage.class); + SplitParser splitParser = Mockito.mock(SplitParser.class); + Split dtoSplit = Mockito.mock(Split.class); + when(splitsStorage.get(FLAG)).thenReturn(dtoSplit); + when(splitParser.parse(Mockito.eq(dtoSplit), Mockito.eq("m"))).thenThrow(new RuntimeException("boom")); + + EvaluatorImpl evaluator = new EvaluatorImpl(splitsStorage, splitParser, calc); + + EvaluationResult res = evaluator.getTreatment("m", null, FLAG, Collections.emptyMap()); + + assertEquals("FALLBACK_TREATMENT_2", res.getTreatment()); + assertTrue(res.getLabel().startsWith("fallback - ")); + } + + @Test + public void helperControlTreatmentsPathUsesFallback() { + FallbackTreatmentsConfiguration cfg = FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("FALLBACK_HELPER", "cfg")) + .build(); + FallbackTreatmentsCalculator calc = new FallbackTreatmentsCalculatorImpl(cfg); + + SplitValidator okValidator = new SplitValidatorImpl(); + ValidationMessageLogger logger = new ValidationMessageLoggerImpl(); + + List names = Arrays.asList(" flag_a ", "flag_b"); + Map out = TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig( + okValidator, + logger, + names, + "test", + TreatmentManagerImpl.ResultTransformer::identity, + calc); + + assertEquals(2, out.size()); + assertEquals("FALLBACK_HELPER", out.get("flag_a").treatment()); + assertEquals("cfg", out.get("flag_a").config()); + assertEquals("FALLBACK_HELPER", out.get("flag_b").treatment()); + } + + @Test + public void treatmentManagerGetTreatmentNullStringUsesFallback() { + String flag = "flag_for_null"; + + Evaluator evaluator = mock(Evaluator.class); + TelemetryStorageProducer telemetry = mock(TelemetryStorageProducer.class); + FallbackTreatmentsCalculator fallbackCalc = mock(FallbackTreatmentsCalculator.class); + Mocks m = Mocks.create(evaluator, telemetry, fallbackCalc); + + when(evaluator.getTreatment(eq("m"), eq("b"), eq(flag), any())) + .thenReturn(new EvaluationResult(null, "label", null, null, false)); + + when(fallbackCalc.resolve(flag)).thenReturn(new FallbackTreatment("FALLBACK_TMT")); + + String out = m.manager.getTreatment(flag, null, null, false); + + assertEquals("FALLBACK_TMT", out); + verify(fallbackCalc, times(1)).resolve(flag); + } + + @Test + public void treatmentManagerGetTreatmentExceptionRecordsTelemetryAndUsesFallback() { + String flag = "flag_for_exception"; + + Evaluator evaluator = mock(Evaluator.class); + TelemetryStorageProducer telemetry = mock(TelemetryStorageProducer.class); + FallbackTreatmentsCalculator fallbackCalc = mock(FallbackTreatmentsCalculator.class); + Mocks m = Mocks.create(evaluator, telemetry, fallbackCalc); + + when(evaluator.getTreatment(eq("m"), eq("b"), eq(flag), any())) + .thenReturn(new EvaluationResult(null, "label", null, null, false)); + + when(fallbackCalc.resolve(flag)) + .thenThrow(new RuntimeException("fail once")) + .thenReturn(new FallbackTreatment("FALLBACK_AFTER_EXCEPTION")); + + String out = m.manager.getTreatment(flag, null, null, false); + + assertEquals("FALLBACK_AFTER_EXCEPTION", out); + verify(telemetry, times(1)).recordException(Method.TREATMENT); + verify(fallbackCalc, times(2)).resolve(flag); + } + + @Test + public void treatmentManagerLabelContainsDefinitionNotFoundTriggersNotFoundPath() { + String flag = "flag_contains_def_not_found"; + + Evaluator evaluator = mock(Evaluator.class); + TelemetryStorageProducer telemetry = mock(TelemetryStorageProducer.class); + FallbackTreatmentsCalculator fallbackCalc = mock(FallbackTreatmentsCalculator.class); + Mocks m = Mocks.create(evaluator, telemetry, fallbackCalc); + + String label = "some prefix - " + TreatmentLabels.DEFINITION_NOT_FOUND + " - some suffix"; + when(evaluator.getTreatment(eq("m"), eq("b"), eq(flag), any())) + .thenReturn(new EvaluationResult("on", label, null, null, false)); + + when(m.splitValidator.splitNotFoundMessage(flag)).thenReturn("not found: " + flag); + + // Invoke getTreatmentWithConfig to go through getTreatmentWithConfigWithoutMetrics path + SplitResult result = m.manager.getTreatmentWithConfig(flag, null, null, false); + + // Ensure treatment is the one provided by evaluator and no impressions are logged + assertEquals("on", result.treatment()); + verify(m.impressions, times(0)).log(Mockito.any(DecoratedImpression.class)); + verify(m.impressions, times(0)).log(Mockito.any(Impression.class)); + + // Ensure we logged the not-found warning by requesting the message from SplitValidator + verify(m.splitValidator, times(1)).splitNotFoundMessage(flag); + } + + private static class Mocks { + final TreatmentManagerImpl manager; + final KeyValidator keyValidator; + final SplitValidator splitValidator; + final ImpressionListener.FederatedImpressionListener impressions; + final ListenableEventsManager events; + final AttributesManager attributesManager; + final AttributesMerger attributesMerger; + final FlagSetsFilter flagSetsFilter; + final SplitsStorage splitsStorage; + final SplitFilterValidator flagSetsValidator; + final PropertyValidator propertyValidator; + + private Mocks(TreatmentManagerImpl manager, + KeyValidator keyValidator, + SplitValidator splitValidator, + ImpressionListener.FederatedImpressionListener impressions, + ListenableEventsManager events, + AttributesManager attributesManager, + AttributesMerger attributesMerger, + FlagSetsFilter flagSetsFilter, + SplitsStorage splitsStorage, + SplitFilterValidator flagSetsValidator, + PropertyValidator propertyValidator) { + this.manager = manager; + this.keyValidator = keyValidator; + this.splitValidator = splitValidator; + this.impressions = impressions; + this.events = events; + this.attributesManager = attributesManager; + this.attributesMerger = attributesMerger; + this.flagSetsFilter = flagSetsFilter; + this.splitsStorage = splitsStorage; + this.flagSetsValidator = flagSetsValidator; + this.propertyValidator = propertyValidator; + } + + static Mocks create(Evaluator evaluator, + TelemetryStorageProducer telemetry, + FallbackTreatmentsCalculator fallbackCalc) { + KeyValidator keyValidator = mock(KeyValidator.class); + SplitValidator splitValidator = mock(SplitValidator.class); + ImpressionListener.FederatedImpressionListener impressions = mock(ImpressionListener.FederatedImpressionListener.class); + ListenableEventsManager events = mock(ListenableEventsManager.class); + AttributesManager attributesManager = mock(AttributesManager.class); + AttributesMerger attributesMerger = mock(AttributesMerger.class); + FlagSetsFilter flagSetsFilter = mock(FlagSetsFilter.class); + SplitsStorage splitsStorage = mock(SplitsStorage.class); + ValidationMessageLogger validationLogger = new ValidationMessageLoggerImpl(); + SplitFilterValidator flagSetsValidator = mock(SplitFilterValidator.class); + PropertyValidator propertyValidator = mock(PropertyValidator.class); + + when(events.eventAlreadyTriggered(SplitEvent.SDK_READY)).thenReturn(true); + when(attributesManager.getAllAttributes()).thenReturn(Collections.emptyMap()); + when(attributesMerger.merge(any(), any())).thenReturn(Collections.emptyMap()); + when(splitValidator.validateName(any())).thenReturn(null); + when(keyValidator.validate(any(), any())).thenReturn(null); + + TreatmentManagerImpl manager = new TreatmentManagerImpl( + "m", + "b", + evaluator, + keyValidator, + splitValidator, + impressions, + true, + events, + attributesManager, + attributesMerger, + telemetry, + flagSetsFilter, + splitsStorage, + validationLogger, + flagSetsValidator, + propertyValidator, + fallbackCalc); + + return new Mocks(manager, keyValidator, splitValidator, impressions, events, attributesManager, + attributesMerger, flagSetsFilter, splitsStorage, flagSetsValidator, propertyValidator); + } + } +} diff --git a/src/test/java/io/split/android/client/validators/TreatmentManagerHelperTest.java b/src/test/java/io/split/android/client/validators/TreatmentManagerHelperTest.java index f833f576f..16c938caf 100644 --- a/src/test/java/io/split/android/client/validators/TreatmentManagerHelperTest.java +++ b/src/test/java/io/split/android/client/validators/TreatmentManagerHelperTest.java @@ -12,6 +12,10 @@ import java.util.Map; import io.split.android.client.SplitResult; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatment; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; public class TreatmentManagerHelperTest { @@ -22,7 +26,11 @@ public void controlTreatmentsForSplitsValidatesSplitsWhenValidatorAndLoggerAreNo when(validator.validateName("split2")).thenReturn(new ValidationErrorInfo(ValidationErrorInfo.ERROR_SOME, "message")); - TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig(validator, logger, Arrays.asList("split1", "split2"), "tag", SplitResult::treatment); + FallbackTreatmentsCalculator calc = new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("control")) + .build()); + + TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig(validator, logger, Arrays.asList("split1", "split2"), "tag", SplitResult::treatment, calc); verify(validator).validateName("split1"); verify(validator).validateName("split2"); @@ -36,7 +44,11 @@ public void controlTreatmentsForSplitsWithConfigValidatesSplitsWhenValidatorAndL when(validator.validateName("split2")).thenReturn(new ValidationErrorInfo(ValidationErrorInfo.ERROR_SOME, "message")); - Map result = TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig(validator, logger, Arrays.asList("split1", "split2"), "tag", TreatmentManagerImpl.ResultTransformer::identity); + FallbackTreatmentsCalculator calc = new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("control")) + .build()); + + Map result = TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig(validator, logger, Arrays.asList("split1", "split2"), "tag", TreatmentManagerImpl.ResultTransformer::identity, calc); verify(validator).validateName("split1"); verify(validator).validateName("split2"); @@ -50,7 +62,11 @@ public void controlTreatmentsForSplitsWithConfigOnlyAddsValueForValidSplits() { when(validator.validateName("split2")).thenReturn(new ValidationErrorInfo(ValidationErrorInfo.ERROR_SOME, "message")); - Map result = TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig(validator, logger, Arrays.asList("split1", "split2"), "tag", TreatmentManagerImpl.ResultTransformer::identity); + FallbackTreatmentsCalculator calc = new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("control")) + .build()); + + Map result = TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig(validator, logger, Arrays.asList("split1", "split2"), "tag", TreatmentManagerImpl.ResultTransformer::identity, calc); verify(validator).validateName("split1"); verify(validator).validateName("split2"); @@ -67,7 +83,11 @@ public void controlTreatmentsForSplitsOnlyAddsValuesForValidSplits() { when(validator.validateName("split2")).thenReturn(new ValidationErrorInfo(ValidationErrorInfo.ERROR_SOME, "message")); - Map result = TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig(validator, logger, Arrays.asList("split1", "split2"), "tag", SplitResult::treatment); + FallbackTreatmentsCalculator calc = new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder() + .global(new FallbackTreatment("control")) + .build()); + + Map result = TreatmentManagerHelper.controlTreatmentsForSplitsWithConfig(validator, logger, Arrays.asList("split1", "split2"), "tag", SplitResult::treatment, calc); verify(validator).validateName("split1"); verify(validator).validateName("split2"); diff --git a/src/test/java/io/split/android/engine/splitter/HashConsistencyTest.java b/src/test/java/io/split/android/engine/splitter/HashConsistencyTest.java index 5f62ed0a7..fb0d16f5b 100644 --- a/src/test/java/io/split/android/engine/splitter/HashConsistencyTest.java +++ b/src/test/java/io/split/android/engine/splitter/HashConsistencyTest.java @@ -1,8 +1,8 @@ package io.split.android.engine.splitter; import com.google.common.hash.Hashing; + import org.junit.Assert; -import org.junit.Ignore; import org.junit.Test; import java.io.BufferedReader; @@ -13,7 +13,6 @@ import java.nio.charset.Charset; import io.split.android.client.utils.MurmurHash3; -import io.split.android.engine.splitter.Splitter; @SuppressWarnings({"UnstableApiUsage", "ConstantConditions"}) public class HashConsistencyTest { @@ -24,14 +23,6 @@ public void testLegacyHashAlphaNum() throws IOException { validateFileLegacyHash(file); } - @Test - @Ignore - public void testLegacyHashNonAlphaNum() throws IOException { - URL resource = getClass().getClassLoader().getResource("legacy-hash-sample-data-non-alpha-numeric.csv"); - File file = new File(resource.getFile()); - validateFileLegacyHash(file); - } - @Test public void testMurmur3HashAlphaNum() throws IOException { URL resource = getClass().getClassLoader().getResource("murmur3-sample-data-v2.csv"); diff --git a/src/test/java/io/split/android/engine/splitter/SplitterTest.java b/src/test/java/io/split/android/engine/splitter/SplitterTest.java index 9ab9aa090..32826214b 100644 --- a/src/test/java/io/split/android/engine/splitter/SplitterTest.java +++ b/src/test/java/io/split/android/engine/splitter/SplitterTest.java @@ -2,105 +2,31 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; - -import com.google.common.base.Joiner; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import org.apache.commons.lang3.RandomStringUtils; -import org.junit.Ignore; import org.junit.Test; +import org.mockito.Mockito; -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileReader; -import java.io.FileWriter; -import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Random; import io.split.android.client.dtos.Partition; +import io.split.android.client.fallback.FallbackTreatmentsConfiguration; +import io.split.android.client.fallback.FallbackTreatment; +import io.split.android.client.fallback.FallbackTreatmentsCalculator; +import io.split.android.client.fallback.FallbackTreatmentsCalculatorImpl; /** * Test for Splitter. */ public class SplitterTest { - @Ignore - @Test - public void generateData() { - Random r = new Random(); - int minKeyLength = 7; - - for (int j = 0; j < 100; j++) { - int seed = r.nextInt(); - for (int i = 0; i < 1000; i++) { - int keyLength = minKeyLength + r.nextInt(13); - String key = RandomStringUtils.randomAlphanumeric(keyLength); - long hash = Splitter.hash(key, seed, 1); - int bucket = Splitter.bucket(hash); - System.out.println(Joiner.on(',').join(Arrays.asList(seed, key, hash, bucket))); - } - } - - } - - @Ignore - @Test - public void generateNonAlphaNumericData() { - Random r = new Random(); - int minKeyLength = 7; - - for (int j = 0; j < 100; j++) { - int seed = r.nextInt(); - for (int i = 0; i < 1000; i++) { - int keyLength = minKeyLength + r.nextInt(13); - String key = RandomStringUtils.random(keyLength); - long hash = Splitter.hash(key, seed, 1); - int bucket = Splitter.bucket(hash); - System.out.println(Joiner.on(',').join(Arrays.asList(seed, key, hash, bucket))); - } - } - - } - - /** - * Use this utily method when algos changes are required and you need to - * generate another sample file using existing seed and key input from - * another file - * - * @throws IOException - */ - @Ignore - @Test - public void generateDataFromExistingInput() throws IOException { - File file = new File("src/test/resources", "murmur3-sample-data-non-alpha-numeric.csv"); - BufferedReader reader = new BufferedReader(new FileReader(file)); - reader.readLine(); // Header - - File target = new File("src/test/resources", "murmur3-sample-data-non-alpha-numeric-v2.csv"); - BufferedWriter writer = new BufferedWriter(new FileWriter(target)); - - // Writer header. - writer.append("# seed, key, hash, bucket"); - writer.newLine(); - - String line; - while ((line = reader.readLine()) != null) { - String[] parts = line.split(","); - Integer seed = Integer.parseInt(parts[0]); - String key = parts[1]; - long hash = Splitter.hash(key, seed, 1); - int bucket = Splitter.bucket(hash); - writer.append(Joiner.on(',').join(Arrays.asList(seed, key, hash, bucket))); - writer.newLine(); - } - writer.close(); - } - @Test public void works() { List partitions = new ArrayList<>(); @@ -115,7 +41,7 @@ public void works() { for (int i = 0; i < n; i++) { String key = RandomStringUtils.random(20); - String treatment = Splitter.getTreatment(key, 123, partitions, 1); + String treatment = Splitter.getTreatment(key, 123, partitions, 1, new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())); treatments[Integer.parseInt(treatment) - 1]++; } @@ -136,7 +62,17 @@ public void ifHundredPercentOneTreatmentWeShortcut() { List partitions = Collections.singletonList(partition); - assertThat(Splitter.getTreatment("13", 15, partitions, 1), is(equalTo("on"))); + assertThat(Splitter.getTreatment("13", 15, partitions, 1, new FallbackTreatmentsCalculatorImpl(FallbackTreatmentsConfiguration.builder().build())), is(equalTo("on"))); + } + + @Test + public void ifNoPartitionsWeReturnGetValueFromFallbackCalculator() { + FallbackTreatmentsCalculator calculator = Mockito.mock(FallbackTreatmentsCalculator.class); + + when(calculator.resolve(anyString())).thenReturn(new FallbackTreatment("on")); + + assertEquals("on", Splitter.getTreatment("13", 15, Collections.emptyList(), 1, calculator)); + verify(calculator).resolve("13"); } private Partition partition(String treatment, int size) {