diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/AbstractProxyExchange.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/AbstractProxyExchange.java index b5171634e2..a45810aa39 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/AbstractProxyExchange.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/common/AbstractProxyExchange.java @@ -19,9 +19,11 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.List; import org.springframework.cloud.gateway.server.mvc.config.GatewayMvcProperties; import org.springframework.cloud.gateway.server.mvc.handler.ProxyExchange; +import org.springframework.http.MediaType; import org.springframework.http.client.ClientHttpResponse; import org.springframework.util.Assert; import org.springframework.util.StreamUtils; @@ -42,7 +44,7 @@ protected int copyResponseBody(ClientHttpResponse clientResponse, InputStream in int transferredBytes; - if (properties.getStreamingMediaTypes().contains(clientResponse.getHeaders().getContentType())) { + if (isStreamingMediaType(properties.getStreamingMediaTypes(), clientResponse.getHeaders().getContentType())) { transferredBytes = copyResponseBodyWithFlushing(inputStream, outputStream); } else { @@ -52,6 +54,15 @@ protected int copyResponseBody(ClientHttpResponse clientResponse, InputStream in return transferredBytes; } + private static boolean isStreamingMediaType(List streamingMediaTypes, MediaType mediaType) { + for (var streamingMediaType : streamingMediaTypes) { + if (streamingMediaType.equalsTypeAndSubtype(mediaType)) { + return true; + } + } + return false; + } + private int copyResponseBodyWithFlushing(InputStream inputStream, OutputStream outputStream) throws IOException { int readBytes; var totalReadBytes = 0; diff --git a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/common/AbstractProxyExchangeTests.java b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/common/AbstractProxyExchangeTests.java index fe203b1518..28863df9bb 100644 --- a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/common/AbstractProxyExchangeTests.java +++ b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/common/AbstractProxyExchangeTests.java @@ -69,6 +69,22 @@ public void copyResponseBodyForTextEventStream() throws IOException { verify(outputStream, times(4)).flush(); } + @Test + public void copyResponseBodyForTextEventStreamWithParameter() throws IOException { + MockClientHttpResponse mockResponse = new MockClientHttpResponse(new byte[0], 200); + MediaType mediaType = MediaType.parseMediaType(MediaType.TEXT_EVENT_STREAM_VALUE + ";charset=UTF-8"); + mockResponse.getHeaders().setContentType(mediaType); + + InputStream inputStream = mock(InputStream.class); + when(inputStream.read(any())).thenReturn(1).thenReturn(1).thenReturn(1).thenReturn(-1); + OutputStream outputStream = mock(OutputStream.class); + + int result = new TestProxyExchange().copyResponseBody(mockResponse, inputStream, outputStream); + + assertThat(result).isEqualTo(3); + verify(outputStream, times(4)).flush(); + } + @Test public void copyResponseBodyWithoutContentType() throws IOException { MockClientHttpResponse mockResponse = new MockClientHttpResponse(new byte[0], 200);