diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/AuthzResponse.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/AuthzResponse.java new file mode 100644 index 00000000000..530badb631b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/AuthzResponse.java @@ -0,0 +1,91 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.util.Optional; + +/** + * Represents the outcome of an authorization check, detailing whether the request is allowed or + * denied and including any associated headers or status information. + */ +@AutoValue +public abstract class AuthzResponse { + + /** Defines the authorization decision. */ + public enum Decision { + /** The request is permitted. */ + ALLOW, + /** The request is rejected. */ + DENY, + } + + /** Creates a builder for an ALLOW response, initializing with the specified headers. */ + public static Builder allow(Metadata headers) { + return new AutoValue_AuthzResponse.Builder().setDecision(Decision.ALLOW) + .setResponseHeaderMutations(ResponseHeaderMutations.create(ImmutableList.of())) + .setHeaders(headers); + } + + /** Creates a builder for a DENY response, initializing with the specified status. */ + public static Builder deny(Status status) { + return new AutoValue_AuthzResponse.Builder().setDecision(Decision.DENY) + .setResponseHeaderMutations(ResponseHeaderMutations.create(ImmutableList.of())) + .setStatus(status); + } + + /** Returns the authorization decision. */ + public abstract Decision decision(); + + /** + * For DENY decisions, this provides the status to be returned to the calling client. It is empty + * for ALLOW decisions. + */ + public abstract Optional status(); + + /** + * For ALLOW decisions, this provides the headers to be appended to the request headers for + * upstream. It is empty for DENY decisions. + */ + public abstract Optional headers(); + + /** + * Returns mutations to be applied to the response headers. This is used for both ALLOW and DENY + * decisions. + */ + public abstract ResponseHeaderMutations responseHeaderMutations(); + + /** Builder for creating {@link AuthzResponse} instances. */ + @AutoValue.Builder + public abstract static class Builder { + + abstract Builder setDecision(Decision decision); + + abstract Builder setStatus(Status status); + + abstract Builder setHeaders(Metadata headers); + + public abstract Builder setResponseHeaderMutations( + ResponseHeaderMutations responseHeaderMutations); + + public abstract AuthzResponse build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java new file mode 100644 index 00000000000..55234cd50dc --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java @@ -0,0 +1,316 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Timestamp; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.xds.internal.Matchers; +import java.io.UnsupportedEncodingException; +import java.net.InetSocketAddress; +import java.security.cert.Certificate; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; + +/** + * Interface for building external authorization check requests. + */ +public interface CheckRequestBuilder { + + /** + * A factory for creating {@link CheckRequestBuilder} instances. + */ + @FunctionalInterface + interface Factory { + /** + * Creates a new instance of the CheckRequestBuilder. + * + * @param config The external authorization configuration. + * @param certificateProvider The provider for certificate information. + * @return A new CheckRequestBuilder instance. + */ + CheckRequestBuilder create(ExtAuthzConfig config, + ExtAuthzCertificateProvider certificateProvider); + } + + /** The default factory for creating {@link CheckRequestBuilder} instances. */ + Factory INSTANCE = CheckRequestBuilderImpl::new; + + /** + * Builds a CheckRequest for a server-side call. + * + * @param serverCall The server call. + * @param headers The request headers. + * @param requestTime The time of the request. + * @return A new CheckRequest. + */ + CheckRequest buildRequest(ServerCall serverCall, Metadata headers, Timestamp requestTime); + + /** + * Builds a CheckRequest for a client-side call. + * + * @param methodDescriptor The method descriptor of the call. + * @param headers The request headers. + * @param requestTime The time of the request. + * @return A new CheckRequest. + */ + CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, + Timestamp requestTime); + + /** + * Implementation of the CheckRequestBuilder interface. + */ + final class CheckRequestBuilderImpl implements CheckRequestBuilder { + private static final Logger logger = Logger.getLogger(CheckRequestBuilderImpl.class.getName()); + + private static final String METHOD = "POST"; + private static final String PROTOCOL = "HTTP/2"; + private static final long SIZE = -1; + + private final ExtAuthzConfig config; + private final ExtAuthzCertificateProvider certificateProvider; + + CheckRequestBuilderImpl(ExtAuthzConfig config, + ExtAuthzCertificateProvider certificateProvider) { + this.config = config; + this.certificateProvider = certificateProvider; + } + + @Override + public CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, + Timestamp requestTime) { + return build(CheckRequestParams.builder().setMethodDescriptor(methodDescriptor) + .setHeaders(headers).setRequestTime(requestTime).build()); + } + + @Override + public CheckRequest buildRequest(ServerCall serverCall, Metadata headers, + Timestamp requestTime) { + CheckRequestParams.Builder paramsBuilder = + CheckRequestParams.builder().setMethodDescriptor(serverCall.getMethodDescriptor()) + .setHeaders(headers).setRequestTime(requestTime); + java.net.SocketAddress localAddress = + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); + if (localAddress != null) { + paramsBuilder.setLocalAddress(localAddress); + } + java.net.SocketAddress remoteAddress = + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + if (remoteAddress != null) { + paramsBuilder.setRemoteAddress(remoteAddress); + } + SSLSession sslSession = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_SSL_SESSION); + if (sslSession != null) { + paramsBuilder.setSslSession(sslSession); + } + return build(paramsBuilder.build()); + } + + private CheckRequest build(CheckRequestParams params) { + AttributeContext.Builder attrBuilder = AttributeContext.newBuilder(); + if (params.remoteAddress().isPresent()) { + attrBuilder.setSource(buildSource(params.remoteAddress().get(), params.sslSession())); + } + if (params.localAddress().isPresent()) { + attrBuilder + .setDestination(buildDestination(params.localAddress().get(), params.sslSession())); + } + attrBuilder.setRequest(buildAttributeRequest(params.headers(), + params.methodDescriptor().getFullMethodName(), params.requestTime())); + return CheckRequest.newBuilder().setAttributes(attrBuilder).build(); + } + + private AttributeContext.Peer buildSource(java.net.SocketAddress socketAddress, + Optional sslSession) { + AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); + if (sslSession.isPresent()) { + try { + Certificate[] certs = sslSession.get().getPeerCertificates(); + if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { + X509Certificate cert = (X509Certificate) certs[0]; + peerBuilder.setPrincipal(certificateProvider.getPrincipal(cert)); + if (config.includePeerCertificate()) { + try { + peerBuilder.setCertificate(certificateProvider.getUrlPemEncodedCertificate(cert)); + } catch (UnsupportedEncodingException | CertificateEncodingException e) { + logger.log(Level.WARNING, + "Error encoding peer certificate. " + + "This is not expected, but if it happens, the certificate should not " + + "be set according to the spec.", + e); + } + } + } + } catch (SSLPeerUnverifiedException e) { + logger.log(Level.FINE, + "Peer is not authenticated. " + + "This is expected, principal and certificate should not be set " + + "according to the spec.", + e); + } + } + return peerBuilder.build(); + } + + private AttributeContext.Peer buildDestination(java.net.SocketAddress socketAddress, + Optional sslSession) { + AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); + if (sslSession.isPresent()) { + Certificate[] certs = sslSession.get().getLocalCertificates(); + if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { + peerBuilder.setPrincipal(certificateProvider.getPrincipal((X509Certificate) certs[0])); + } + } + return peerBuilder.build(); + } + + private AttributeContext.Peer buildPeer(java.net.SocketAddress socketAddress) { + AttributeContext.Peer.Builder peerBuilder = AttributeContext.Peer.newBuilder(); + if (socketAddress instanceof InetSocketAddress) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; + peerBuilder.setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress(inetSocketAddress.getAddress().getHostAddress()) + .setPortValue(inetSocketAddress.getPort())) + .build()); + } + return peerBuilder.build(); + } + + private AttributeContext.Request buildAttributeRequest(Metadata headers, String fullMethodName, + Timestamp requestTime) { + AttributeContext.Request.Builder reqBuilder = AttributeContext.Request.newBuilder(); + reqBuilder.setTime(requestTime); + AttributeContext.HttpRequest.Builder httpReqBuilder = + AttributeContext.HttpRequest.newBuilder(); + httpReqBuilder.setPath(fullMethodName); + httpReqBuilder.setMethod(METHOD); + httpReqBuilder.setProtocol(PROTOCOL); + httpReqBuilder.setSize(SIZE); + for (String key : headers.keys()) { + if (!isAllowed(key)) { + continue; + } + Optional value; + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + value = getBinaryHeaderValue(headers, key); + } else { + value = getAsciiHeaderValue(headers, key); + } + value.ifPresent( + headerValue -> httpReqBuilder.putHeaders(key.toLowerCase(Locale.ROOT), headerValue)); + } + reqBuilder.setHttp(httpReqBuilder); + return reqBuilder.build(); + } + + private Optional getBinaryHeaderValue(Metadata headers, String key) { + Iterable binaryValues = + headers.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + if (binaryValues == null) { + // Unreachable code, since we iterate over the keys. Exists for defensive programming. + return Optional.empty(); + } + List base64Values = new ArrayList<>(); + for (byte[] value : binaryValues) { + base64Values.add(BaseEncoding.base64().encode(value)); + } + return Optional.of(String.join(",", base64Values)); + } + + private Optional getAsciiHeaderValue(Metadata headers, String key) { + Iterable stringValues = + headers.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + if (stringValues == null) { + // Unreachable code, since we iterate over the keys. Exists for defensive programming. + return Optional.empty(); + } + return Optional.of(String.join(",", stringValues)); + } + + private boolean isAllowed(String header) { + for (Matchers.StringMatcher matcher : config.disallowedHeaders()) { + if (matcher.matches(header)) { + return false; + } + } + if (config.allowedHeaders().isEmpty()) { + return true; + } + for (Matchers.StringMatcher matcher : config.allowedHeaders()) { + if (matcher.matches(header)) { + return true; + } + } + return false; + } + + @AutoValue + abstract static class CheckRequestParams { + abstract Metadata headers(); + + abstract MethodDescriptor methodDescriptor(); + + abstract Timestamp requestTime(); + + abstract Optional localAddress(); + + abstract Optional remoteAddress(); + + abstract Optional sslSession(); + + static Builder builder() { + Builder builder = + new AutoValue_CheckRequestBuilder_CheckRequestBuilderImpl_CheckRequestParams.Builder(); + return builder; + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setHeaders(Metadata headers); + + abstract Builder setMethodDescriptor(MethodDescriptor method); + + abstract Builder setRequestTime(Timestamp time); + + abstract Builder setLocalAddress(java.net.SocketAddress localAddress); + + abstract Builder setRemoteAddress(java.net.SocketAddress remoteAddress); + + abstract Builder setSslSession(SSLSession sslSession); + + abstract CheckRequestParams build(); + } + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckResponseHandler.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckResponseHandler.java new file mode 100644 index 00000000000..6f03bcd1302 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckResponseHandler.java @@ -0,0 +1,148 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.service.auth.v3.CheckResponse; +import io.envoyproxy.envoy.service.auth.v3.DeniedHttpResponse; +import io.envoyproxy.envoy.service.auth.v3.OkHttpResponse; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import io.grpc.xds.internal.headermutations.HeaderMutationDisallowedException; +import io.grpc.xds.internal.headermutations.HeaderMutationFilter; +import io.grpc.xds.internal.headermutations.HeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutator; + +/** + * Handles the response from the external authorization service, processing it to determine the + * authorization decision and applying any necessary header mutations. + */ +public interface CheckResponseHandler { + + /** + * A factory for creating {@link CheckResponseHandler} instances. + */ + @FunctionalInterface + interface Factory { + /** + * Creates a new ResponseHandler. + * + * @param headerMutator Utility to apply header mutations. + * @param headerMutationFilter Filter to apply to header mutations. + * @param config The external authorization configuration. + */ + CheckResponseHandler create(HeaderMutator headerMutator, + HeaderMutationFilter headerMutationFilter, ExtAuthzConfig config); + } + + /** + * The default factory for creating {@link CheckResponseHandler} instances. + */ + Factory INSTANCE = ResponseHandlerImpl::new; + + /** + * Processes the CheckResponse from the external authorization service. + * + * @param response The response from the authorization service. + * @param headers The request headers, which may be mutated as part of handling the response. + * @return An {@link AuthzResponse} indicating the outcome of the authorization check. + */ + AuthzResponse handleResponse(final CheckResponse response, Metadata headers); + + /** Default implementation of {@link CheckResponseHandler}. */ + static final class ResponseHandlerImpl implements CheckResponseHandler { + private final HeaderMutator headerMutator; + private final HeaderMutationFilter headerMutationFilter; + private final ExtAuthzConfig config; + + ResponseHandlerImpl(HeaderMutator headerMutator, // NOPMD + HeaderMutationFilter headerMutationFilter, ExtAuthzConfig config) { + this.headerMutator = headerMutator; + this.headerMutationFilter = headerMutationFilter; + this.config = config; + } + + @Override + public AuthzResponse handleResponse(final CheckResponse response, Metadata headers) { + try { + if (response.getStatus().getCode() == Status.Code.OK.value()) { + return handleOkResponse(response, headers); + } else { + return handleNotOkResponse(response); + } + } catch (HeaderMutationDisallowedException e) { + return AuthzResponse.deny(e.getStatus()).build(); + } + } + + private AuthzResponse handleOkResponse(final CheckResponse response, Metadata headers) + throws HeaderMutationDisallowedException { + if (!response.hasOkResponse()) { + return AuthzResponse.allow(headers).build(); + } + OkHttpResponse okResponse = response.getOkResponse(); + HeaderMutations requestedMutations = buildHeaderMutationsFromOkResponse(okResponse); + HeaderMutations allowedMutations = headerMutationFilter.filter(requestedMutations); + + applyMutations(allowedMutations, headers); + return AuthzResponse.allow(headers) + .setResponseHeaderMutations(allowedMutations.responseMutations()).build(); + } + + private HeaderMutations buildHeaderMutationsFromOkResponse(OkHttpResponse okResponse) { + return HeaderMutations.create( + HeaderMutations.RequestHeaderMutations.create( + ImmutableList.copyOf(okResponse.getHeadersList()), + ImmutableList.copyOf(okResponse.getHeadersToRemoveList())), + HeaderMutations.ResponseHeaderMutations + .create(ImmutableList.copyOf(okResponse.getResponseHeadersToAddList()))); + } + + private AuthzResponse handleNotOkResponse(CheckResponse response) + throws HeaderMutationDisallowedException { + Status statusToReturn = config.statusOnError(); + if (!response.hasDeniedResponse()) { + return AuthzResponse.deny(statusToReturn).build(); + } + DeniedHttpResponse deniedResponse = response.getDeniedResponse(); + HeaderMutations requestedMutations = buildHeaderMutationsFromDeniedResponse(deniedResponse); + HeaderMutations allowedMutations = headerMutationFilter.filter(requestedMutations); + + Status status = statusToReturn; + if (deniedResponse.hasStatus()) { + status = GrpcUtil.httpStatusToGrpcStatus(deniedResponse.getStatus().getCodeValue()) + .withDescription(deniedResponse.getBody()); + } + return AuthzResponse.deny(status) + .setResponseHeaderMutations(allowedMutations.responseMutations()).build(); + } + + private HeaderMutations buildHeaderMutationsFromDeniedResponse( + DeniedHttpResponse deniedResponse) { + return HeaderMutations.create( + HeaderMutations.RequestHeaderMutations.create(ImmutableList.of(), ImmutableList.of()), + HeaderMutations.ResponseHeaderMutations + .create(ImmutableList.copyOf(deniedResponse.getHeadersList()))); + } + + + private void applyMutations(final HeaderMutations mutations, Metadata headers) { + headerMutator.applyRequestMutations(mutations.requestMutations(), headers); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProvider.java new file mode 100644 index 00000000000..b4ec8dd8303 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProvider.java @@ -0,0 +1,132 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.common.io.BaseEncoding; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * An interface for providing certificate-related information. + */ +public interface ExtAuthzCertificateProvider { + /** + * Creates a new instance of the CertificateProvider. + * + * @return A new CertificateProvider instance. + */ + static ExtAuthzCertificateProvider create() { + return new DefaultCertificateProvider(); + } + + /** + * Gets the principal from a certificate. It returns the cert's first IP Address SAN if set, + * otherwise the cert's first DNS SAN if set, otherwise the subject field of the certificate in + * RFC 2253 format. + * + * @param cert The certificate. + * @return The principal. + */ + String getPrincipal(X509Certificate cert); + + /** + * Gets the URL PEM encoded certificate. It Pem encodes first and then urlencodes. + * + * @param cert The certificate. + * @return The URL PEM encoded certificate. + * @throws CertificateEncodingException If an error occurs while encoding the certificate. + * @throws UnsupportedEncodingException If an error occurs while encoding the URL. + */ + String getUrlPemEncodedCertificate(X509Certificate cert) + throws CertificateEncodingException, UnsupportedEncodingException; + + /** + * Default implementation of the CertificateProvider interface. + */ + final class DefaultCertificateProvider implements ExtAuthzCertificateProvider { + private static final Logger logger = + Logger.getLogger(DefaultCertificateProvider.class.getName()); + // From RFC 5280, section 4.2.1.6, Subject Alternative Name + // dNSName (2) + // iPAddress (7) + private static final int SAN_TYPE_DNS_NAME = 2; + private static final int SAN_TYPE_IP_ADDRESS = 7; + + @Override + public String getPrincipal(X509Certificate cert) { + try { + Collection> sans = cert.getSubjectAlternativeNames(); + if (sans != null) { + // Look for IP Address SAN. + for (List san : sans) { + if (san.size() == 2 && san.get(0) instanceof Integer + && (Integer) san.get(0) == SAN_TYPE_IP_ADDRESS) { + return (String) san.get(1); + } + } + // If no IP Address SAN, look for DNS SAN. + for (List san : sans) { + if (san.size() == 2 && san.get(0) instanceof Integer + && (Integer) san.get(0) == SAN_TYPE_DNS_NAME) { + return (String) san.get(1); + } + } + } + } catch (java.security.cert.CertificateParsingException e) { + logger.log(Level.WARNING, "Error parsing certificate SANs. " + "This is not expected," + + "falling back to the subject according to the spec.", e); + } + return cert.getSubjectX500Principal().getName(); + } + + @Override + public String getUrlPemEncodedCertificate(X509Certificate cert) + throws CertificateEncodingException, UnsupportedEncodingException { + String pemCert = CertPemConverter.toPem(cert); + return URLEncoder.encode(pemCert, StandardCharsets.UTF_8.toString()); + } + } + + /** + * A utility class for PEM encoding. + */ + final class CertPemConverter { + + private static final String X509_PEM_HEADER = "-----BEGIN CERTIFICATE-----\n"; + private static final String X509_PEM_FOOTER = "\n-----END CERTIFICATE-----\n"; + + private CertPemConverter() {} + + /** + * Converts a certificate to a PEM string. + * + * @param cert The certificate to convert. + * @return The PEM encoded certificate. + * @throws CertificateEncodingException If an error occurs while encoding the certificate. + */ + public static String toPem(X509Certificate cert) throws CertificateEncodingException { + return X509_PEM_HEADER + BaseEncoding.base64().encode(cert.getEncoded()) + X509_PEM_FOOTER; + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java new file mode 100644 index 00000000000..e826f501d9c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java @@ -0,0 +1,250 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import io.grpc.xds.internal.MatcherParser; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceParseException; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesConfig; +import java.util.Optional; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; + +/** + * Represents the configuration for the external authorization (ext_authz) filter. This class + * encapsulates the settings defined in the + * {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz} proto, providing a + * structured, immutable representation for use within gRPC. It includes configurations for the gRPC + * service used for authorization, header mutation rules, and other filter behaviors. + */ +@AutoValue +public abstract class ExtAuthzConfig { + + /** Creates a new builder for creating {@link ExtAuthzConfig} instances. */ + public static Builder builder() { + return new AutoValue_ExtAuthzConfig.Builder().allowedHeaders(ImmutableList.of()) + .disallowedHeaders(ImmutableList.of()).statusOnError(Status.PERMISSION_DENIED) + .filterEnabled(Matchers.FractionMatcher.create(100, 100)); + } + + /** + * Parses the {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz} proto to + * create an {@link ExtAuthzConfig} instance. + * + * @param extAuthzProto The ext_authz proto to parse. + * @return An {@link ExtAuthzConfig} instance. + * @throws ExtAuthzParseException if the proto is invalid or contains unsupported features. + */ + public static ExtAuthzConfig fromProto(ExtAuthz extAuthzProto) throws ExtAuthzParseException { + if (!extAuthzProto.hasGrpcService()) { + throw new ExtAuthzParseException( + "unsupported ExtAuthz service type: only grpc_service is " + "supported"); + } + GrpcServiceConfig grpcServiceConfig; + try { + grpcServiceConfig = GrpcServiceConfig.fromProto(extAuthzProto.getGrpcService()); + } catch (GrpcServiceParseException e) { + throw new ExtAuthzParseException("Failed to parse GrpcService config: " + e.getMessage(), e); + } + Builder builder = builder().grpcService(grpcServiceConfig) + .failureModeAllow(extAuthzProto.getFailureModeAllow()) + .failureModeAllowHeaderAdd(extAuthzProto.getFailureModeAllowHeaderAdd()) + .includePeerCertificate(extAuthzProto.getIncludePeerCertificate()) + .denyAtDisable(extAuthzProto.getDenyAtDisable().getDefaultValue().getValue()); + + if (extAuthzProto.hasFilterEnabled()) { + builder.filterEnabled(parsePercent(extAuthzProto.getFilterEnabled().getDefaultValue())); + } + + if (extAuthzProto.hasStatusOnError()) { + builder.statusOnError( + GrpcUtil.httpStatusToGrpcStatus(extAuthzProto.getStatusOnError().getCodeValue())); + } + + if (extAuthzProto.hasAllowedHeaders()) { + builder.allowedHeaders(extAuthzProto.getAllowedHeaders().getPatternsList().stream() + .map(MatcherParser::parseStringMatcher).collect(ImmutableList.toImmutableList())); + } + + if (extAuthzProto.hasDisallowedHeaders()) { + builder.disallowedHeaders(extAuthzProto.getDisallowedHeaders().getPatternsList().stream() + .map(MatcherParser::parseStringMatcher).collect(ImmutableList.toImmutableList())); + } + + if (extAuthzProto.hasDecoderHeaderMutationRules()) { + builder.decoderHeaderMutationRules( + parseHeaderMutationRules(extAuthzProto.getDecoderHeaderMutationRules())); + } + + return builder.build(); + } + + /** + * The gRPC service configuration for the external authorization service. This is a required + * field. + * + * @see ExtAuthz#getGrpcService() + */ + public abstract GrpcServiceConfig grpcService(); + + /** + * Changes the filter's behavior on errors from the authorization service. If {@code true}, the + * filter will accept the request even if the authorization service fails or returns an error. + * + * @see ExtAuthz#getFailureModeAllow() + */ + public abstract boolean failureModeAllow(); + + /** + * Determines if the {@code x-envoy-auth-failure-mode-allowed} header is added to the request when + * {@link #failureModeAllow()} is true. + * + * @see ExtAuthz#getFailureModeAllowHeaderAdd() + */ + public abstract boolean failureModeAllowHeaderAdd(); + + /** + * Specifies if the peer certificate is sent to the external authorization service. + * + * @see ExtAuthz#getIncludePeerCertificate() + */ + public abstract boolean includePeerCertificate(); + + /** + * The gRPC status returned to the client when the authorization server returns an error or is + * unreachable. Defaults to {@code PERMISSION_DENIED}. + * + * @see io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz#getStatusOnError() + */ + public abstract Status statusOnError(); + + /** + * Specifies whether to deny requests when the filter is disabled. Defaults to {@code false}. + * + * @see ExtAuthz#getDenyAtDisable() + */ + public abstract boolean denyAtDisable(); + + /** + * The fraction of requests that will be checked by the authorization service. Defaults to all + * requests. + * + * @see ExtAuthz#getFilterEnabled() + */ + public abstract Matchers.FractionMatcher filterEnabled(); + + /** + * Specifies which request headers are sent to the authorization service. If not set, all headers + * are sent. + * + * @see ExtAuthz#getAllowedHeaders() + */ + public abstract ImmutableList allowedHeaders(); + + /** + * Specifies which request headers are not sent to the authorization service. This overrides + * {@link #allowedHeaders()}. + * + * @see ExtAuthz#getDisallowedHeaders() + */ + public abstract ImmutableList disallowedHeaders(); + + /** + * Rules for what modifications an ext_authz server may make to request headers. + * + * @see ExtAuthz#getDecoderHeaderMutationRules() + */ + public abstract Optional decoderHeaderMutationRules(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder grpcService(GrpcServiceConfig grpcService); + + public abstract Builder failureModeAllow(boolean failureModeAllow); + + public abstract Builder failureModeAllowHeaderAdd(boolean failureModeAllowHeaderAdd); + + public abstract Builder includePeerCertificate(boolean includePeerCertificate); + + public abstract Builder statusOnError(Status statusOnError); + + public abstract Builder denyAtDisable(boolean denyAtDisable); + + public abstract Builder filterEnabled(Matchers.FractionMatcher filterEnabled); + + public abstract Builder allowedHeaders(Iterable allowedHeaders); + + public abstract Builder disallowedHeaders(Iterable disallowedHeaders); + + public abstract Builder decoderHeaderMutationRules(HeaderMutationRulesConfig rules); + + public abstract ExtAuthzConfig build(); + } + + + private static Matchers.FractionMatcher parsePercent( + io.envoyproxy.envoy.type.v3.FractionalPercent proto) throws ExtAuthzParseException { + int denominator; + switch (proto.getDenominator()) { + case HUNDRED: + denominator = 100; + break; + case TEN_THOUSAND: + denominator = 10_000; + break; + case MILLION: + denominator = 1_000_000; + break; + case UNRECOGNIZED: + default: + throw new ExtAuthzParseException("Unknown denominator type: " + proto.getDenominator()); + } + return Matchers.FractionMatcher.create(proto.getNumerator(), denominator); + } + + private static HeaderMutationRulesConfig parseHeaderMutationRules(HeaderMutationRules proto) + throws ExtAuthzParseException { + HeaderMutationRulesConfig.Builder builder = HeaderMutationRulesConfig.builder(); + builder.disallowAll(proto.getDisallowAll().getValue()); + builder.disallowIsError(proto.getDisallowIsError().getValue()); + if (proto.hasAllowExpression()) { + builder.allowExpression( + parseRegex(proto.getAllowExpression().getRegex(), "allow_expression")); + } + if (proto.hasDisallowExpression()) { + builder.disallowExpression( + parseRegex(proto.getDisallowExpression().getRegex(), "disallow_expression")); + } + return builder.build(); + } + + private static Pattern parseRegex(String regex, String fieldName) throws ExtAuthzParseException { + try { + return Pattern.compile(regex); + } catch (PatternSyntaxException e) { + throw new ExtAuthzParseException( + "Invalid regex pattern for " + fieldName + ": " + e.getMessage(), e); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java new file mode 100644 index 00000000000..78edea5c305 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +/** + * A custom exception for signaling errors during the parsing of external authorization + * (ext_authz) configurations. + */ +public class ExtAuthzParseException extends Exception { + + private static final long serialVersionUID = 0L; + + public ExtAuthzParseException(String message) { + super(message); + } + + public ExtAuthzParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java new file mode 100644 index 00000000000..da9be978f87 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java @@ -0,0 +1,308 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.auto.value.AutoValue; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3.XdsCredentials; +import io.grpc.CallCredentials; +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.auth.MoreCallCredentials; +import io.grpc.xds.XdsChannelCredentials; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Optional; + + +/** + * A Java representation of the {@link io.envoyproxy.envoy.config.core.v3.GrpcService} proto, + * designed for parsing and internal use within gRPC. This class encapsulates the configuration for + * a gRPC service, including target URI, credentials, and other settings. The parsing logic adheres + * to the specifications outlined in + * A102: xDS GrpcService Support. This class is immutable and uses the AutoValue library for its + * implementation. + */ +@AutoValue +public abstract class GrpcServiceConfig { + + public static Builder builder() { + return new AutoValue_GrpcServiceConfig.Builder(); + } + + /** + * Parses the {@link io.envoyproxy.envoy.config.core.v3.GrpcService} proto to create a + * {@link GrpcServiceConfig} instance. This method adheres to gRFC A102, which specifies that only + * the {@code google_grpc} target specifier is supported. Other fields like {@code timeout} and + * {@code initial_metadata} are also parsed as per the gRFC. + * + * @param grpcServiceProto The proto to parse. + * @return A {@link GrpcServiceConfig} instance. + * @throws GrpcServiceParseException if the proto is invalid or uses unsupported features. + */ + public static GrpcServiceConfig fromProto(GrpcService grpcServiceProto) + throws GrpcServiceParseException { + if (!grpcServiceProto.hasGoogleGrpc()) { + throw new GrpcServiceParseException( + "Unsupported: GrpcService must have GoogleGrpc, got: " + grpcServiceProto); + } + GoogleGrpcConfig googleGrpcConfig = + GoogleGrpcConfig.fromProto(grpcServiceProto.getGoogleGrpc()); + + Builder builder = GrpcServiceConfig.builder().googleGrpc(googleGrpcConfig); + + if (!grpcServiceProto.getInitialMetadataList().isEmpty()) { + Metadata initialMetadata = new Metadata(); + for (io.envoyproxy.envoy.config.core.v3.HeaderValue header : grpcServiceProto + .getInitialMetadataList()) { + String key = header.getKey(); + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + initialMetadata.put(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), + BaseEncoding.base64().decode(header.getValue())); + } else { + initialMetadata.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), + header.getValue()); + } + } + builder.initialMetadata(initialMetadata); + } + + if (grpcServiceProto.hasTimeout()) { + com.google.protobuf.Duration timeout = grpcServiceProto.getTimeout(); + builder.timeout(Duration.ofSeconds(timeout.getSeconds(), timeout.getNanos())); + } + return builder.build(); + } + + public abstract GoogleGrpcConfig googleGrpc(); + + public abstract Optional timeout(); + + public abstract Optional initialMetadata(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder googleGrpc(GoogleGrpcConfig googleGrpc); + + public abstract Builder timeout(Duration timeout); + + public abstract Builder initialMetadata(Metadata initialMetadata); + + public abstract GrpcServiceConfig build(); + } + + /** + * Represents the configuration for a Google gRPC service, as defined in the + * {@link io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc} proto. This class + * encapsulates settings specific to Google's gRPC implementation, such as target URI and + * credentials. The parsing of this configuration is guided by gRFC A102, which specifies how gRPC + * clients should interpret the GrpcService proto. + */ + @AutoValue + public abstract static class GoogleGrpcConfig { + + private static final String TLS_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "tls.v3.TlsCredentials"; + private static final String LOCAL_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "local.v3.LocalCredentials"; + private static final String XDS_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "xds.v3.XdsCredentials"; + private static final String INSECURE_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "insecure.v3.InsecureCredentials"; + private static final String GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "google_default.v3.GoogleDefaultCredentials"; + + public static Builder builder() { + return new AutoValue_GrpcServiceConfig_GoogleGrpcConfig.Builder(); + } + + /** + * Parses the {@link io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc} proto to create + * a {@link GoogleGrpcConfig} instance. + * + * @param googleGrpcProto The proto to parse. + * @return A {@link GoogleGrpcConfig} instance. + * @throws GrpcServiceParseException if the proto is invalid. + */ + public static GoogleGrpcConfig fromProto(GrpcService.GoogleGrpc googleGrpcProto) + throws GrpcServiceParseException { + + HashedChannelCredentials channelCreds = + extractChannelCredentials(googleGrpcProto.getChannelCredentialsPluginList()); + + CallCredentials callCreds = + extractCallCredentials(googleGrpcProto.getCallCredentialsPluginList()); + + return GoogleGrpcConfig.builder().target(googleGrpcProto.getTargetUri()) + .hashedChannelCredentials(channelCreds).callCredentials(callCreds).build(); + } + + public abstract String target(); + + public abstract HashedChannelCredentials hashedChannelCredentials(); + + public abstract CallCredentials callCredentials(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder target(String target); + + public abstract Builder hashedChannelCredentials(HashedChannelCredentials channelCredentials); + + public abstract Builder callCredentials(CallCredentials callCredentials); + + public abstract GoogleGrpcConfig build(); + } + + private static T getFirstSupported(List configs, Parser parser, + String configName) throws GrpcServiceParseException { + List errors = new ArrayList<>(); + for (U config : configs) { + try { + return parser.parse(config); + } catch (GrpcServiceParseException e) { + errors.add(e.getMessage()); + } + } + throw new GrpcServiceParseException( + "No valid supported " + configName + " found. Errors: " + errors); + } + + private static HashedChannelCredentials channelCredsFromProto(Any cred) + throws GrpcServiceParseException { + String typeUrl = cred.getTypeUrl(); + try { + switch (typeUrl) { + case GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL: + return HashedChannelCredentials.of(GoogleDefaultChannelCredentials.create(), + cred.hashCode()); + case INSECURE_CREDENTIALS_TYPE_URL: + return HashedChannelCredentials.of(InsecureChannelCredentials.create(), + cred.hashCode()); + case XDS_CREDENTIALS_TYPE_URL: + XdsCredentials xdsConfig = cred.unpack(XdsCredentials.class); + HashedChannelCredentials fallbackCreds = + channelCredsFromProto(xdsConfig.getFallbackCredentials()); + return HashedChannelCredentials.of( + XdsChannelCredentials.create(fallbackCreds.channelCredentials()), cred.hashCode()); + case LOCAL_CREDENTIALS_TYPE_URL: + // TODO(sauravzg) : What's the java alternative to LocalCredentials. + throw new GrpcServiceParseException("LocalCredentials are not yet supported."); + case TLS_CREDENTIALS_TYPE_URL: + // TODO(sauravzg) : How to instantiate a TlsChannelCredentials from TlsCredentials + // proto? + throw new GrpcServiceParseException("TlsCredentials are not yet supported."); + default: + throw new GrpcServiceParseException("Unsupported channel credentials type: " + typeUrl); + } + } catch (InvalidProtocolBufferException e) { + // TODO(sauravzg): Add unit tests when we have a solution for TLS creds. + // This code is as of writing unreachable because all channel credential message + // types except TLS are empty messages. + throw new GrpcServiceParseException( + "Failed to parse channel credentials: " + e.getMessage()); + } + } + + private static CallCredentials callCredsFromProto(Any cred) throws GrpcServiceParseException { + try { + AccessTokenCredentials accessToken = cred.unpack(AccessTokenCredentials.class); + // TODO(sauravzg): Verify if the current behavior is per spec.The `AccessTokenCredentials` + // config doesn't have any timeout/refresh, so set the token to never expire. + return MoreCallCredentials.from(OAuth2Credentials + .create(new AccessToken(accessToken.getToken(), new Date(Long.MAX_VALUE)))); + } catch (InvalidProtocolBufferException e) { + throw new GrpcServiceParseException( + "Unsupported call credentials type: " + cred.getTypeUrl()); + } + } + + private static HashedChannelCredentials extractChannelCredentials( + List channelCredentialPlugins) throws GrpcServiceParseException { + return getFirstSupported(channelCredentialPlugins, GoogleGrpcConfig::channelCredsFromProto, + "channel_credentials"); + } + + private static CallCredentials extractCallCredentials(List callCredentialPlugins) + throws GrpcServiceParseException { + return getFirstSupported(callCredentialPlugins, GoogleGrpcConfig::callCredsFromProto, + "call_credentials"); + } + } + + /** + * A container for {@link ChannelCredentials} and a hash for the purpose of caching. + */ + @AutoValue + public abstract static class HashedChannelCredentials { + /** + * Creates a new {@link HashedChannelCredentials} instance. + * + * @param creds The channel credentials. + * @param hash The hash of the credentials. + * @return A new {@link HashedChannelCredentials} instance. + */ + public static HashedChannelCredentials of(ChannelCredentials creds, int hash) { + return new AutoValue_GrpcServiceConfig_HashedChannelCredentials(creds, hash); + } + + /** + * Returns the channel credentials. + */ + public abstract ChannelCredentials channelCredentials(); + + /** + * Returns the hash of the credentials. + */ + public abstract int hash(); + } + + /** + * Defines a generic interface for parsing a configuration of type {@code U} into a result of type + * {@code T}. This functional interface is used to abstract the parsing logic for different parts + * of the GrpcService configuration. + * + * @param The type of the object that will be returned after parsing. + * @param The type of the configuration object that will be parsed. + */ + private interface Parser { + + /** + * Parses the given configuration. + * + * @param config The configuration object to parse. + * @return The parsed object of type {@code T}. + * @throws GrpcServiceParseException if an error occurs during parsing. + */ + T parse(U config) throws GrpcServiceParseException; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigChannelFactory.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigChannelFactory.java new file mode 100644 index 00000000000..0d02989eaa3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigChannelFactory.java @@ -0,0 +1,26 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import io.grpc.ManagedChannel; + +/** + * A factory for creating {@link ManagedChannel}s from a {@link GrpcServiceConfig}. + */ +public interface GrpcServiceConfigChannelFactory { + ManagedChannel createChannel(GrpcServiceConfig config); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java new file mode 100644 index 00000000000..319ad3d07e3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +/** + * Exception thrown when there is an error parsing the gRPC service config. + */ +public class GrpcServiceParseException extends Exception { + + private static final long serialVersionUID = 1L; + + public GrpcServiceParseException(String message) { + super(message); + } + + public GrpcServiceParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactory.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactory.java new file mode 100644 index 00000000000..d6325d43be4 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import io.grpc.Grpc; +import io.grpc.ManagedChannel; + +/** + * An insecure implementation of {@link GrpcServiceConfigChannelFactory} that creates a plaintext + * channel. This is a stub implementation for channel creation until the GrpcService trusted server + * implementation is completely implemented. + */ +public final class InsecureGrpcChannelFactory implements GrpcServiceConfigChannelFactory { + + private static final InsecureGrpcChannelFactory INSTANCE = new InsecureGrpcChannelFactory(); + + private InsecureGrpcChannelFactory() {} + + public static InsecureGrpcChannelFactory getInstance() { + return INSTANCE; + } + + @Override + public ManagedChannel createChannel(GrpcServiceConfig config) { + GrpcServiceConfig.GoogleGrpcConfig googleGrpc = config.googleGrpc(); + return Grpc.newChannelBuilder(googleGrpc.target(), + googleGrpc.hashedChannelCredentials().channelCredentials()).build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java new file mode 100644 index 00000000000..b8d4eb582fb --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import io.grpc.Status; +import io.grpc.StatusException; + +/** + * Exception thrown when a header mutation is disallowed. + */ +public final class HeaderMutationDisallowedException extends StatusException { + + private static final long serialVersionUID = 1L; + + public HeaderMutationDisallowedException(String message) { + super(Status.INTERNAL.withDescription(message)); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java new file mode 100644 index 00000000000..0452354d823 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java @@ -0,0 +1,172 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.util.Collection; +import java.util.Locale; +import java.util.Optional; +import java.util.function.Predicate; + +/** + * The HeaderMutationFilter class is responsible for filtering header mutations based on a given set + * of rules. + */ +public interface HeaderMutationFilter { + + /** + * A factory for creating {@link HeaderMutationFilter} instances. + */ + @FunctionalInterface + interface Factory { + /** + * Creates a new instance of {@code HeaderMutationFilter}. + * + * @param mutationRules The rules for header mutations. If an empty {@code Optional} is + * provided, all header mutations are allowed by default, except for certain system + * headers. If a {@link HeaderMutationRulesConfig} is provided, mutations will be + * filtered based on the specified rules. + */ + HeaderMutationFilter create(Optional mutationRules); + } + + /** + * The default factory for creating {@link HeaderMutationFilter} instances. + */ + Factory INSTANCE = HeaderMutationFilterImpl::new; + + /** + * Filters the given header mutations based on the configured rules and returns the allowed + * mutations. + * + * @param mutations The header mutations to filter + * @return The allowed header mutations. + * @throws HeaderMutationDisallowedException if a disallowed mutation is encountered and the rules + * specify that this should be an error. + */ + HeaderMutations filter(HeaderMutations mutations) throws HeaderMutationDisallowedException; + + /** Default implementation of {@link HeaderMutationFilter}. */ + final class HeaderMutationFilterImpl implements HeaderMutationFilter { + private final Optional mutationRules; + + /** + * Set of HTTP/2 pseudo-headers and the host header that are critical for routing and protocol + * correctness. These headers cannot be mutated by user configuration. + */ + private static final ImmutableSet IMMUTABLE_HEADERS = + ImmutableSet.of("host", ":authority", ":scheme", ":method"); + + private HeaderMutationFilterImpl(Optional mutationRules) { // NOPMD + this.mutationRules = mutationRules; + } + + @Override + public HeaderMutations filter(HeaderMutations mutations) + throws HeaderMutationDisallowedException { + ImmutableList allowedRequestHeaders = + filterCollection(mutations.requestMutations().headers(), + header -> isHeaderMutationAllowed(header.getHeader().getKey()) + && !appendsSystemHeader(header)); + ImmutableList allowedRequestHeadersToRemove = + filterCollection(mutations.requestMutations().headersToRemove(), + header -> isHeaderMutationAllowed(header) && isHeaderRemovalAllowed(header)); + ImmutableList allowedResponseHeaders = + filterCollection(mutations.responseMutations().headers(), + header -> isHeaderMutationAllowed(header.getHeader().getKey()) + && !appendsSystemHeader(header)); + return HeaderMutations.create( + RequestHeaderMutations.create(allowedRequestHeaders, allowedRequestHeadersToRemove), + ResponseHeaderMutations.create(allowedResponseHeaders)); + } + + /** + * A generic helper to filter a collection based on a predicate. + * + * @param items The collection of items to filter. + * @param isAllowedPredicate The predicate to apply to each item. + * @param The type of items in the collection. + * @return An immutable list of allowed items. + * @throws HeaderMutationDisallowedException if an item is disallowed and disallowIsError is + * true. + */ + private ImmutableList filterCollection(Collection items, + Predicate isAllowedPredicate) throws HeaderMutationDisallowedException { + ImmutableList.Builder allowed = ImmutableList.builder(); + for (T item : items) { + if (isAllowedPredicate.test(item)) { + allowed.add(item); + } else if (disallowIsError()) { + throw new HeaderMutationDisallowedException( + "Header mutation disallowed for header: " + item); + } + } + return allowed.build(); + } + + private boolean isHeaderRemovalAllowed(String headerKey) { + return !isSystemHeaderKey(headerKey); + } + + private boolean appendsSystemHeader(HeaderValueOption headerValueOption) { + String key = headerValueOption.getHeader().getKey(); + boolean isAppend = headerValueOption + .getAppendAction() == HeaderValueOption.HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD; + return isAppend && isSystemHeaderKey(key); + } + + private boolean isSystemHeaderKey(String key) { + return key.startsWith(":") || key.toLowerCase(Locale.ROOT).equals("host"); + } + + private boolean isHeaderMutationAllowed(String headerName) { + String lowerCaseHeaderName = headerName.toLowerCase(Locale.ROOT); + if (IMMUTABLE_HEADERS.contains(lowerCaseHeaderName)) { + return false; + } + return mutationRules.map(rules -> isHeaderMutationAllowed(lowerCaseHeaderName, rules)) + .orElse(true); + } + + private boolean isHeaderMutationAllowed(String lowerCaseHeaderName, + HeaderMutationRulesConfig rules) { + // TODO(sauravzg): The priority is slightly unclear in the spec. + // Both `disallowAll` and `disallow_expression` take precedence over `all other + // settings`. + // `allow_expression` takes precedence over everything except `disallow_expression`. + // This is a conflict between ordering for `allow_expression` and `disallowAll`. + // Choosing to proceed with current envoy implementation which favors `allow_expression` over + // `disallowAll`. + if (rules.disallowExpression().isPresent() + && rules.disallowExpression().get().matcher(lowerCaseHeaderName).matches()) { + return false; + } + if (rules.allowExpression().isPresent()) { + return rules.allowExpression().get().matcher(lowerCaseHeaderName).matches(); + } + return !rules.disallowAll(); + } + + private boolean disallowIsError() { + return mutationRules.map(HeaderMutationRulesConfig::disallowIsError).orElse(false); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java new file mode 100644 index 00000000000..fd8048fdbd2 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import java.util.Optional; +import java.util.regex.Pattern; + +/** + * Represents the configuration for header mutation rules, as defined in the + * {@link io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules} proto. + */ +@AutoValue +public abstract class HeaderMutationRulesConfig { + /** Creates a new builder for creating {@link HeaderMutationRulesConfig} instances. */ + public static Builder builder() { + return new AutoValue_HeaderMutationRulesConfig.Builder().disallowAll(false) + .disallowIsError(false); + } + + /** + * If set, allows any header that matches this regular expression. + * + * @see HeaderMutationRules#getAllowExpression() + */ + public abstract Optional allowExpression(); + + /** + * If set, disallows any header that matches this regular expression. + * + * @see HeaderMutationRules#getDisallowExpression() + */ + public abstract Optional disallowExpression(); + + /** + * If true, disallows all header mutations. + * + * @see HeaderMutationRules#getDisallowAll() + */ + public abstract boolean disallowAll(); + + /** + * If true, disallows any header mutation that would result in an invalid header value. + * + * @see HeaderMutationRules#getDisallowIsError() + */ + public abstract boolean disallowIsError(); + + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder allowExpression(Pattern matcher); + + public abstract Builder disallowExpression(Pattern matcher); + + public abstract Builder disallowAll(boolean disallowAll); + + public abstract Builder disallowIsError(boolean disallowIsError); + + public abstract HeaderMutationRulesConfig build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java new file mode 100644 index 00000000000..e0cb3daede3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; + +/** A collection of header mutations for both request and response headers. */ +@AutoValue +public abstract class HeaderMutations { + + public static HeaderMutations create(RequestHeaderMutations requestMutations, + ResponseHeaderMutations responseMutations) { + return new AutoValue_HeaderMutations(requestMutations, responseMutations); + } + + public abstract RequestHeaderMutations requestMutations(); + + public abstract ResponseHeaderMutations responseMutations(); + + /** Represents mutations for request headers. */ + @AutoValue + public abstract static class RequestHeaderMutations { + public static RequestHeaderMutations create(ImmutableList headers, + ImmutableList headersToRemove) { + return new AutoValue_HeaderMutations_RequestHeaderMutations(headers, headersToRemove); + } + + public abstract ImmutableList headers(); + + public abstract ImmutableList headersToRemove(); + } + + /** Represents mutations for response headers. */ + @AutoValue + public abstract static class ResponseHeaderMutations { + public static ResponseHeaderMutations create(ImmutableList headers) { + return new AutoValue_HeaderMutations_ResponseHeaderMutations(headers); + } + + public abstract ImmutableList headers(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java new file mode 100644 index 00000000000..de5b946bbc7 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java @@ -0,0 +1,143 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.common.io.BaseEncoding; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption.HeaderAppendAction; +import io.grpc.Metadata; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.nio.charset.StandardCharsets; +import java.util.logging.Logger; + +/** + * The HeaderMutator class is an implementation of the HeaderMutator interface. It provides methods + * to apply header mutations to a given set of headers based on a given set of rules. + */ +public interface HeaderMutator { + /** + * Creates a new instance of {@code HeaderMutator}. + */ + static HeaderMutator create() { + return new HeaderMutatorImpl(); + } + + /** + * Applies the given header mutations to the provided metadata headers. + * + * @param mutations The header mutations to apply. + * @param headers The metadata headers to which the mutations will be applied. + */ + void applyRequestMutations(RequestHeaderMutations mutations, Metadata headers); + + + /** + * Applies the given header mutations to the provided metadata headers. + * + * @param mutations The header mutations to apply. + * @param headers The metadata headers to which the mutations will be applied. + */ + void applyResponseMutations(ResponseHeaderMutations mutations, Metadata headers); + + /** Default implementation of {@link HeaderMutator}. */ + final class HeaderMutatorImpl implements HeaderMutator { + + private static final Logger logger = Logger.getLogger(HeaderMutatorImpl.class.getName()); + + @Override + public void applyRequestMutations(final RequestHeaderMutations mutations, Metadata headers) { + // TODO(sauravzg): The specification is not clear on order of header removals and additions. + // in case of conflicts. Copying the order from Envoy here, which does removals at the end. + applyHeaderUpdates(mutations.headers(), headers); + for (String headerToRemove : mutations.headersToRemove()) { + headers.discardAll(Metadata.Key.of(headerToRemove, Metadata.ASCII_STRING_MARSHALLER)); + } + } + + @Override + public void applyResponseMutations(final ResponseHeaderMutations mutations, Metadata headers) { + applyHeaderUpdates(mutations.headers(), headers); + } + + private void applyHeaderUpdates(final Iterable headerOptions, + Metadata headers) { + for (HeaderValueOption headerOption : headerOptions) { + HeaderValue headerValue = headerOption.getHeader(); + updateHeader(headerValue, headerOption.getAppendAction(), headers); + } + } + + private void updateHeader(final HeaderValue header, final HeaderAppendAction action, + Metadata mutableHeaders) { + if (header.getKey().endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + updateHeader(action, Metadata.Key.of(header.getKey(), Metadata.BINARY_BYTE_MARSHALLER), + getBinaryHeaderValue(header), mutableHeaders); + } else { + updateHeader(action, Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER), + getAsciiValue(header), mutableHeaders); + } + } + + private void updateHeader(final HeaderAppendAction action, final Metadata.Key key, + final T value, Metadata mutableHeaders) { + switch (action) { + case APPEND_IF_EXISTS_OR_ADD: + mutableHeaders.put(key, value); + break; + case ADD_IF_ABSENT: + if (!mutableHeaders.containsKey(key)) { + mutableHeaders.put(key, value); + } + break; + case OVERWRITE_IF_EXISTS_OR_ADD: + mutableHeaders.discardAll(key); + mutableHeaders.put(key, value); + break; + case OVERWRITE_IF_EXISTS: + if (mutableHeaders.containsKey(key)) { + mutableHeaders.discardAll(key); + mutableHeaders.put(key, value); + } + break; + case UNRECOGNIZED: + // Ignore invalid value + logger.warning("Unrecognized HeaderAppendAction: " + action); + break; + default: + // Should be unreachable unless there's a proto schema mismatch. + logger.warning("Unknown HeaderAppendAction: " + action); + } + } + + private byte[] getBinaryHeaderValue(HeaderValue header) { + return BaseEncoding.base64().decode(getAsciiValue(header)); + } + + private String getAsciiValue(HeaderValue header) { + // TODO(sauravzg): GRPC only supports base64 encoded binary headers, so we decode bytes to + // String using `StandardCharsets.US_ASCII`. + // Envoy's spec `raw_value` specification can contain non UTF-8 bytes, so this may potentially + // cause an exception or corruption. + if (!header.getRawValue().isEmpty()) { + return header.getRawValue().toString(StandardCharsets.US_ASCII); + } + return header.getValue(); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/AuthzResponseTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/AuthzResponseTest.java new file mode 100644 index 00000000000..e81e356fe75 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/AuthzResponseTest.java @@ -0,0 +1,66 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.xds.internal.extauthz.AuthzResponse.Decision; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AuthzResponseTest { + @Test + public void testAllow() { + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); + AuthzResponse response = AuthzResponse.allow(headers).build(); + assertThat(response.decision()).isEqualTo(Decision.ALLOW); + assertThat(response.headers()).hasValue(headers); + assertThat(response.status()).isEmpty(); + assertThat(response.responseHeaderMutations().headers()).isEmpty(); + } + + @Test + public void testAllowWithHeaderMutations() { + Metadata headers = new Metadata(); + ResponseHeaderMutations mutations = + ResponseHeaderMutations.create(ImmutableList.of(HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey("key").setValue("value")).build())); + AuthzResponse response = + AuthzResponse.allow(headers).setResponseHeaderMutations(mutations).build(); + assertThat(response.decision()).isEqualTo(Decision.ALLOW); + assertThat(response.responseHeaderMutations()).isEqualTo(mutations); + } + + @Test + public void testDeny() { + Status status = Status.PERMISSION_DENIED.withDescription("reason"); + AuthzResponse response = AuthzResponse.deny(status).build(); + assertThat(response.decision()).isEqualTo(Decision.DENY); + assertThat(response.status()).hasValue(status); + assertThat(response.headers()).isEmpty(); + assertThat(response.responseHeaderMutations().headers()).isEmpty(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java new file mode 100644 index 00000000000..1faa0062a04 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java @@ -0,0 +1,350 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.protobuf.Any; +import com.google.protobuf.Timestamp; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.envoyproxy.envoy.type.matcher.v3.ListStringMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.testing.TestMethodDescriptors; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class CheckRequestBuilderTest { + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + private ServerCall serverCall; + @Mock + private SSLSession sslSession; + @Mock + private ExtAuthzCertificateProvider certificateProvider; + + private CheckRequestBuilder checkRequestBuilder; + private MethodDescriptor methodDescriptor; + private Timestamp requestTime; + + @Before + public void setUp() throws ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(); + checkRequestBuilder = CheckRequestBuilder.INSTANCE.create(config, certificateProvider); + methodDescriptor = TestMethodDescriptors.voidMethod(); + requestTime = Timestamp.newBuilder().setSeconds(12345).setNanos(67890).build(); + } + + @Test + public void buildRequest_forServer_happyPath() throws Exception { + // Setup for addresses + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + + // Setup for SSL and certificates + X509Certificate peerCert = mock(X509Certificate.class); + X509Certificate localCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + Certificate[] localCerts = new Certificate[] {localCert}; + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(sslSession.getLocalCertificates()).thenReturn(localCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + when(certificateProvider.getPrincipal(localCert)).thenReturn("local-principal"); + when(certificateProvider.getUrlPemEncodedCertificate(peerCert)).thenReturn("encoded-peer-cert"); + + // Setup for headers + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("allowed-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + headers.put(Metadata.Key.of("overridden-header", Metadata.ASCII_STRING_MARSHALLER), "v3"); + byte[] binaryValue = new byte[] {1, 2, 3}; + headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), binaryValue); + + // Configure CheckRequestBuilder to allow specific headers + ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()) + .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()).build(); + ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()) + .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()).build(); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); + checkRequestBuilder = CheckRequestBuilder.INSTANCE.create(config, certificateProvider); + + // Setup server call attributes + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + // Build and verify the request + CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.getSource().getAddress().getSocketAddress().getAddress()) + .isEqualTo("192.168.1.1"); + assertThat(attrContext.getSource().getPrincipal()).isEqualTo("peer-principal"); + assertThat(attrContext.getSource().getCertificate()).isEqualTo("encoded-peer-cert"); + assertThat(attrContext.getDestination().getAddress().getSocketAddress().getAddress()) + .isEqualTo("10.0.0.2"); + assertThat(attrContext.getDestination().getPrincipal()).isEqualTo("local-principal"); + + AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); + assertThat(http.getHeadersMap()).containsEntry("allowed-header", "v1"); + assertThat(http.getHeadersMap()).doesNotContainKey("bin-header-bin"); + assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + assertThat(http.getHeadersMap()).doesNotContainKey("overridden-header"); + } + + @Test + public void buildRequest_forServer_noTransportAttrs() { + when(serverCall.getAttributes()).thenReturn(Attributes.EMPTY); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + Metadata headers = new Metadata(); + + CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); + + assertThat(request.getAttributes().getRequest().getTime()).isEqualTo(requestTime); + assertThat(request.getAttributes().getRequest().getHttp().getPath()) + .isEqualTo(methodDescriptor.getFullMethodName()); + assertThat(request.getAttributes().getRequest().getHttp().getMethod()).isEqualTo("POST"); + assertThat(request.getAttributes().getRequest().getHttp().getProtocol()).isEqualTo("HTTP/2"); + assertThat(request.getAttributes().getRequest().getHttp().getSize()).isEqualTo(-1); + assertThat(request.getAttributes().getRequest().getHttp().getHeadersMap()).isEmpty(); + assertThat(request.getAttributes().hasSource()).isFalse(); + assertThat(request.getAttributes().hasDestination()).isFalse(); + } + + + @Test + public void buildRequest_forClient_happyPath_emptyAllowedHeaders() throws Exception { + // Setup for headers + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("some-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + byte[] binaryValue = new byte[] {1, 2, 3}; + headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), binaryValue); + + // Configure CheckRequestBuilder with empty allowed headers + ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder().build(); // empty + ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()).build(); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); + checkRequestBuilder = CheckRequestBuilder.INSTANCE.create(config, certificateProvider); + + // Build and verify the request + CheckRequest request = checkRequestBuilder.buildRequest(methodDescriptor, headers, requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isFalse(); + assertThat(attrContext.hasDestination()).isFalse(); + + AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); + assertThat(http.getPath()).isEqualTo(methodDescriptor.getFullMethodName()); + assertThat(http.getHeadersMap()).containsEntry("some-header", "v1"); + assertThat(http.getHeadersMap()).containsEntry("bin-header-bin", "AQID"); + assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + } + + @Test + public void buildRequest_forServer_noSslSession() { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isTrue(); + Address sourceAddress = attrContext.getSource().getAddress(); + assertThat(sourceAddress.getSocketAddress().getAddress()).isEqualTo("192.168.1.1"); + assertThat(sourceAddress.getSocketAddress().getPortValue()).isEqualTo(12345); + assertThat(attrContext.getSource().getPrincipal()).isEmpty(); + + assertThat(attrContext.hasDestination()).isTrue(); + Address destAddress = attrContext.getDestination().getAddress(); + assertThat(destAddress.getSocketAddress().getAddress()).isEqualTo("10.0.0.2"); + assertThat(destAddress.getSocketAddress().getPortValue()).isEqualTo(443); + assertThat(attrContext.getDestination().getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_sslPeerUnverified() throws Exception { + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + when(sslSession.getPeerCertificates()).thenThrow(new SSLPeerUnverifiedException("unverified")); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + } + + @Test + public void buildRequest_forServer_includePeerCertFalse() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), + ListStringMatcher.newBuilder().build(), false); + checkRequestBuilder = CheckRequestBuilder.INSTANCE.create(config, certificateProvider); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + X509Certificate peerCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEqualTo("peer-principal"); + assertThat(source.getCertificate()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nullOrEmptyCertificates() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + // Test with null certificates + when(sslSession.getPeerCertificates()).thenReturn(null); + when(sslSession.getLocalCertificates()).thenReturn(null); + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + AttributeContext.Peer destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + + // Test with empty certificates + when(sslSession.getPeerCertificates()).thenReturn(new Certificate[0]); + when(sslSession.getLocalCertificates()).thenReturn(new Certificate[0]); + request = checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nonX509Certificate() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + Certificate nonX509Cert = mock(Certificate.class); + Certificate[] certs = new Certificate[] {nonX509Cert}; + + when(sslSession.getPeerCertificates()).thenReturn(certs); + when(sslSession.getLocalCertificates()).thenReturn(certs); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + AttributeContext.Peer destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nonInetSocketAddress() { + SocketAddress remoteAddress = mock(SocketAddress.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build()); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + assertThat(request.getAttributes().getSource().hasAddress()).isFalse(); + } + + private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { + return buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), + ListStringMatcher.newBuilder().build(), true); + } + + private ExtAuthzConfig buildExtAuthzConfig(ListStringMatcher allowed, + ListStringMatcher disallowed, boolean includePeerCertificate) throws ExtAuthzParseException { + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz.Builder builder = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster").addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setIncludePeerCertificate(includePeerCertificate).setAllowedHeaders(allowed) + .setDisallowedHeaders(disallowed); + return ExtAuthzConfig.fromProto(builder.build()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckResponseHandlerTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckResponseHandlerTest.java new file mode 100644 index 00000000000..31b14a312c4 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckResponseHandlerTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.rpc.Code; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.service.auth.v3.CheckResponse; +import io.envoyproxy.envoy.service.auth.v3.DeniedHttpResponse; +import io.envoyproxy.envoy.service.auth.v3.OkHttpResponse; +import io.envoyproxy.envoy.type.v3.HttpStatus; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.xds.internal.extauthz.AuthzResponse.Decision; +import io.grpc.xds.internal.headermutations.HeaderMutationDisallowedException; +import io.grpc.xds.internal.headermutations.HeaderMutationFilter; +import io.grpc.xds.internal.headermutations.HeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class CheckResponseHandlerTest { + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock + private HeaderMutator headerMutator; + @Mock + private HeaderMutationFilter headerMutationFilter; + + private CheckResponseHandler responseHandler; + + @Before + public void setUp() throws Exception { + responseHandler = + CheckResponseHandler.INSTANCE.create(headerMutator, headerMutationFilter, + buildExtAuthzConfig()); + when(headerMutationFilter.filter(any(HeaderMutations.class))) + .thenAnswer(invocation -> invocation.getArgument(0)); + } + + @Test + public void handleResponse_ok() { + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.OK_VALUE).build()).build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.ALLOW); + assertThat(authzResponse.headers()).hasValue(headers); + } + + @Test + public void handleResponse_okWithMutations() { + HeaderValueOption option = HeaderValueOption.newBuilder().build(); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.OK_VALUE).build()) + .setOkResponse(OkHttpResponse.newBuilder().addHeaders(option) + .addHeadersToRemove("remove-key").addResponseHeadersToAdd(option).build()) + .build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.ALLOW); + assertThat(authzResponse.headers()).hasValue(headers); + HeaderMutations expectedMutations = HeaderMutations.create( + HeaderMutations.RequestHeaderMutations.create(ImmutableList.of(option), + ImmutableList.of("remove-key")), + HeaderMutations.ResponseHeaderMutations.create(ImmutableList.of(option))); + verify(headerMutator).applyRequestMutations(expectedMutations.requestMutations(), headers); + assertThat(authzResponse.responseHeaderMutations()) + .isEqualTo(expectedMutations.responseMutations()); + } + + @Test + public void handleResponse_notOk() { + CheckResponse checkResponse = CheckResponse.newBuilder().setStatus(com.google.rpc.Status + .newBuilder().setCode(Code.PERMISSION_DENIED_VALUE).setMessage("denied").build()).build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.DENY); + assertThat(authzResponse.status().isPresent()).isTrue(); + assertThat(authzResponse.status().get().getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(authzResponse.status().get().getDescription()).isEqualTo("HTTP status code 403"); + verify(headerMutator, never()).applyRequestMutations(any(), any()); + } + + @Test + public void handleResponse_deniedResponseWithoutStatusOverride() { + HeaderValueOption option = HeaderValueOption.newBuilder().build(); + DeniedHttpResponse deniedHttpResponse = + DeniedHttpResponse.newBuilder().addHeaders(option).build(); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.ABORTED_VALUE).build()) + .setDeniedResponse(deniedHttpResponse).build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.DENY); + assertThat(authzResponse.status().get().getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(authzResponse.status().get().getDescription()).isEqualTo("HTTP status code 403"); + HeaderMutations.ResponseHeaderMutations expectedMutations = + HeaderMutations.ResponseHeaderMutations.create(ImmutableList.of(option)); + assertThat(authzResponse.responseHeaderMutations()).isEqualTo(expectedMutations); + verify(headerMutator, never()).applyRequestMutations(any(), any()); + } + + @Test + public void handleResponse_deniedResponseWithStatusOverride() { + DeniedHttpResponse deniedHttpResponse = + DeniedHttpResponse.newBuilder().setStatus(HttpStatus.newBuilder().setCodeValue(401).build()) + .setBody("custom body").build(); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.ABORTED_VALUE).build()) + .setDeniedResponse(deniedHttpResponse).build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.DENY); + assertThat(authzResponse.status().isPresent()).isTrue(); + Status status = authzResponse.status().get(); + assertThat(status.getCode()).isEqualTo(Status.Code.UNAUTHENTICATED); + assertThat(status.getDescription()).isEqualTo("custom body"); + HeaderMutations.ResponseHeaderMutations expectedMutations = + HeaderMutations.ResponseHeaderMutations.create(ImmutableList.of()); + assertThat(authzResponse.responseHeaderMutations()).isEqualTo(expectedMutations); + verify(headerMutator, never()).applyRequestMutations(any(), any()); + } + + @Test + public void handleResponse_okWithDisallowedMutation() throws HeaderMutationDisallowedException { + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.OK_VALUE).build()) + .setOkResponse(OkHttpResponse.newBuilder().build()).build(); + Metadata headers = new Metadata(); + HeaderMutationDisallowedException exception = + new HeaderMutationDisallowedException("disallowed"); + when(headerMutationFilter.filter(any(HeaderMutations.class))).thenThrow(exception); + + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + + assertThat(authzResponse.decision()).isEqualTo(Decision.DENY); + assertThat(authzResponse.status().get().getCode()).isEqualTo(Status.INTERNAL.getCode()); + assertThat(authzResponse.status().get().getDescription()).isEqualTo("disallowed"); + } + + private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster").addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setStatusOnError( + io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder().setCodeValue(403).build()) + .build(); + return ExtAuthzConfig.fromProto(extAuthz); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProviderTest.java new file mode 100644 index 00000000000..fdeff595d56 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProviderTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import javax.security.auth.x500.X500Principal; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + + + +@RunWith(JUnit4.class) +public class ExtAuthzCertificateProviderTest { + private final ExtAuthzCertificateProvider provider = ExtAuthzCertificateProvider.create(); + + @Test + public void getPrincipal_ipAddressSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List ipSan = Arrays.asList(7, "192.168.1.1"); // SAN_TYPE_IP_ADDRESS + Collection> sans = Arrays.asList(ipSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("192.168.1.1"); + } + + @Test + public void getPrincipal_dnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List san = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Collections.singletonList(san); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_noSan_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Collections.emptyList()); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getPrincipal_nullSans_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenReturn(null); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getPrincipal_ipSanWrongSize_usesDnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List ipSan = Collections.singletonList(7); // SAN_TYPE_IP_ADDRESS, wrong size + List dnsSan = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Arrays.asList(ipSan, dnsSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_ipSanWrongType_usesDnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + // SAN_TYPE_IP_ADDRESS, wrong type + List ipSan = Arrays.asList("not-an-integer", "192.168.1.1"); + List dnsSan = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Arrays.asList(ipSan, dnsSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_dnsSanWrongType_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + // Wrong SAN type for DNS check + List otherSan = Arrays.asList(6, "foo.test.google.fr"); // SAN_TYPE_URI + Collection> sans = Collections.singletonList(otherSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + when(mockCert.getSubjectX500Principal()).thenReturn(new X500Principal("CN=test")); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("CN=test"); + } + + @Test + public void getPrincipal_sanParsingException_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenThrow(new CertificateParsingException()); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getUrlPemEncodedCertificate() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + byte[] certData = "cert-data".getBytes(StandardCharsets.UTF_8); + when(mockCert.getEncoded()).thenReturn(certData); + + String pem = "-----BEGIN CERTIFICATE-----\n" + "Y2VydC1kYXRh" // base64 of "cert-data" + + "\n-----END CERTIFICATE-----\n"; + String urlEncodedPem = URLEncoder.encode(pem, StandardCharsets.UTF_8.toString()); + assertThat(provider.getUrlPemEncodedCertificate(mockCert)).isEqualTo(urlEncodedPem); + } + + @Test + public void getUrlPemEncodedCertificate_encodingException() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getEncoded()).thenThrow(new CertificateEncodingException("test")); + assertThrows(CertificateEncodingException.class, + () -> provider.getUrlPemEncodedCertificate(mockCert)); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzConfigTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzConfigTest.java new file mode 100644 index 00000000000..9b9a55b4079 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzConfigTest.java @@ -0,0 +1,259 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.RuntimeFeatureFlag; +import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.type.matcher.v3.ListStringMatcher; +import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; +import io.grpc.Status; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesConfig; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ExtAuthzConfigTest { + + private static final Any GOOGLE_DEFAULT_CHANNEL_CREDS = + Any.pack(GoogleDefaultCredentials.newBuilder().build()); + private static final Any FAKE_ACCESS_TOKEN_CALL_CREDS = + Any.pack(AccessTokenCredentials.newBuilder().build()); + + private ExtAuthz.Builder extAuthzBuilder; + + @Before + public void setUp() { + extAuthzBuilder = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster") + .addChannelCredentialsPlugin(GOOGLE_DEFAULT_CHANNEL_CREDS) + .addCallCredentialsPlugin(FAKE_ACCESS_TOKEN_CALL_CREDS).build()) + .build()); + } + + @Test + public void fromProto_missingGrpcService_throws() { + ExtAuthz extAuthz = ExtAuthz.newBuilder().build(); + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat() + .isEqualTo("unsupported ExtAuthz service type: only grpc_service is supported"); + } + } + + @Test + public void fromProto_invalidGrpcService_throws() { + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder().build()) + .build(); + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Failed to parse GrpcService config:"); + } + } + + @Test + public void fromProto_invalidAllowExpression_throws() { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("[invalid").build()).build()) + .build(); + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Invalid regex pattern for allow_expression:"); + } + } + + @Test + public void fromProto_invalidDisallowExpression_throws() { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("[invalid").build()).build()) + .build(); + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Invalid regex pattern for disallow_expression:"); + } + } + + @Test + public void fromProto_success() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setGrpcService(extAuthzBuilder.getGrpcServiceBuilder() + .setTimeout(com.google.protobuf.Duration.newBuilder().setSeconds(5).build()) + .addInitialMetadata(HeaderValue.newBuilder().setKey("key").setValue("value").build()) + .build()) + .setFailureModeAllow(true).setFailureModeAllowHeaderAdd(true) + .setIncludePeerCertificate(true) + .setStatusOnError( + io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder().setCodeValue(403).build()) + .setDenyAtDisable( + RuntimeFeatureFlag.newBuilder().setDefaultValue(BoolValue.of(true)).build()) + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue(FractionalPercent.newBuilder().setNumerator(50) + .setDenominator(DenominatorType.TEN_THOUSAND).build()) + .build()) + .setAllowedHeaders(ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()).build()) + .setDisallowedHeaders(ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setPrefix("disallowed-").build()).build()) + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow.*").build()) + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow.*").build()) + .setDisallowAll(BoolValue.of(true)).setDisallowIsError(BoolValue.of(true)).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.grpcService().googleGrpc().target()).isEqualTo("test-cluster"); + assertThat(config.grpcService().timeout().get().getSeconds()).isEqualTo(5); + assertThat(config.grpcService().initialMetadata().isPresent()).isTrue(); + assertThat(config.failureModeAllow()).isTrue(); + assertThat(config.failureModeAllowHeaderAdd()).isTrue(); + assertThat(config.includePeerCertificate()).isTrue(); + assertThat(config.statusOnError().getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(config.statusOnError().getDescription()).isEqualTo("HTTP status code 403"); + assertThat(config.denyAtDisable()).isTrue(); + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(50, 10_000)); + assertThat(config.allowedHeaders()).hasSize(1); + assertThat(config.allowedHeaders().get(0).matches("allowed-header")).isTrue(); + assertThat(config.disallowedHeaders()).hasSize(1); + assertThat(config.disallowedHeaders().get(0).matches("disallowed-foo")).isTrue(); + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().get().pattern()).isEqualTo("allow.*"); + assertThat(rules.disallowExpression().get().pattern()).isEqualTo("disallow.*"); + assertThat(rules.disallowAll()).isTrue(); + assertThat(rules.disallowIsError()).isTrue(); + } + + @Test + public void fromProto_saneDefaults() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder.build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.failureModeAllow()).isFalse(); + assertThat(config.failureModeAllowHeaderAdd()).isFalse(); + assertThat(config.includePeerCertificate()).isFalse(); + assertThat(config.statusOnError()).isEqualTo(Status.PERMISSION_DENIED); + assertThat(config.denyAtDisable()).isFalse(); + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(100, 100)); + assertThat(config.allowedHeaders()).isEmpty(); + assertThat(config.disallowedHeaders()).isEmpty(); + assertThat(config.decoderHeaderMutationRules().isPresent()).isFalse(); + } + + @Test + public void fromProto_headerMutationRules_allowExpressionOnly() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow.*").build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().get().pattern()).isEqualTo("allow.*"); + assertThat(rules.disallowExpression().isPresent()).isFalse(); + } + + @Test + public void fromProto_headerMutationRules_disallowExpressionOnly() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow.*").build()) + .build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().isPresent()).isFalse(); + assertThat(rules.disallowExpression().get().pattern()).isEqualTo("disallow.*"); + } + + @Test + public void fromProto_filterEnabled_hundred() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled(RuntimeFractionalPercent.newBuilder().setDefaultValue(FractionalPercent + .newBuilder().setNumerator(25).setDenominator(DenominatorType.HUNDRED).build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(25, 100)); + } + + @Test + public void fromProto_filterEnabled_million() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled( + RuntimeFractionalPercent.newBuilder().setDefaultValue(FractionalPercent.newBuilder() + .setNumerator(123456).setDenominator(DenominatorType.MILLION).build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.filterEnabled()) + .isEqualTo(Matchers.FractionMatcher.create(123456, 1_000_000)); + } + + @Test + public void fromProto_filterEnabled_unrecognizedDenominator() { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue( + FractionalPercent.newBuilder().setNumerator(1).setDenominatorValue(4).build()) + .build()) + .build(); + + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().isEqualTo("Unknown denominator type: UNRECOGNIZED"); + } + } +} \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigTest.java b/xds/src/test/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigTest.java new file mode 100644 index 00000000000..7a506220973 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigTest.java @@ -0,0 +1,243 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Any; +import com.google.protobuf.Duration; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.insecure.v3.InsecureCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.local.v3.LocalCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3.XdsCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import java.nio.charset.StandardCharsets; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcServiceConfigTest { + + @Test + public void fromProto_success() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + HeaderValue asciiHeader = + HeaderValue.newBuilder().setKey("test_key").setValue("test_value").build(); + HeaderValue binaryHeader = HeaderValue.newBuilder().setKey("test_key-bin") + .setValue( + BaseEncoding.base64().encode("test_value_binary".getBytes(StandardCharsets.UTF_8))) + .build(); + Duration timeout = Duration.newBuilder().setSeconds(10).build(); + GrpcService grpcService = + GrpcService.newBuilder().setGoogleGrpc(googleGrpc).addInitialMetadata(asciiHeader) + .addInitialMetadata(binaryHeader).setTimeout(timeout).build(); + + GrpcServiceConfig config = GrpcServiceConfig.fromProto(grpcService); + + // Assert target URI + assertThat(config.googleGrpc().target()).isEqualTo("test_uri"); + + // Assert channel credentials + assertThat(config.googleGrpc().hashedChannelCredentials().channelCredentials()) + .isInstanceOf(InsecureChannelCredentials.class); + assertThat(config.googleGrpc().hashedChannelCredentials().hash()) + .isEqualTo(insecureCreds.hashCode()); + + // Assert call credentials + assertThat(config.googleGrpc().callCredentials().getClass().getName()) + .isEqualTo("io.grpc.auth.GoogleAuthLibraryCallCredentials"); + + // Assert initial metadata + assertThat(config.initialMetadata().isPresent()).isTrue(); + assertThat(config.initialMetadata().get() + .get(Metadata.Key.of("test_key", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("test_value"); + assertThat(config.initialMetadata().get() + .get(Metadata.Key.of("test_key-bin", Metadata.BINARY_BYTE_MARSHALLER))) + .isEqualTo("test_value_binary".getBytes(StandardCharsets.UTF_8)); + + // Assert timeout + assertThat(config.timeout().isPresent()).isTrue(); + assertThat(config.timeout().get()).isEqualTo(java.time.Duration.ofSeconds(10)); + } + + @Test + public void fromProto_minimalSuccess_defaults() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = GrpcServiceConfig.fromProto(grpcService); + + assertThat(config.googleGrpc().target()).isEqualTo("test_uri"); + assertThat(config.initialMetadata().isPresent()).isFalse(); + assertThat(config.timeout().isPresent()).isFalse(); + } + + @Test + public void fromProto_missingGoogleGrpc() { + GrpcService grpcService = GrpcService.newBuilder().build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat() + .startsWith("Unsupported: GrpcService must have GoogleGrpc, got: "); + } + + @Test + public void fromProto_emptyCallCredentials() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat() + .isEqualTo("No valid supported call_credentials found. Errors: []"); + } + + @Test + public void fromProto_emptyChannelCredentials() { + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat() + .isEqualTo("No valid supported channel_credentials found. Errors: []"); + } + + @Test + public void fromProto_googleDefaultCredentials() throws GrpcServiceParseException { + Any googleDefaultCreds = Any.pack(GoogleDefaultCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(googleDefaultCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = GrpcServiceConfig.fromProto(grpcService); + + assertThat(config.googleGrpc().hashedChannelCredentials().channelCredentials()) + .isInstanceOf(io.grpc.CompositeChannelCredentials.class); + assertThat(config.googleGrpc().hashedChannelCredentials().hash()) + .isEqualTo(googleDefaultCreds.hashCode()); + } + + @Test + public void fromProto_localCredentials() throws GrpcServiceParseException { + Any localCreds = Any.pack(LocalCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(localCreds).addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat().contains("LocalCredentials are not yet supported."); + } + + @Test + public void fromProto_xdsCredentials_withInsecureFallback() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + XdsCredentials xdsCreds = + XdsCredentials.newBuilder().setFallbackCredentials(insecureCreds).build(); + Any xdsCredsAny = Any.pack(xdsCreds); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(xdsCredsAny).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = GrpcServiceConfig.fromProto(grpcService); + + assertThat(config.googleGrpc().hashedChannelCredentials().channelCredentials()) + .isInstanceOf(io.grpc.ChannelCredentials.class); + assertThat(config.googleGrpc().hashedChannelCredentials().hash()) + .isEqualTo(xdsCredsAny.hashCode()); + } + + @Test + public void fromProto_tlsCredentials_notSupported() { + Any tlsCreds = Any + .pack(io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.tls.v3.TlsCredentials + .getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(tlsCreds).addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat().contains("TlsCredentials are not yet supported."); + } + + @Test + public void fromProto_invalidChannelCredentialsProto() { + // Pack a Duration proto, but try to unpack it as GoogleDefaultCredentials + Any invalidCreds = Any.pack(com.google.protobuf.Duration.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(invalidCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat() + .contains("No valid supported channel_credentials found. Errors: [Unsupported channel " + + "credentials type: type.googleapis.com/google.protobuf.Duration"); + } + + @Test + public void fromProto_invalidCallCredentialsProto() { + // Pack a Duration proto, but try to unpack it as AccessTokenCredentials + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any invalidCallCredentials = Any.pack(Duration.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(invalidCallCredentials) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat().contains("Unsupported call credentials type:"); + } +} + diff --git a/xds/src/test/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactoryTest.java new file mode 100644 index 00000000000..8d7347f56c6 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactoryTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import static org.junit.Assert.assertNotNull; + +import io.grpc.CallCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig.GoogleGrpcConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig.HashedChannelCredentials; +import java.util.concurrent.Executor; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link InsecureGrpcChannelFactory}. */ +@RunWith(JUnit4.class) +public class InsecureGrpcChannelFactoryTest { + + private static final class NoOpCallCredentials extends CallCredentials { + @Override + public void applyRequestMetadata(RequestInfo requestInfo, Executor appExecutor, + MetadataApplier applier) { + applier.apply(new Metadata()); + } + } + + @Test + public void testCreateChannel() { + InsecureGrpcChannelFactory factory = InsecureGrpcChannelFactory.getInstance(); + GrpcServiceConfig config = GrpcServiceConfig.builder() + .googleGrpc(GoogleGrpcConfig.builder().target("localhost:8080") + .hashedChannelCredentials( + HashedChannelCredentials.of(InsecureChannelCredentials.create(), 0)) + .callCredentials(new NoOpCallCredentials()).build()) + .build(); + ManagedChannel channel = factory.createChannel(config); + assertNotNull(channel); + channel.shutdownNow(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java new file mode 100644 index 00000000000..e73460924c7 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java @@ -0,0 +1,245 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption.HeaderAppendAction; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.util.Optional; +import java.util.regex.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationFilterTest { + + private static HeaderValueOption header(String key, String value) { + return HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(key).setValue(value)).build(); + } + + private static HeaderValueOption header(String key, String value, HeaderAppendAction action) { + return HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(key).setValue(value)).setAppendAction(action) + .build(); + } + + @Test + public void filter_removesImmutableHeaders() throws HeaderMutationDisallowedException { + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.empty()); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("add-key", "add-value"), header(":authority", "new-authority"), + header("host", "new-host"), header(":scheme", "https"), header(":method", "PUT")), + ImmutableList.of("remove-key", "host", ":authority", ":scheme", ":method")), + ResponseHeaderMutations.create(ImmutableList.of(header("resp-add-key", "resp-add-value"), + header(":scheme", "https")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()) + .containsExactly(header("add-key", "add-value")); + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("remove-key"); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header("resp-add-key", "resp-add-value")); + } + + @Test + public void filter_cannotAppendToSystemHeaders() throws HeaderMutationDisallowedException { + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.empty()); + HeaderMutations mutations = + HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of( + header("add-key", "add-value", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(":authority", "new-authority", + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header("host", "new-host", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(":path", "/new-path", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD)), + ImmutableList.of()), + ResponseHeaderMutations.create(ImmutableList + .of(header("host", "new-host", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD)))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()).containsExactly( + header("add-key", "add-value", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD)); + assertThat(filtered.responseMutations().headers()).isEmpty(); + } + + @Test + public void filter_cannotRemoveSystemHeaders() throws HeaderMutationDisallowedException { + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.empty()); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create(ImmutableList.of(), + ImmutableList.of("remove-key", "host", ":foo", ":bar")), + ResponseHeaderMutations.create(ImmutableList.of())); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("remove-key"); + } + + @Test + public void filter_canOverrideSystemHeadersNotInImmutableHeaders() + throws HeaderMutationDisallowedException { + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.empty()); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("user-agent", "new-agent"), + header(":path", "/new/path", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header(":grpc-trace-bin", "binary-value", HeaderAppendAction.ADD_IF_ABSENT)), + ImmutableList.of()), + ResponseHeaderMutations.create(ImmutableList + .of(header(":alt-svc", "h3=:443", HeaderAppendAction.OVERWRITE_IF_EXISTS)))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()).containsExactly( + header("user-agent", "new-agent"), + header(":path", "/new/path", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header(":grpc-trace-bin", "binary-value", HeaderAppendAction.ADD_IF_ABSENT)); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header(":alt-svc", "h3=:443", HeaderAppendAction.OVERWRITE_IF_EXISTS)); + } + + @Test + public void filter_disallowAll_disablesAllModifications() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder().disallowAll(true).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create(ImmutableList.of(header("add-key", "add-value")), + ImmutableList.of("remove-key")), + ResponseHeaderMutations.create(ImmutableList.of(header("resp-add-key", "resp-add-value")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()).isEmpty(); + assertThat(filtered.requestMutations().headersToRemove()).isEmpty(); + assertThat(filtered.responseMutations().headers()).isEmpty(); + } + + @Test + public void filter_disallowExpression_filtersRelevantExpressions() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .disallowExpression(Pattern.compile("^x-private-.*")).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("x-public", "value"), header("x-private-key", "value")), + ImmutableList.of("x-public-remove", "x-private-remove")), + ResponseHeaderMutations.create( + ImmutableList.of(header("x-public-resp", "value"), header("x-private-resp", "value")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()).containsExactly(header("x-public", "value")); + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("x-public-remove"); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header("x-public-resp", "value")); + } + + @Test + public void filter_allowExpression_onlyAllowsRelevantExpressions() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .allowExpression(Pattern.compile("^x-allowed-.*")).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = + HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("x-allowed-key", "value"), + header("not-allowed-key", "value")), + ImmutableList.of("x-allowed-remove", "not-allowed-remove")), + ResponseHeaderMutations.create(ImmutableList.of(header("x-allowed-resp-key", "value"), + header("not-allowed-resp-key", "value")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()) + .containsExactly(header("x-allowed-key", "value")); + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("x-allowed-remove"); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header("x-allowed-resp-key", "value")); + } + + @Test + public void filter_allowExpression_overridesDisallowAll() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder().disallowAll(true) + .allowExpression(Pattern.compile("^x-allowed-.*")).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("x-allowed-key", "value"), header("not-allowed", "value")), + ImmutableList.of("x-allowed-remove", "not-allowed-remove")), + ResponseHeaderMutations.create(ImmutableList.of(header("x-allowed-resp-key", "value"), + header("not-allowed-resp-key", "value")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()) + .containsExactly(header("x-allowed-key", "value")); + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("x-allowed-remove"); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header("x-allowed-resp-key", "value")); + } + + @Test(expected = HeaderMutationDisallowedException.class) + public void filter_disallowIsError_throwsExceptionOnDisallowed() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowAll(true).disallowIsError(true).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create(RequestHeaderMutations + .create(ImmutableList.of(header("add-key", "add-value")), ImmutableList.of()), + ResponseHeaderMutations.create(ImmutableList.of())); + filter.filter(mutations); + } + + @Test(expected = HeaderMutationDisallowedException.class) + public void filter_disallowIsError_throwsExceptionOnDisallowedRemove() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowAll(true).disallowIsError(true).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create(ImmutableList.of(), ImmutableList.of("remove-key")), + ResponseHeaderMutations.create(ImmutableList.of())); + filter.filter(mutations); + } + + @Test(expected = HeaderMutationDisallowedException.class) + public void filter_disallowIsError_throwsExceptionOnDisallowedResponseHeader() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowAll(true).disallowIsError(true).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create(ImmutableList.of(), ImmutableList.of()), + ResponseHeaderMutations.create(ImmutableList.of(header("resp-add-key", "resp-add-value")))); + filter.filter(mutations); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java new file mode 100644 index 00000000000..e2bda9cb836 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.regex.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationRulesConfigTest { + @Test + public void testBuilderDefaultValues() { + HeaderMutationRulesConfig config = HeaderMutationRulesConfig.builder().build(); + assertFalse(config.disallowAll()); + assertFalse(config.disallowIsError()); + assertThat(config.allowExpression()).isEmpty(); + assertThat(config.disallowExpression()).isEmpty(); + } + + @Test + public void testBuilder_setDisallowAll() { + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowAll(true).build(); + assertTrue(config.disallowAll()); + } + + @Test + public void testBuilder_setDisallowIsError() { + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowIsError(true).build(); + assertTrue(config.disallowIsError()); + } + + @Test + public void testBuilder_setAllowExpression() { + Pattern pattern = Pattern.compile("allow.*"); + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().allowExpression(pattern).build(); + assertThat(config.allowExpression()).hasValue(pattern); + } + + @Test + public void testBuilder_setDisallowExpression() { + Pattern pattern = Pattern.compile("disallow.*"); + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowExpression(pattern).build(); + assertThat(config.disallowExpression()).hasValue(pattern); + } + + @Test + public void testBuilder_setAll() { + Pattern allowPattern = Pattern.compile("allow.*"); + Pattern disallowPattern = Pattern.compile("disallow.*"); + HeaderMutationRulesConfig config = HeaderMutationRulesConfig.builder() + .disallowAll(true) + .disallowIsError(true) + .allowExpression(allowPattern) + .disallowExpression(disallowPattern) + .build(); + assertTrue(config.disallowAll()); + assertTrue(config.disallowIsError()); + assertThat(config.allowExpression()).hasValue(allowPattern); + assertThat(config.disallowExpression()).hasValue(disallowPattern); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java new file mode 100644 index 00000000000..f1dc0561692 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import org.junit.Test; + +public class HeaderMutationsTest { + @Test + public void testCreate() { + HeaderValueOption reqHeader = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey("req-key").setValue("req-value").build()) + .build(); + RequestHeaderMutations requestMutations = RequestHeaderMutations + .create(ImmutableList.of(reqHeader), ImmutableList.of("remove-req-key")); + assertThat(requestMutations.headers()).containsExactly(reqHeader); + assertThat(requestMutations.headersToRemove()).containsExactly("remove-req-key"); + + HeaderValueOption respHeader = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey("resp-key").setValue("resp-value").build()) + .build(); + ResponseHeaderMutations responseMutations = + ResponseHeaderMutations.create(ImmutableList.of(respHeader)); + assertThat(responseMutations.headers()).containsExactly(respHeader); + + HeaderMutations mutations = HeaderMutations.create(requestMutations, responseMutations); + assertThat(mutations.requestMutations()).isEqualTo(requestMutations); + assertThat(mutations.responseMutations()).isEqualTo(responseMutations); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java new file mode 100644 index 00000000000..df6ce383d8c --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java @@ -0,0 +1,311 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.common.io.BaseEncoding; +import com.google.common.testing.TestLogHandler; +import com.google.protobuf.ByteString; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption.HeaderAppendAction; +import io.grpc.Metadata; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutatorTest { + + private static final Metadata.Key ASCII_KEY = + Metadata.Key.of("some-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key BINARY_KEY = + Metadata.Key.of("some-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + private static final Metadata.Key APPEND_KEY = + Metadata.Key.of("append-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key ADD_KEY = + Metadata.Key.of("add-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_KEY = + Metadata.Key.of("overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key REMOVE_KEY = + Metadata.Key.of("remove-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key NEW_ADD_KEY = + Metadata.Key.of("new-add-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key NEW_OVERWRITE_KEY = + Metadata.Key.of("new-overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_IF_EXISTS_KEY = + Metadata.Key.of("overwrite-if-exists-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_IF_EXISTS_ABSENT_KEY = + Metadata.Key.of("overwrite-if-exists-absent-key", Metadata.ASCII_STRING_MARSHALLER); + + private final HeaderMutator headerMutator = HeaderMutator.create(); + + private static final TestLogHandler logHandler = new TestLogHandler(); + private static final Logger logger = + Logger.getLogger(HeaderMutator.HeaderMutatorImpl.class.getName()); + + @Before + public void setUp() { + logHandler.clear(); + logger.addHandler(logHandler); + logger.setLevel(Level.WARNING); + } + + @After + public void tearDown() { + logger.removeHandler(logHandler); + } + + private static HeaderValueOption header(String key, String value, HeaderAppendAction action) { + return HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(key).setValue(value)).setAppendAction(action) + .build(); + } + + @Test + public void applyRequestMutations_asciiHeaders() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "append-value-1"); + headers.put(ADD_KEY, "add-value-original"); + headers.put(OVERWRITE_KEY, "overwrite-value-original"); + headers.put(REMOVE_KEY, "remove-value-original"); + headers.put(OVERWRITE_IF_EXISTS_KEY, "original-value"); + + RequestHeaderMutations mutations = RequestHeaderMutations.create(ImmutableList.of( + // Append to existing header + header(APPEND_KEY.name(), "append-value-2", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + // Try to add to an existing header (should be no-op) + header(ADD_KEY.name(), "add-value-new", HeaderAppendAction.ADD_IF_ABSENT), + // Add a new header + header(NEW_ADD_KEY.name(), "new-add-value", HeaderAppendAction.ADD_IF_ABSENT), + // Overwrite an existing header + header(OVERWRITE_KEY.name(), "overwrite-value-new", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + // Overwrite a new header + header(NEW_OVERWRITE_KEY.name(), "new-overwrite-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + // Overwrite an existing header if it exists + header(OVERWRITE_IF_EXISTS_KEY.name(), "new-value", HeaderAppendAction.OVERWRITE_IF_EXISTS), + // Try to overwrite a header that does not exist + header(OVERWRITE_IF_EXISTS_ABSENT_KEY.name(), "new-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS)), + ImmutableList.of(REMOVE_KEY.name())); + + headerMutator.applyRequestMutations(mutations, headers); + + assertThat(headers.getAll(APPEND_KEY)).containsExactly("append-value-1", "append-value-2"); + assertThat(headers.get(ADD_KEY)).isEqualTo("add-value-original"); + assertThat(headers.get(NEW_ADD_KEY)).isEqualTo("new-add-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("overwrite-value-new"); + assertThat(headers.get(NEW_OVERWRITE_KEY)).isEqualTo("new-overwrite-value"); + assertThat(headers.containsKey(REMOVE_KEY)).isFalse(); + assertThat(headers.get(OVERWRITE_IF_EXISTS_KEY)).isEqualTo("new-value"); + assertThat(headers.containsKey(OVERWRITE_IF_EXISTS_ABSENT_KEY)).isFalse(); + } + + @Test + public void applyRequestMutations_InvalidAppendAction_isIgnored() { + Metadata headers = new Metadata(); + headers.put(ASCII_KEY, "value1"); + headerMutator + .applyRequestMutations( + RequestHeaderMutations + .create( + ImmutableList.of( + HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(ASCII_KEY.name()) + .setValue("value2")) + .setAppendActionValue(-1).build(), + HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()) + .setValue("value2")) + .setAppendActionValue(-5).build()), + ImmutableList.of()), + headers); + assertThat(headers.getAll(ASCII_KEY)).containsExactly("value1"); + } + + @Test + public void applyRequestMutations_removalHasPriority() { + Metadata headers = new Metadata(); + headers.put(REMOVE_KEY, "value"); + RequestHeaderMutations mutations = RequestHeaderMutations.create( + ImmutableList.of( + header(REMOVE_KEY.name(), "new-value", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD)), + ImmutableList.of(REMOVE_KEY.name())); + + headerMutator.applyRequestMutations(mutations, headers); + + assertThat(headers.containsKey(REMOVE_KEY)).isFalse(); + } + + @Test + public void applyRequestMutations_binary_withBase64RawValue() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()).setRawValue( + ByteString.copyFrom(BaseEncoding.base64().encode(value), StandardCharsets.US_ASCII))) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + headerMutator.applyRequestMutations( + RequestHeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyRequestMutations_binary_withBase64Value() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + String base64Value = BaseEncoding.base64().encode(value); + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()).setValue(base64Value)) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + + headerMutator.applyRequestMutations( + RequestHeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyRequestMutations_ascii_withRawValue() { + Metadata headers = new Metadata(); + byte[] value = "raw-value".getBytes(StandardCharsets.US_ASCII); + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(ASCII_KEY.name()) + .setRawValue(ByteString.copyFrom(value))) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + headerMutator.applyRequestMutations( + RequestHeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(Metadata.Key.of(ASCII_KEY.name(), Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("raw-value"); + } + + @Test + public void applyResponseMutations_asciiHeaders() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "append-value-1"); + headers.put(ADD_KEY, "add-value-original"); + headers.put(OVERWRITE_KEY, "overwrite-value-original"); + + ResponseHeaderMutations mutations = ResponseHeaderMutations.create(ImmutableList.of( + header(APPEND_KEY.name(), "append-value-2", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(ADD_KEY.name(), "add-value-new", HeaderAppendAction.ADD_IF_ABSENT), + header(NEW_ADD_KEY.name(), "new-add-value", HeaderAppendAction.ADD_IF_ABSENT), + header(OVERWRITE_KEY.name(), "overwrite-value-new", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header(NEW_OVERWRITE_KEY.name(), "new-overwrite-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD))); + + headerMutator.applyResponseMutations(mutations, headers); + + assertThat(headers.getAll(APPEND_KEY)).containsExactly("append-value-1", "append-value-2"); + assertThat(headers.get(ADD_KEY)).isEqualTo("add-value-original"); + assertThat(headers.get(NEW_ADD_KEY)).isEqualTo("new-add-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("overwrite-value-new"); + assertThat(headers.get(NEW_OVERWRITE_KEY)).isEqualTo("new-overwrite-value"); + } + + + @Test + public void applyResponseMutations_InvalidAppendAction_isIgnored() { + Metadata headers = new Metadata(); + headers.put(ASCII_KEY, "value1"); + headerMutator + .applyResponseMutations( + ResponseHeaderMutations + .create( + ImmutableList.of( + HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(ASCII_KEY.name()) + .setValue("value2")) + .setAppendActionValue(-1).build(), + HeaderValueOption + .newBuilder().setHeader(HeaderValue.newBuilder() + .setKey(BINARY_KEY.name()).setValue("value2")) + .setAppendActionValue(-5).build())), + headers); + assertThat(headers.getAll(ASCII_KEY)).containsExactly("value1"); + } + + @Test + public void applyResponseMutations_binary_withBase64RawValue() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()).setRawValue( + ByteString.copyFrom(BaseEncoding.base64().encode(value), StandardCharsets.US_ASCII))) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + headerMutator.applyResponseMutations(ResponseHeaderMutations.create(ImmutableList.of(option)), + headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyResponseMutations_binary_withBase64Value() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + String base64Value = BaseEncoding.base64().encode(value); + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()).setValue(base64Value)) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + + headerMutator.applyResponseMutations(ResponseHeaderMutations.create(ImmutableList.of(option)), + headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyResponseMutations_ascii_withRawValue() { + Metadata headers = new Metadata(); + byte[] value = "raw-value".getBytes(StandardCharsets.US_ASCII); + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(ASCII_KEY.name()) + .setRawValue(ByteString.copyFrom(value))) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + + headerMutator.applyResponseMutations(ResponseHeaderMutations.create(ImmutableList.of(option)), + headers); + assertThat(headers.get(Metadata.Key.of(ASCII_KEY.name(), Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("raw-value"); + } + + @Test + public void applyRequestMutations_unrecognizedAction_logsWarning() { + Metadata headers = new Metadata(); + RequestHeaderMutations mutations = + RequestHeaderMutations.create(ImmutableList.of(HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey("key").setValue("value")) + .setAppendActionValue(-1).build()), ImmutableList.of()); + headerMutator.applyRequestMutations(mutations, headers); + + List records = logHandler.getStoredLogRecords(); + assertThat(records).hasSize(1); + assertThat(records.get(0).getMessage()) + .contains("Unrecognized HeaderAppendAction: UNRECOGNIZED"); + } +}