diff --git a/aws/client/aws-client-restxml/src/it/java/software/amazon/smithy/java/aws/client/restxml/RestXmlProtocolTests.java b/aws/client/aws-client-restxml/src/it/java/software/amazon/smithy/java/aws/client/restxml/RestXmlProtocolTests.java index 78299f942..021e0ba12 100644 --- a/aws/client/aws-client-restxml/src/it/java/software/amazon/smithy/java/aws/client/restxml/RestXmlProtocolTests.java +++ b/aws/client/aws-client-restxml/src/it/java/software/amazon/smithy/java/aws/client/restxml/RestXmlProtocolTests.java @@ -57,10 +57,6 @@ public void requestTest(DataStream expected, DataStream actual) { } @HttpClientResponseTests - @ProtocolTestFilter(skipTests = { - "InvalidGreetingError", - "ComplexError" - }) public void responseTest(Runnable test) { test.run(); } diff --git a/aws/client/aws-client-restxml/src/main/java/software/amazon/smithy/java/aws/client/restxml/RestXmlClientProtocol.java b/aws/client/aws-client-restxml/src/main/java/software/amazon/smithy/java/aws/client/restxml/RestXmlClientProtocol.java index 5527e2a27..40fda578e 100644 --- a/aws/client/aws-client-restxml/src/main/java/software/amazon/smithy/java/aws/client/restxml/RestXmlClientProtocol.java +++ b/aws/client/aws-client-restxml/src/main/java/software/amazon/smithy/java/aws/client/restxml/RestXmlClientProtocol.java @@ -5,6 +5,7 @@ package software.amazon.smithy.java.aws.client.restxml; +import java.nio.ByteBuffer; import software.amazon.smithy.aws.traits.protocols.RestXmlTrait; import software.amazon.smithy.java.aws.events.AwsEventDecoderFactory; import software.amazon.smithy.java.aws.events.AwsEventEncoderFactory; @@ -17,13 +18,17 @@ import software.amazon.smithy.java.client.http.binding.HttpBindingClientProtocol; import software.amazon.smithy.java.client.http.binding.HttpBindingErrorFactory; import software.amazon.smithy.java.context.Context; +import software.amazon.smithy.java.core.error.ModeledException; import software.amazon.smithy.java.core.schema.InputEventStreamingApiOperation; import software.amazon.smithy.java.core.schema.OutputEventStreamingApiOperation; import software.amazon.smithy.java.core.serde.Codec; +import software.amazon.smithy.java.core.serde.TypeRegistry; import software.amazon.smithy.java.core.serde.event.EventDecoderFactory; import software.amazon.smithy.java.core.serde.event.EventEncoderFactory; import software.amazon.smithy.java.core.serde.event.EventStreamingException; +import software.amazon.smithy.java.http.api.HttpResponse; import software.amazon.smithy.java.xml.XmlCodec; +import software.amazon.smithy.java.xml.XmlUtil; import software.amazon.smithy.model.shapes.ShapeId; /** @@ -45,6 +50,7 @@ public RestXmlClientProtocol(ShapeId service) { this.errorDeserializer = HttpErrorDeserializer.builder() .codec(codec) .serviceId(service) + .errorPayloadParser(XML_ERROR_PAYLOAD_PARSER) .knownErrorFactory(new HttpBindingErrorFactory(httpBinding())) .headerErrorExtractor(new AmznErrorHeaderExtractor()) .build(); @@ -89,6 +95,25 @@ protected EventDecoderFactory getEventDecoderFactory( return AwsEventDecoderFactory.forOutputStream(outputOperation, payloadCodec(), f -> f); } + private static final HttpErrorDeserializer.ErrorPayloadParser XML_ERROR_PAYLOAD_PARSER = ( + Context context, + Codec codec, + HttpErrorDeserializer.KnownErrorFactory knownErrorFactory, + ShapeId serviceId, + TypeRegistry typeRegistry, + HttpResponse response, + ByteBuffer buffer) -> { + var deserializer = codec.createDeserializer(buffer); + String code = XmlUtil.parseErrorCodeName(deserializer); + var nameSpace = serviceId.getNamespace(); + var id = ShapeId.fromOptionalNamespace(nameSpace, code); + var builder = typeRegistry.createBuilder(id, ModeledException.class); + if (builder != null) { + return knownErrorFactory.createError(context, codec, response, builder); + } + return null; + }; + public static final class Factory implements ClientProtocolFactory { @Override public ShapeId id() { diff --git a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpErrorDeserializer.java b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpErrorDeserializer.java index e23e9a6c0..cb276721a 100644 --- a/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpErrorDeserializer.java +++ b/client/client-http/src/main/java/software/amazon/smithy/java/client/http/HttpErrorDeserializer.java @@ -111,6 +111,38 @@ default ModeledException createErrorFromDocument( } } + /** + * To create an error, the {@link ShapeId} of the error is required to retrieve the corresponding {@link ShapeBuilder}. + * Different protocols need different parsers to extract the ShapeId given their different response structures. + * If no parser specified, {@link #DEFAULT_ERROR_PAYLOAD_PARSER} will be picked. + */ + @FunctionalInterface + public interface ErrorPayloadParser { + /** + * This method should parse the response payload and extract error's ShapeId,and + * create the corresponding error with the {@link KnownErrorFactory}. + * + * @param context Context of the call. + * @param codec Codec used to deserialize payloads. + * @param knownErrorFactory The knownErrorFactory to create error. + * @param serviceId The ShapeId of the service. + * @param typeRegistry The error typeRegistry to retrieve builder for the error. + * @param response Response to parse. + * @param buffer Bytebuffer of the payload. + * + * @return the created error. + */ + CallException parsePayload( + Context context, + Codec codec, + KnownErrorFactory knownErrorFactory, + ShapeId serviceId, + TypeRegistry typeRegistry, + HttpResponse response, + ByteBuffer buffer + ) throws SerializationException, DiscriminatorException; + } + // Does not check for any error headers by default. private static final HeaderErrorExtractor DEFAULT_EXTRACTOR = new HeaderErrorExtractor() { @Override @@ -159,24 +191,52 @@ public ModeledException createErrorFromDocument( } }; + // This default parser should work for most protocols, but other protocols + // that do not support document types will need a custom parser to extract error ShapeId. + private static final ErrorPayloadParser DEFAULT_ERROR_PAYLOAD_PARSER = ( + Context context, + Codec codec, + KnownErrorFactory knownErrorFactory, + ShapeId serviceId, + TypeRegistry typeRegistry, + HttpResponse response, + ByteBuffer buffer) -> { + var document = codec.createDeserializer(buffer).readDocument(); + var id = document.discriminator(); + var builder = typeRegistry.createBuilder(id, ModeledException.class); + if (builder != null) { + return knownErrorFactory.createErrorFromDocument( + context, + codec, + response, + buffer, + document, + builder); + } + return null; + }; + private final Codec codec; private final HeaderErrorExtractor headerErrorExtractor; private final ShapeId serviceId; private final UnknownErrorFactory unknownErrorFactory; private final KnownErrorFactory knownErrorFactory; + private final ErrorPayloadParser errorPayloadParser; private HttpErrorDeserializer( Codec codec, HeaderErrorExtractor headerErrorExtractor, ShapeId serviceId, UnknownErrorFactory unknownErrorFactory, - KnownErrorFactory knownErrorFactory + KnownErrorFactory knownErrorFactory, + ErrorPayloadParser errorPayloadParser ) { this.codec = Objects.requireNonNull(codec, "Missing codec"); this.serviceId = Objects.requireNonNull(serviceId, "Missing serviceId"); this.headerErrorExtractor = headerErrorExtractor; this.unknownErrorFactory = unknownErrorFactory; this.knownErrorFactory = knownErrorFactory; + this.errorPayloadParser = errorPayloadParser; } public static Builder builder() { @@ -200,13 +260,14 @@ public CallException createError( // No error header, no __type: it's an unknown error. return createErrorFromHints(operation, response, unknownErrorFactory); } else { - // Look for __type in the payload. return makeErrorFromPayload( context, codec, knownErrorFactory, unknownErrorFactory, + errorPayloadParser, operation, + serviceId, typeRegistry, response, content); @@ -240,7 +301,9 @@ private static CallException makeErrorFromPayload( Codec codec, KnownErrorFactory knownErrorFactory, UnknownErrorFactory unknownErrorFactory, + ErrorPayloadParser errorPayloadParser, ShapeId operationId, + ShapeId serviceId, TypeRegistry typeRegistry, HttpResponse response, DataStream content @@ -251,17 +314,16 @@ private static CallException makeErrorFromPayload( ByteBuffer buffer = content.asByteBuffer(); if (buffer.remaining() > 0) { - var document = codec.createDeserializer(buffer).readDocument(); - var id = document.discriminator(); - var builder = typeRegistry.createBuilder(id, ModeledException.class); - if (builder != null) { - return knownErrorFactory.createErrorFromDocument( - context, - codec, - response, - buffer, - document, - builder); + var error = errorPayloadParser.parsePayload( + context, + codec, + knownErrorFactory, + serviceId, + typeRegistry, + response, + buffer); + if (error != null) { + return error; } } } catch (SerializationException | DiscriminatorException ignored) { @@ -291,6 +353,7 @@ public static final class Builder { private ShapeId serviceId; private UnknownErrorFactory unknownErrorFactory = DEFAULT_UNKNOWN_FACTORY; private KnownErrorFactory knownErrorFactory = DEFAULT_KNOWN_FACTORY; + private ErrorPayloadParser errorPayloadParser = DEFAULT_ERROR_PAYLOAD_PARSER; private Builder() {} @@ -300,7 +363,8 @@ public HttpErrorDeserializer build() { headerErrorExtractor, serviceId, unknownErrorFactory, - knownErrorFactory); + knownErrorFactory, + errorPayloadParser); } /** @@ -362,5 +426,19 @@ public Builder unknownErrorFactory(UnknownErrorFactory unknownErrorFactory) { this.unknownErrorFactory = Objects.requireNonNull(unknownErrorFactory, "unknownErrorFactory is null"); return this; } + + /** + * The parser to parse the shapeId from the payload. + * + *

The default parser implementation will parse the payload into a {@link Document} and + * use {@link Document#discriminator()} to extract its {@code __type} field as ShapeId + * + * @param errorPayloadParser Parser used to parse the payload. + * @return the builder. + */ + public Builder errorPayloadParser(ErrorPayloadParser errorPayloadParser) { + this.errorPayloadParser = Objects.requireNonNull(errorPayloadParser, "ErrorPayloadParser is null"); + return this; + } } } diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java index 20f5363bc..6b0ed4639 100644 --- a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java @@ -66,6 +66,16 @@ public void close() { // The first deserialization of XML expects a containing XML element for the shape. // The inner deserializer deserializes members and doesn't have this expectation. + // If the name is ErrorResponse or Error, it means we are deserializing an ErrorResponse like: + // + // + // Sender + // InvalidInput + // Invalid input + // + // + // We should skip all the way to the Code element to deserialize the rest of the fields and + // return early to skip the name comparison. private void enter(Schema schema) { try { if (!isTopLevel) { @@ -80,6 +90,9 @@ private void enter(Schema schema) { expected = trait.getValue(); } else if (schema.isMember()) { expected = schema.memberTarget().id().getName(); + } else if (name != null && (name.equals("ErrorResponse") || name.equals("Error"))) { + skipToCodeElement(name); + return; } else { expected = schema.id().getName(); } @@ -100,6 +113,51 @@ private void exit() { } } + private void skipToCodeElement(String name) throws XMLStreamException { + if (name.equals("ErrorResponse")) { + reader.nextMemberElement(); // Move to Error element + } + String element; + while ((element = reader.nextMemberElement()) != null) { + if (element.equals("Code")) { + reader.closeElement(); + return; + } + reader.closeElement(); + } + } + + String parseErrorCodeName() { + try { + var element = reader.nextMemberElement(); + if (element == null || (!element.equals("ErrorResponse") && !element.equals("Error"))) { + throw new SerializationException( + "Expected element or for restXml error response"); + } + if (element.equals("ErrorResponse")) { + element = reader.nextMemberElement(); + if (element == null || !element.equals("Error")) { + throw new SerializationException("Expected element inside "); + } + } + String childElement; + while ((childElement = reader.nextMemberElement()) != null) { + if (childElement.equals("Code")) { + if (reader.getText() == null) { + throw new SerializationException("Expected shape name inside "); + } + var code = reader.getText(); + reader.closeElement(); + return code; + } + reader.closeElement(); + } + throw new SerializationException("Expected element inside "); + } catch (XMLStreamException e) { + throw new SerializationException(e); + } + } + @Override public boolean readBoolean(Schema schema) { enter(schema); diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlUtil.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlUtil.java new file mode 100644 index 000000000..6384d6c95 --- /dev/null +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlUtil.java @@ -0,0 +1,25 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import software.amazon.smithy.java.core.serde.ShapeDeserializer; + +/** + * Utility class for XML codec. + */ +public final class XmlUtil { + /** + * Retrieve the Code element value from the error response. + * + * @param deserializer the deserializer for the error response + * @return String value of the Code element if found + */ + public static String parseErrorCodeName(ShapeDeserializer deserializer) { + try (var xmlDeserializer = (XmlDeserializer) deserializer) { + return xmlDeserializer.parseErrorCodeName(); + } + } +}