Skip to content

Commit

Permalink
Add newAttachMetadataServerInterceptor() MetadataUtil (#11458)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdcormie committed Aug 14, 2024
1 parent 6a9bc3b commit 6dbd1b9
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 0 deletions.
64 changes: 64 additions & 0 deletions stub/src/main/java/io/grpc/stub/MetadataUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ExperimentalApi;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -143,4 +148,63 @@ public void onClose(Status status, Metadata trailers) {
}
}
}

/**
* Returns a ServerInterceptor that adds the specified Metadata to every response stream, one way
* or another.
*
* <p>If, absent this interceptor, a stream would have headers, 'extras' will be added to those
* headers. Otherwise, 'extras' will be sent as trailers. This pattern is useful when you have
* some fixed information, server identity say, that should be included no matter how the call
* turns out. The fallback to trailers avoids artificially committing clients to error responses
* that could otherwise be retried (see https://grpc.io/docs/guides/retry/ for more).
*
* <p>For correct operation, be sure to arrange for this interceptor to run *before* any others
* that might add headers.
*
* @param extras the Metadata to be added to each stream. Caller gives up ownership.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11462")
public static ServerInterceptor newAttachMetadataServerInterceptor(Metadata extras) {
return new MetadataAttachingServerInterceptor(extras);
}

private static final class MetadataAttachingServerInterceptor implements ServerInterceptor {

private final Metadata extras;

MetadataAttachingServerInterceptor(Metadata extras) {
this.extras = extras;
}

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
return next.startCall(new MetadataAttachingServerCall<>(call), headers);
}

final class MetadataAttachingServerCall<ReqT, RespT>
extends SimpleForwardingServerCall<ReqT, RespT> {
boolean headersSent;

MetadataAttachingServerCall(ServerCall<ReqT, RespT> delegate) {
super(delegate);
}

@Override
public void sendHeaders(Metadata headers) {
headers.merge(extras);
headersSent = true;
super.sendHeaders(headers);
}

@Override
public void close(Status status, Metadata trailers) {
if (!headersSent) {
trailers.merge(extras);
}
super.close(status, trailers);
}
}
}
}
175 changes: 175 additions & 0 deletions stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright 2024 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.stub;

import static com.google.common.truth.Truth.assertThat;
import static io.grpc.stub.MetadataUtils.newAttachMetadataServerInterceptor;
import static io.grpc.stub.MetadataUtils.newCaptureMetadataInterceptor;
import static org.junit.Assert.fail;

import com.google.common.collect.ImmutableList;
import io.grpc.CallOptions;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptors;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.StringMarshaller;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.testing.GrpcCleanupRule;
import java.io.IOException;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class MetadataUtilsTest {

@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();

private static final String SERVER_NAME = "test";
private static final Metadata.Key<String> FOO_KEY =
Metadata.Key.of("foo-key", Metadata.ASCII_STRING_MARSHALLER);

private final MethodDescriptor<String, String> echoMethod =
MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE)
.setFullMethodName("test/echo")
.setType(MethodDescriptor.MethodType.UNARY)
.build();

private final ServerCallHandler<String, String> echoCallHandler =
ServerCalls.asyncUnaryCall(
(req, respObserver) -> {
respObserver.onNext(req);
respObserver.onCompleted();
});

MethodDescriptor<String, String> echoServerStreamingMethod =
MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE)
.setFullMethodName("test/echoStream")
.setType(MethodDescriptor.MethodType.SERVER_STREAMING)
.build();

private final AtomicReference<Metadata> trailersCapture = new AtomicReference<>();
private final AtomicReference<Metadata> headersCapture = new AtomicReference<>();

@Test
public void shouldAttachHeadersToResponse() throws IOException {
Metadata extras = new Metadata();
extras.put(FOO_KEY, "foo-value");

ServerServiceDefinition serviceDef =
ServerInterceptors.intercept(
ServerServiceDefinition.builder("test").addMethod(echoMethod, echoCallHandler).build(),
ImmutableList.of(newAttachMetadataServerInterceptor(extras)));

grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start());
ManagedChannel channel =
grpcCleanup.register(
newInProcessChannelBuilder()
.intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture))
.build());

String response =
ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello");
assertThat(response).isEqualTo("hello");
assertThat(trailersCapture.get() == null || !trailersCapture.get().containsKey(FOO_KEY))
.isTrue();
assertThat(headersCapture.get().get(FOO_KEY)).isEqualTo("foo-value");
}

@Test
public void shouldAttachTrailersWhenNoResponse() throws IOException {
Metadata extras = new Metadata();
extras.put(FOO_KEY, "foo-value");

ServerServiceDefinition serviceDef =
ServerInterceptors.intercept(
ServerServiceDefinition.builder("test")
.addMethod(
ServerMethodDefinition.create(
echoServerStreamingMethod,
ServerCalls.asyncUnaryCall(
(req, respObserver) -> respObserver.onCompleted())))
.build(),
ImmutableList.of(newAttachMetadataServerInterceptor(extras)));
grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start());

ManagedChannel channel =
grpcCleanup.register(
newInProcessChannelBuilder()
.intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture))
.build());

Iterator<String> response =
ClientCalls.blockingServerStreamingCall(
channel, echoServerStreamingMethod, CallOptions.DEFAULT, "hello");
assertThat(response.hasNext()).isFalse();
assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue();
assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value");
}

@Test
public void shouldAttachTrailersToErrorResponse() throws IOException {
Metadata extras = new Metadata();
extras.put(FOO_KEY, "foo-value");

ServerServiceDefinition serviceDef =
ServerInterceptors.intercept(
ServerServiceDefinition.builder("test")
.addMethod(
echoMethod,
ServerCalls.asyncUnaryCall(
(req, respObserver) ->
respObserver.onError(Status.INVALID_ARGUMENT.asRuntimeException())))
.build(),
ImmutableList.of(newAttachMetadataServerInterceptor(extras)));
grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start());

ManagedChannel channel =
grpcCleanup.register(
newInProcessChannelBuilder()
.intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture))
.build());
try {
ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello");
fail();
} catch (StatusRuntimeException e) {
assertThat(e.getStatus()).isNotNull();
assertThat(e.getStatus().getCode()).isEqualTo(Code.INVALID_ARGUMENT);
}
assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue();
assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value");
}

private static InProcessServerBuilder newInProcessServerBuilder() {
return InProcessServerBuilder.forName(SERVER_NAME).directExecutor();
}

private static InProcessChannelBuilder newInProcessChannelBuilder() {
return InProcessChannelBuilder.forName(SERVER_NAME).directExecutor();
}
}

0 comments on commit 6dbd1b9

Please sign in to comment.