diff --git a/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java b/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java index f1ebea5c7..367af1b02 100644 --- a/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java +++ b/src/main/java/io/split/android/client/network/HttpResponseConnectionAdapter.java @@ -1,8 +1,11 @@ package io.split.android.client.network; import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import androidx.annotation.VisibleForTesting; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -22,13 +25,17 @@ /** * Adapter that wraps an HttpResponse as an HttpURLConnection. *

- * This is only used to adapt the response from the CONNECT method. + * This is only used to adapt the response from request through the TLS tunnel. */ class HttpResponseConnectionAdapter extends HttpsURLConnection { private final HttpResponse mResponse; private final URL mUrl; private final Certificate[] mServerCertificates; + private OutputStream mOutputStream; + private InputStream mInputStream; + private InputStream mErrorStream; + private boolean mDoOutput = false; /** * Creates an adapter that wraps an HttpResponse as an HttpURLConnection. @@ -38,12 +45,33 @@ class HttpResponseConnectionAdapter extends HttpsURLConnection { * @param serverCertificates The server certificates from the SSL connection */ HttpResponseConnectionAdapter(@NonNull URL url, - @NonNull HttpResponse response, - Certificate[] serverCertificates) { + @NonNull HttpResponse response, + Certificate[] serverCertificates) { + this(url, response, serverCertificates, new ByteArrayOutputStream()); + } + + @VisibleForTesting + HttpResponseConnectionAdapter(@NonNull URL url, + @NonNull HttpResponse response, + Certificate[] serverCertificates, + @NonNull OutputStream outputStream) { + this(url, response, serverCertificates, outputStream, null, null); + } + + @VisibleForTesting + HttpResponseConnectionAdapter(@NonNull URL url, + @NonNull HttpResponse response, + Certificate[] serverCertificates, + @NonNull OutputStream outputStream, + @Nullable InputStream inputStream, + @Nullable InputStream errorStream) { super(url); mUrl = url; mResponse = response; mServerCertificates = serverCertificates; + mOutputStream = outputStream; + mInputStream = inputStream; + mErrorStream = errorStream; } @Override @@ -77,21 +105,27 @@ public InputStream getInputStream() throws IOException { if (mResponse.getHttpStatus() >= 400) { throw new IOException("HTTP " + mResponse.getHttpStatus()); } - String data = mResponse.getData(); - if (data == null) { - data = ""; + if (mInputStream == null) { + String data = mResponse.getData(); + if (data == null) { + data = ""; + } + mInputStream = new ByteArrayInputStream(data.getBytes(StandardCharsets.UTF_8)); } - return new ByteArrayInputStream(data.getBytes(StandardCharsets.UTF_8)); + return mInputStream; } @Override public InputStream getErrorStream() { if (mResponse.getHttpStatus() >= 400) { - String data = mResponse.getData(); - if (data == null) { - data = ""; + if (mErrorStream == null) { + String data = mResponse.getData(); + if (data == null) { + data = ""; + } + mErrorStream = new ByteArrayInputStream(data.getBytes(StandardCharsets.UTF_8)); } - return new ByteArrayInputStream(data.getBytes(StandardCharsets.UTF_8)); + return mErrorStream; } return null; } @@ -108,6 +142,32 @@ public boolean usingProxy() { @Override public void disconnect() { + // Close output stream if it exists + try { + if (mOutputStream != null) { + mOutputStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } + + // Close input stream if it exists + try { + if (mInputStream != null) { + mInputStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } + + // Close error stream if it exists + try { + if (mErrorStream != null) { + mErrorStream.close(); + } + } catch (IOException e) { + // Ignore exception during disconnect + } } // Required abstract method implementations for HTTPS connection @@ -148,11 +208,12 @@ public boolean getInstanceFollowRedirects() { @Override public void setDoOutput(boolean doOutput) { + mDoOutput = doOutput; } @Override public boolean getDoOutput() { - return false; + return mDoOutput; } @Override @@ -350,7 +411,10 @@ public Permission getPermission() throws IOException { @Override public OutputStream getOutputStream() throws IOException { - throw new IOException("Output not supported for SSL proxy responses"); + if (!mDoOutput) { + throw new IOException("Output not enabled for this connection. Call setDoOutput(true) first."); + } + return mOutputStream; } @Override diff --git a/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java b/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java index 5baa35e52..0bc972444 100644 --- a/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java +++ b/src/test/java/io/split/android/client/network/HttpResponseConnectionAdapterTest.java @@ -1,6 +1,7 @@ package io.split.android.client.network; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; @@ -12,6 +13,8 @@ import org.junit.Test; import org.mockito.Mock; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.net.MalformedURLException; @@ -367,8 +370,159 @@ public void urlCanBeRetrieved() { } @Test(expected = IOException.class) - public void getOutputStreamThrows() throws IOException { + public void getOutputStreamThrowsWhenNotEnabled() throws IOException { mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + // Should throw exception since doOutput is not enabled mAdapter.getOutputStream(); } + + @Test + public void setDoOutputEnablesOutput() { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + + // Initially doOutput should be false + assertEquals(false, mAdapter.getDoOutput()); + + // After setting doOutput to true, getDoOutput should return true + mAdapter.setDoOutput(true); + assertEquals(true, mAdapter.getDoOutput()); + } + + @Test + public void getOutputStreamAfterEnablingOutput() throws IOException { + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates); + mAdapter.setDoOutput(true); + + assertNotNull("Output stream should not be null when doOutput is enabled", mAdapter.getOutputStream()); + } + + @Test + public void writeToOutputStream() throws IOException { + // Create a ByteArrayOutputStream to capture the written data + ByteArrayOutputStream testOutputStream = new ByteArrayOutputStream(); + + // Use the constructor that accepts a custom OutputStream + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates, testOutputStream); + mAdapter.setDoOutput(true); + + // Write test data to the output stream + String testData = "Test output data"; + mAdapter.getOutputStream().write(testData.getBytes(StandardCharsets.UTF_8)); + + // Verify that the data was written correctly + assertEquals("Written data should match the input", testData, testOutputStream.toString(StandardCharsets.UTF_8.name())); + } + + @Test + public void disconnectClosesOutputStream() throws IOException { + // Create a custom OutputStream that tracks if it's been closed + TestOutputStream testOutputStream = new TestOutputStream(); + + mAdapter = new HttpResponseConnectionAdapter(mTestUrl, mMockResponse, mTestCertificates, testOutputStream); + mAdapter.setDoOutput(true); + + // Get the output stream and write some data + mAdapter.getOutputStream().write("Test".getBytes(StandardCharsets.UTF_8)); + + // Verify the stream is not closed yet + assertFalse("Output stream should not be closed before disconnect", testOutputStream.isClosed()); + + // Disconnect should close the output stream + mAdapter.disconnect(); + + // Verify the stream was closed + assertTrue("Output stream should be closed after disconnect", testOutputStream.isClosed()); + } + + @Test + public void disconnectClosesInputStream() throws IOException { + // Create a custom InputStream that tracks if it's been closed + TestInputStream testInputStream = new TestInputStream("Test response data".getBytes(StandardCharsets.UTF_8)); + TestOutputStream testOutputStream = new TestOutputStream(); + + // Create adapter with injected test input stream + when(mMockResponse.getHttpStatus()).thenReturn(200); + mAdapter = new HttpResponseConnectionAdapter( + mTestUrl, + mMockResponse, + mTestCertificates, + testOutputStream, + testInputStream, + null); + + // Get the input stream and read some data to simulate usage + InputStream stream = mAdapter.getInputStream(); + byte[] buffer = new byte[10]; + stream.read(buffer); + + // Verify the stream is not closed yet + assertFalse("Input stream should not be closed before disconnect", testInputStream.isClosed()); + + // Disconnect should close the input stream + mAdapter.disconnect(); + + // Verify the stream was closed + assertTrue("Input stream should be closed after disconnect", testInputStream.isClosed()); + } + + /** + * Custom OutputStream implementation for testing that tracks if it's been closed. + */ + private static class TestOutputStream extends ByteArrayOutputStream { + private boolean mClosed = false; + + @Override + public void close() throws IOException { + super.close(); + mClosed = true; + } + + public boolean isClosed() { + return mClosed; + } + } + + private static class TestInputStream extends ByteArrayInputStream { + private boolean mClosed = false; + + public TestInputStream(byte[] data) { + super(data); + } + + @Override + public void close() throws IOException { + super.close(); + mClosed = true; + } + + public boolean isClosed() { + return mClosed; + } + } + + @Test + public void disconnectClosesErrorStream() throws IOException { + TestInputStream testErrorStream = new TestInputStream("Error data".getBytes(StandardCharsets.UTF_8)); + TestOutputStream testOutputStream = new TestOutputStream(); + + when(mMockResponse.getHttpStatus()).thenReturn(404); // Error status + mAdapter = new HttpResponseConnectionAdapter( + mTestUrl, + mMockResponse, + mTestCertificates, + testOutputStream, + null, + testErrorStream); + + // Get the error stream and read some data to simulate usage + InputStream stream = mAdapter.getErrorStream(); + byte[] buffer = new byte[10]; + stream.read(buffer); + + assertFalse("Error stream should not be closed before disconnect", testErrorStream.isClosed()); + + mAdapter.disconnect(); + + assertTrue("Error stream should be closed after disconnect", testErrorStream.isClosed()); + } }