Skip to content

Handle AWS error codes in documents #495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any, Final

from smithy_core.codecs import Codec
from smithy_core.exceptions import DiscriminatorError
from smithy_core.schemas import APIOperation
from smithy_core.shapes import ShapeID
from smithy_core.shapes import ShapeID, ShapeType
from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPResponse
from smithy_http.aio.protocols import HttpBindingClientProtocol
from smithy_json import JSONCodec
from smithy_json import JSONCodec, JSONDocument

from ..traits import RestJson1Trait
from ..utils import parse_document_discriminator, parse_error_code


class AWSErrorIdentifier(HTTPErrorIdentifier):
Expand All @@ -24,20 +26,29 @@ def identify(

error_field = response.fields[self._HEADER_KEY]
code = error_field.values[0] if len(error_field.values) > 0 else None
if not code:
return None
if code is not None:
return parse_error_code(code, operation.schema.id.namespace)
return None


code = code.split(":")[0]
if "#" in code:
return ShapeID(code)
return ShapeID.from_parts(name=code, namespace=operation.schema.id.namespace)
class AWSJSONDocument(JSONDocument):
@property
def discriminator(self) -> ShapeID:
if self.shape_type is ShapeType.STRUCTURE:
return self._schema.id
parsed = parse_document_discriminator(self, self._settings.default_namespace)
if parsed is None:
raise DiscriminatorError(
f"Unable to parse discriminator for {self.shape_type} document."
)
return parsed


class RestJsonClientProtocol(HttpBindingClientProtocol):
"""An implementation of the aws.protocols#restJson1 protocol."""

_id: Final = RestJson1Trait.id
_codec: Final = JSONCodec()
_codec: Final = JSONCodec(document_class=AWSJSONDocument)
_contentType: Final = "application/json"
_error_identifier: Final = AWSErrorIdentifier()

Expand Down
32 changes: 32 additions & 0 deletions packages/smithy-aws-core/src/smithy_aws_core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from smithy_core.documents import Document
from smithy_core.shapes import ShapeID, ShapeType


def parse_document_discriminator(
document: Document, default_namespace: str | None
) -> ShapeID | None:
if document.shape_type is ShapeType.MAP:
map_document = document.as_map()
code = map_document.get("__type")
if code is None:
code = map_document.get("code")
if code is not None and code.shape_type is ShapeType.STRING:
return parse_error_code(code.as_string(), default_namespace)

return None


def parse_error_code(code: str, default_namespace: str | None) -> ShapeID | None:
if not code:
return None

code = code.split(":")[0]
if "#" in code:
return ShapeID(code)

if not code or not default_namespace:
return None

return ShapeID.from_parts(name=code, namespace=default_namespace)
57 changes: 56 additions & 1 deletion packages/smithy-aws-core/tests/unit/aio/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from unittest.mock import Mock

import pytest
from smithy_aws_core.aio.protocols import AWSErrorIdentifier
from smithy_aws_core.aio.protocols import AWSErrorIdentifier, AWSJSONDocument
from smithy_core.exceptions import DiscriminatorError
from smithy_core.schemas import APIOperation, Schema
from smithy_core.shapes import ShapeID, ShapeType
from smithy_http import Fields, tuples_to_fields
from smithy_http.aio import HTTPResponse
from smithy_json import JSONSettings


@pytest.mark.parametrize(
Expand All @@ -24,6 +26,7 @@
"com.test#FooError",
),
("", None),
(":", None),
(None, None),
],
)
Expand All @@ -42,3 +45,55 @@ def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> N
actual = error_identifier.identify(operation=operation, response=http_response)

assert actual == expected


@pytest.mark.parametrize(
"document, expected",
[
({"__type": "FooError"}, "com.test#FooError"),
({"__type": "com.test#FooError"}, "com.test#FooError"),
(
{
"__type": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/"
},
"com.test#FooError",
),
(
{
"__type": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate"
},
"com.test#FooError",
),
({"code": "FooError"}, "com.test#FooError"),
({"code": "com.test#FooError"}, "com.test#FooError"),
(
{
"code": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/"
},
"com.test#FooError",
),
(
{
"code": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate"
},
"com.test#FooError",
),
({"__type": "FooError", "code": "BarError"}, "com.test#FooError"),
("FooError", None),
({"__type": None}, None),
({"__type": ""}, None),
({"__type": ":"}, None),
],
)
def test_aws_json_document_discriminator(
document: dict[str, str], expected: ShapeID | None
) -> None:
settings = JSONSettings(
document_class=AWSJSONDocument, default_namespace="com.test"
)
if expected is None:
with pytest.raises(DiscriminatorError):
AWSJSONDocument(document, settings=settings).discriminator
else:
discriminator = AWSJSONDocument(document, settings=settings).discriminator
assert discriminator == expected
78 changes: 78 additions & 0 deletions packages/smithy-aws-core/tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import pytest
from smithy_aws_core.utils import parse_document_discriminator, parse_error_code
from smithy_core.documents import Document
from smithy_core.shapes import ShapeID


@pytest.mark.parametrize(
"document, expected",
[
({"__type": "FooError"}, "com.test#FooError"),
({"__type": "com.test#FooError"}, "com.test#FooError"),
(
{
"__type": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/"
},
"com.test#FooError",
),
(
{
"__type": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate"
},
"com.test#FooError",
),
({"code": "FooError"}, "com.test#FooError"),
({"code": "com.test#FooError"}, "com.test#FooError"),
(
{
"code": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/"
},
"com.test#FooError",
),
(
{
"code": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate"
},
"com.test#FooError",
),
({"__type": "FooError", "code": "BarError"}, "com.test#FooError"),
("FooError", None),
({"__type": None}, None),
({"__type": ""}, None),
({"__type": ":"}, None),
],
)
def test_aws_json_document_discriminator(
document: dict[str, str], expected: ShapeID | None
) -> None:
actual = parse_document_discriminator(Document(document), "com.test")
assert actual == expected


@pytest.mark.parametrize(
"code, expected",
[
("FooError", "com.test#FooError"),
(
"FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/",
"com.test#FooError",
),
(
"com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate",
"com.test#FooError",
),
("", None),
(":", None),
],
)
def test_parse_error_code(code: str, expected: ShapeID | None) -> None:
actual = parse_error_code(code, "com.test")
assert actual == expected


def test_parse_error_code_without_default_namespace() -> None:
actual = parse_error_code("FooError", None)
assert actual is None
6 changes: 4 additions & 2 deletions packages/smithy-core/src/smithy_core/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TypeGuard, override

from .deserializers import DeserializeableShape, ShapeDeserializer
from .exceptions import ExpectationNotMetError, SmithyError
from .exceptions import DiscriminatorError, ExpectationNotMetError, SmithyError
from .schemas import Schema
from .serializers import (
InterceptingSerializer,
Expand Down Expand Up @@ -146,7 +146,9 @@ def shape_type(self) -> ShapeType:
@property
def discriminator(self) -> ShapeID:
"""The shape ID that corresponds to the contents of the document."""
return self._schema.id
if self._type is ShapeType.STRUCTURE:
return self._schema.id
raise DiscriminatorError(f"{self._type} document has no discriminator.")

def is_none(self) -> bool:
"""Indicates whether the document contains a null value."""
Expand Down
5 changes: 5 additions & 0 deletions packages/smithy-core/src/smithy_core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class SerializationError(SmithyError):
"""Base exception type for exceptions raised during serialization."""


class DiscriminatorError(SmithyError):
"""Exception indicating something went wrong when attempting to find the
discriminator in a document."""


class RetryError(SmithyError):
"""Base exception type for all exceptions raised in retry strategies."""

Expand Down
12 changes: 11 additions & 1 deletion packages/smithy-core/tests/unit/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_DocumentDeserializer,
_DocumentSerializer,
)
from smithy_core.exceptions import ExpectationNotMetError
from smithy_core.exceptions import DiscriminatorError, ExpectationNotMetError
from smithy_core.prelude import (
BIG_DECIMAL,
BLOB,
Expand Down Expand Up @@ -938,3 +938,13 @@ def _read_optional_map(k: str, d: ShapeDeserializer):
actual = given.as_shape(DocumentSerdeShape)
case _:
raise Exception(f"Unexpected type: {type(given)}")


def test_document_has_no_discriminator_by_default() -> None:
with pytest.raises(DiscriminatorError):
Document().discriminator


def test_struct_document_has_discriminator() -> None:
document = Document({"integerMember": 1}, schema=SCHEMA)
assert document.discriminator == SCHEMA.id
13 changes: 10 additions & 3 deletions packages/smithy-core/tests/unit/test_type_registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer
from smithy_core.documents import Document, TypeRegistry
from smithy_core.prelude import STRING
from smithy_core.schemas import Schema
from smithy_core.shapes import ShapeID, ShapeType
from smithy_core.shapes import ShapeID
from smithy_core.traits import RequiredTrait


def test_get():
Expand Down Expand Up @@ -59,11 +61,16 @@ def test_deserialize():

class TestShape(DeserializeableShape):
__test__ = False
schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING)
schema = Schema.collection(
id=ShapeID("com.example#Test"),
members={"value": {"index": 0, "target": STRING, "traits": [RequiredTrait()]}},
)

def __init__(self, value: str):
self.value = value

@classmethod
def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape":
return TestShape(deserializer.read_string(schema=TestShape.schema))
return TestShape(
value=deserializer.read_string(schema=cls.schema.members["value"])
)
Loading