diff --git a/spring-cloud-gateway-integration-tests/grpc/src/main/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplication.java b/spring-cloud-gateway-integration-tests/grpc/src/main/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplication.java index 241a851fd3..44362c3055 100644 --- a/spring-cloud-gateway-integration-tests/grpc/src/main/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplication.java +++ b/spring-cloud-gateway-integration-tests/grpc/src/main/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplication.java @@ -119,6 +119,14 @@ public void hello(HelloRequest request, StreamObserver responseOb HelloResponse response = HelloResponse.newBuilder().setGreeting(greeting).build(); responseObserver.onNext(response); + + if ("failWithRuntimeExceptionAfterData!".equals(request.getFirstName())) { + StatusRuntimeException exception = Status.RESOURCE_EXHAUSTED.withDescription("Too long firstNames?") + .asRuntimeException(); + responseObserver.onError(exception); + return; + } + responseObserver.onCompleted(); } diff --git a/spring-cloud-gateway-integration-tests/grpc/src/test/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplicationTests.java b/spring-cloud-gateway-integration-tests/grpc/src/test/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplicationTests.java index e5a6bbcb43..5d388918c6 100644 --- a/spring-cloud-gateway-integration-tests/grpc/src/test/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplicationTests.java +++ b/spring-cloud-gateway-integration-tests/grpc/src/test/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplicationTests.java @@ -34,6 +34,7 @@ import org.springframework.boot.test.web.server.LocalServerPort; import static io.grpc.Status.FAILED_PRECONDITION; +import static io.grpc.Status.RESOURCE_EXHAUSTED; import static io.grpc.netty.NegotiationType.TLS; import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment; @@ -75,15 +76,35 @@ private ManagedChannel createSecuredChannel(int port) throws SSLException { @Test public void gRPCUnaryCallShouldHandleRuntimeException() throws SSLException { ManagedChannel channel = createSecuredChannel(gatewayPort); + boolean thrown = false; try { HelloServiceGrpc.newBlockingStub(channel) .hello(HelloRequest.newBuilder().setFirstName("failWithRuntimeException!").build()); } catch (StatusRuntimeException e) { - Assertions.assertThat(FAILED_PRECONDITION.getCode()).isEqualTo(e.getStatus().getCode()); - Assertions.assertThat("Invalid firstName").isEqualTo(e.getStatus().getDescription()); + thrown = true; + Assertions.assertThat(e.getStatus().getCode()).isEqualTo(FAILED_PRECONDITION.getCode()); + Assertions.assertThat(e.getStatus().getDescription()).isEqualTo("Invalid firstName"); } + Assertions.assertThat(thrown).withFailMessage("Expected exception not thrown!").isTrue(); + } + + @Test + public void gRPCUnaryCallShouldHandleRuntimeException2() throws SSLException { + ManagedChannel channel = createSecuredChannel(gatewayPort); + boolean thrown = false; + try { + HelloServiceGrpc.newBlockingStub(channel) + .hello(HelloRequest.newBuilder().setFirstName("failWithRuntimeExceptionAfterData!").build()) + .getGreeting(); + } + catch (StatusRuntimeException e) { + thrown = true; + Assertions.assertThat(e.getStatus().getCode()).isEqualTo(RESOURCE_EXHAUSTED.getCode()); + Assertions.assertThat(e.getStatus().getDescription()).isEqualTo("Too long firstNames?"); + } + Assertions.assertThat(thrown).withFailMessage("Expected exception not thrown!").isTrue(); } private TrustManager[] createTrustAllTrustManager() { diff --git a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/headers/GRPCResponseHeadersFilter.java b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/headers/GRPCResponseHeadersFilter.java index 25a45d29ae..0899907076 100644 --- a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/headers/GRPCResponseHeadersFilter.java +++ b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/headers/GRPCResponseHeadersFilter.java @@ -16,6 +16,7 @@ package org.springframework.cloud.gateway.filter.headers; +import reactor.netty.http.client.HttpClientResponse; import reactor.netty.http.server.HttpServerResponse; import org.springframework.core.Ordered; @@ -26,6 +27,8 @@ import org.springframework.util.StringUtils; import org.springframework.web.server.ServerWebExchange; +import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_ATTR; + /** * @author Alberto C. RĂ­os */ @@ -37,45 +40,62 @@ public class GRPCResponseHeadersFilter implements HttpHeadersFilter, Ordered { @Override public HttpHeaders filter(HttpHeaders headers, ServerWebExchange exchange) { - ServerHttpResponse response = exchange.getResponse(); - HttpHeaders responseHeaders = response.getHeaders(); if (isGRPC(exchange)) { - String trailerHeaderValue = GRPC_STATUS_HEADER + "," + GRPC_MESSAGE_HEADER; - String originalTrailerHeaderValue = responseHeaders.getFirst(HttpHeaders.TRAILER); - if (originalTrailerHeaderValue != null) { - trailerHeaderValue += "," + originalTrailerHeaderValue; - } - responseHeaders.set(HttpHeaders.TRAILER, trailerHeaderValue); + ServerHttpResponse response = exchange.getResponse(); + HttpHeaders responseHeaders = response.getHeaders(); - while (response instanceof ServerHttpResponseDecorator) { - response = ((ServerHttpResponseDecorator) response).getDelegate(); + if (headers.containsKey(GRPC_STATUS_HEADER)) { + if (!"0".equals(headers.getFirst(GRPC_STATUS_HEADER))) { + response.setComplete(); // avoid empty DATA frame + } } - if (response instanceof AbstractServerHttpResponse) { - String grpcStatus = getGrpcStatus(headers); - String grpcMessage = getGrpcMessage(headers); - ((HttpServerResponse) ((AbstractServerHttpResponse) response).getNativeResponse()).trailerHeaders(h -> { - h.set(GRPC_STATUS_HEADER, grpcStatus); - h.set(GRPC_MESSAGE_HEADER, grpcMessage); + + HttpClientResponse nettyInResponse = exchange.getAttribute(CLIENT_RESPONSE_ATTR); + if (nettyInResponse != null) { + nettyInResponse.trailerHeaders().subscribe(entries -> { + if (entries.contains(GRPC_STATUS_HEADER)) { + addTrailingHeader(entries, response, responseHeaders); + } }); } - } + return headers; } - private boolean isGRPC(ServerWebExchange exchange) { - String contentTypeValue = exchange.getRequest().getHeaders().getFirst(HttpHeaders.CONTENT_TYPE); - return StringUtils.startsWithIgnoreCase(contentTypeValue, "application/grpc"); + private void addTrailingHeader(io.netty.handler.codec.http.HttpHeaders sourceHeaders, ServerHttpResponse response, + HttpHeaders responseHeaders) { + String trailerHeaderValue = GRPC_STATUS_HEADER + "," + GRPC_MESSAGE_HEADER; + String originalTrailerHeaderValue = responseHeaders.getFirst(HttpHeaders.TRAILER); + if (originalTrailerHeaderValue != null) { + trailerHeaderValue += "," + originalTrailerHeaderValue; + } + responseHeaders.set(HttpHeaders.TRAILER, trailerHeaderValue); + + HttpServerResponse nettyOutResponse = getNettyResponse(response); + if (nettyOutResponse != null) { + String grpcStatus = sourceHeaders.get(GRPC_STATUS_HEADER, "0"); + String grpcMessage = sourceHeaders.get(GRPC_MESSAGE_HEADER, ""); + nettyOutResponse.trailerHeaders(h -> { + h.set(GRPC_STATUS_HEADER, grpcStatus); + h.set(GRPC_MESSAGE_HEADER, grpcMessage); + }); + } } - private String getGrpcStatus(HttpHeaders headers) { - final String grpcStatusValue = headers.getFirst(GRPC_STATUS_HEADER); - return StringUtils.hasText(grpcStatusValue) ? grpcStatusValue : "0"; + private HttpServerResponse getNettyResponse(ServerHttpResponse response) { + while (response instanceof ServerHttpResponseDecorator) { + response = ((ServerHttpResponseDecorator) response).getDelegate(); + } + if (response instanceof AbstractServerHttpResponse) { + return ((AbstractServerHttpResponse) response).getNativeResponse(); + } + return null; } - private String getGrpcMessage(HttpHeaders headers) { - final String grpcStatusValue = headers.getFirst(GRPC_MESSAGE_HEADER); - return StringUtils.hasText(grpcStatusValue) ? grpcStatusValue : ""; + private boolean isGRPC(ServerWebExchange exchange) { + String contentTypeValue = exchange.getRequest().getHeaders().getFirst(HttpHeaders.CONTENT_TYPE); + return StringUtils.startsWithIgnoreCase(contentTypeValue, "application/grpc"); } @Override