Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ public void requestTest(DataStream expected, DataStream actual) {
}

@HttpClientResponseTests
@ProtocolTestFilter(skipTests = {
"InvalidGreetingError",
"ComplexError"
})
public void responseTest(Runnable test) {
test.run();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.java.aws.client.restxml;

import java.nio.ByteBuffer;
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait;
import software.amazon.smithy.java.aws.events.AwsEventDecoderFactory;
import software.amazon.smithy.java.aws.events.AwsEventEncoderFactory;
Expand All @@ -17,12 +18,15 @@
import software.amazon.smithy.java.client.http.binding.HttpBindingClientProtocol;
import software.amazon.smithy.java.client.http.binding.HttpBindingErrorFactory;
import software.amazon.smithy.java.context.Context;
import software.amazon.smithy.java.core.error.ModeledException;
import software.amazon.smithy.java.core.schema.InputEventStreamingApiOperation;
import software.amazon.smithy.java.core.schema.OutputEventStreamingApiOperation;
import software.amazon.smithy.java.core.serde.Codec;
import software.amazon.smithy.java.core.serde.TypeRegistry;
import software.amazon.smithy.java.core.serde.event.EventDecoderFactory;
import software.amazon.smithy.java.core.serde.event.EventEncoderFactory;
import software.amazon.smithy.java.core.serde.event.EventStreamingException;
import software.amazon.smithy.java.http.api.HttpResponse;
import software.amazon.smithy.java.xml.XmlCodec;
import software.amazon.smithy.model.shapes.ShapeId;

Expand All @@ -45,6 +49,7 @@ public RestXmlClientProtocol(ShapeId service) {
this.errorDeserializer = HttpErrorDeserializer.builder()
.codec(codec)
.serviceId(service)
.errorPayloadParser(XML_ERROR_PAYLOAD_PARSER)
.knownErrorFactory(new HttpBindingErrorFactory(httpBinding()))
.headerErrorExtractor(new AmznErrorHeaderExtractor())
.build();
Expand Down Expand Up @@ -89,6 +94,25 @@ protected EventDecoderFactory<AwsEventFrame> getEventDecoderFactory(
return AwsEventDecoderFactory.forOutputStream(outputOperation, payloadCodec(), f -> f);
}

private static final HttpErrorDeserializer.ErrorPayloadParser XML_ERROR_PAYLOAD_PARSER = (
Context context,
Codec codec,
HttpErrorDeserializer.KnownErrorFactory knownErrorFactory,
ShapeId serviceId,
TypeRegistry typeRegistry,
HttpResponse response,
ByteBuffer buffer) -> {
var xmlCodec = (XmlCodec) codec;
String code = xmlCodec.parseCodeName(buffer);
var nameSpace = serviceId.getNamespace();
var id = ShapeId.fromOptionalNamespace(nameSpace, code);
var builder = typeRegistry.createBuilder(id, ModeledException.class);
if (builder != null) {
return knownErrorFactory.createError(context, codec, response, builder);
}
return null;
};

public static final class Factory implements ClientProtocolFactory<RestXmlTrait> {
@Override
public ShapeId id() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,38 @@ default ModeledException createErrorFromDocument(
}
}

/**
* To create an error, the {@link ShapeId} of the error is required to retrieve the corresponding {@link ShapeBuilder}.
* Different protocols need different parsers to extract the ShapeId given their different response structures.
* If no parser specified, {@link #DEFAULT_ERROR_PAYLOAD_PARSER} will be picked.
*/
@FunctionalInterface
public interface ErrorPayloadParser {
/**
* This method should parse the response payload and extract error's ShapeId,and
* create the corresponding error with the {@link KnownErrorFactory}.
*
* @param context Context of the call.
* @param codec Codec used to deserialize payloads.
* @param knownErrorFactory The knownErrorFactory to create error.
* @param serviceId The ShapeId of the service.
* @param typeRegistry The error typeRegistry to retrieve builder for the error.
* @param response Response to parse.
* @param buffer Bytebuffer of the payload.
*
* @return the created error.
*/
CallException parsePayload(
Context context,
Codec codec,
KnownErrorFactory knownErrorFactory,
ShapeId serviceId,
TypeRegistry typeRegistry,
HttpResponse response,
ByteBuffer buffer
) throws SerializationException, DiscriminatorException;
}

// Does not check for any error headers by default.
private static final HeaderErrorExtractor DEFAULT_EXTRACTOR = new HeaderErrorExtractor() {
@Override
Expand Down Expand Up @@ -159,24 +191,52 @@ public ModeledException createErrorFromDocument(
}
};

// This default parser should work for most protocols, but other protocols
// that do not support document types will need a custom parser to extract error ShapeId.
private static final ErrorPayloadParser DEFAULT_ERROR_PAYLOAD_PARSER = (
Context context,
Codec codec,
KnownErrorFactory knownErrorFactory,
ShapeId serviceId,
TypeRegistry typeRegistry,
HttpResponse response,
ByteBuffer buffer) -> {
var document = codec.createDeserializer(buffer).readDocument();
var id = document.discriminator();
var builder = typeRegistry.createBuilder(id, ModeledException.class);
if (builder != null) {
return knownErrorFactory.createErrorFromDocument(
context,
codec,
response,
buffer,
document,
builder);
}
return null;
};

private final Codec codec;
private final HeaderErrorExtractor headerErrorExtractor;
private final ShapeId serviceId;
private final UnknownErrorFactory unknownErrorFactory;
private final KnownErrorFactory knownErrorFactory;
private final ErrorPayloadParser errorPayloadParser;

private HttpErrorDeserializer(
Codec codec,
HeaderErrorExtractor headerErrorExtractor,
ShapeId serviceId,
UnknownErrorFactory unknownErrorFactory,
KnownErrorFactory knownErrorFactory
KnownErrorFactory knownErrorFactory,
ErrorPayloadParser errorPayloadParser
) {
this.codec = Objects.requireNonNull(codec, "Missing codec");
this.serviceId = Objects.requireNonNull(serviceId, "Missing serviceId");
this.headerErrorExtractor = headerErrorExtractor;
this.unknownErrorFactory = unknownErrorFactory;
this.knownErrorFactory = knownErrorFactory;
this.errorPayloadParser = errorPayloadParser;
}

public static Builder builder() {
Expand All @@ -200,13 +260,14 @@ public CallException createError(
// No error header, no __type: it's an unknown error.
return createErrorFromHints(operation, response, unknownErrorFactory);
} else {
// Look for __type in the payload.
return makeErrorFromPayload(
context,
codec,
knownErrorFactory,
unknownErrorFactory,
errorPayloadParser,
operation,
serviceId,
typeRegistry,
response,
content);
Expand Down Expand Up @@ -240,7 +301,9 @@ private static CallException makeErrorFromPayload(
Codec codec,
KnownErrorFactory knownErrorFactory,
UnknownErrorFactory unknownErrorFactory,
ErrorPayloadParser errorPayloadParser,
ShapeId operationId,
ShapeId serviceId,
TypeRegistry typeRegistry,
HttpResponse response,
DataStream content
Expand All @@ -251,17 +314,16 @@ private static CallException makeErrorFromPayload(
ByteBuffer buffer = content.asByteBuffer();

if (buffer.remaining() > 0) {
var document = codec.createDeserializer(buffer).readDocument();
var id = document.discriminator();
var builder = typeRegistry.createBuilder(id, ModeledException.class);
if (builder != null) {
return knownErrorFactory.createErrorFromDocument(
context,
codec,
response,
buffer,
document,
builder);
var error = errorPayloadParser.parsePayload(
context,
codec,
knownErrorFactory,
serviceId,
typeRegistry,
response,
buffer);
if (error != null) {
return error;
}
}
} catch (SerializationException | DiscriminatorException ignored) {
Expand Down Expand Up @@ -291,6 +353,7 @@ public static final class Builder {
private ShapeId serviceId;
private UnknownErrorFactory unknownErrorFactory = DEFAULT_UNKNOWN_FACTORY;
private KnownErrorFactory knownErrorFactory = DEFAULT_KNOWN_FACTORY;
private ErrorPayloadParser errorPayloadParser = DEFAULT_ERROR_PAYLOAD_PARSER;

private Builder() {}

Expand All @@ -300,7 +363,8 @@ public HttpErrorDeserializer build() {
headerErrorExtractor,
serviceId,
unknownErrorFactory,
knownErrorFactory);
knownErrorFactory,
errorPayloadParser);
}

/**
Expand Down Expand Up @@ -362,5 +426,19 @@ public Builder unknownErrorFactory(UnknownErrorFactory unknownErrorFactory) {
this.unknownErrorFactory = Objects.requireNonNull(unknownErrorFactory, "unknownErrorFactory is null");
return this;
}

/**
* The parser to parse the shapeId from the payload.
*
* <p>The default parser implementation will parse the payload into a {@link Document} and
* use {@link Document#discriminator()} to extract its {@code __type} field as ShapeId
*
* @param errorPayloadParser Parser used to parse the payload.
* @return the builder.
*/
public Builder errorPayloadParser(ErrorPayloadParser errorPayloadParser) {
this.errorPayloadParser = Objects.requireNonNull(errorPayloadParser, "ErrorPayloadParser is null");
return this;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ public ShapeDeserializer createDeserializer(ByteBuffer source) {
}
}

/**
* Retrieve the Code element value from the error response.
*
* @param source Response payload source
* @return String value of the Code element if found
*/
public String parseCodeName(ByteBuffer source) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems out of place on XmlCodec. Maybe move to XmlDeserializer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XmlDeserializer is package private and we cannot access it if we put it to XmlDeserializer. I think XmlCodec is the only public class under this package, so I put the method here.

try (var deserializer = (XmlDeserializer) createDeserializer(source)) {
return deserializer.parseCodeName();
}
}

/**
* Builder used to create an XML codec.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ private void enter(Schema schema) {
expected = trait.getValue();
} else if (schema.isMember()) {
expected = schema.memberTarget().id().getName();
} else if (name != null && (name.equals("ErrorResponse") || name.equals("Error"))) {
skipToCodeElement(name);
return;
} else {
expected = schema.id().getName();
}
Expand All @@ -100,6 +103,51 @@ private void exit() {
}
}

private void skipToCodeElement(String name) throws XMLStreamException {
if (name.equals("ErrorResponse")) {
reader.nextMemberElement(); // Move to Error element
}
String element;
while ((element = reader.nextMemberElement()) != null) {
if (element.equals("Code")) {
reader.closeElement();
return;
}
reader.closeElement();
}
}

String parseCodeName() {
try {
var element = reader.nextMemberElement();
if (element == null || (!element.equals("ErrorResponse") && !element.equals("Error"))) {
throw new SerializationException(
"Expected element <ErrorResponse> or <Error> for restXml error response");
}
if (element.equals("ErrorResponse")) {
element = reader.nextMemberElement();
if (element == null || !element.equals("Error")) {
throw new SerializationException("Expected <Error> element inside <ErrorResponse>");
}
}
String childElement;
while ((childElement = reader.nextMemberElement()) != null) {
if (childElement.equals("Code")) {
if (reader.getText() == null) {
throw new SerializationException("Expected shape name inside <Code>");
}
var code = reader.getText();
reader.closeElement();
return code;
}
reader.closeElement();
}
throw new SerializationException("Expected <Code> element inside <Error>");
} catch (XMLStreamException e) {
throw new SerializationException(e);
}
}

@Override
public boolean readBoolean(Schema schema) {
enter(schema);
Expand Down
Loading