diff --git a/packages/smithy-http/src/smithy_http/aio/protocols.py b/packages/smithy-http/src/smithy_http/aio/protocols.py index e5591923..e842bec8 100644 --- a/packages/smithy-http/src/smithy_http/aio/protocols.py +++ b/packages/smithy-http/src/smithy_http/aio/protocols.py @@ -1,9 +1,10 @@ import os +from collections.abc import AsyncIterable from inspect import iscoroutinefunction -from io import BytesIO from typing import Any -from smithy_core.aio.interfaces import ClientProtocol +from smithy_core.aio.interfaces import AsyncByteStream, ClientProtocol +from smithy_core.aio.interfaces import StreamingBlob as AsyncStreamingBlob from smithy_core.codecs import Codec from smithy_core.deserializers import DeserializeableShape from smithy_core.documents import TypeRegistry @@ -109,35 +110,45 @@ async def deserialize_response[ error_registry: TypeRegistry, context: TypedProperties, ) -> OperationOutput: - body = response.body - - # if body is not streaming and is async, we have to buffer it - if not operation.output_stream_member and not is_streaming_blob(body): - if ( - read := getattr(body, "read", None) - ) is not None and iscoroutinefunction(read): - body = BytesIO(await read()) - if not self._is_success(operation, context, response): raise await self._create_error( operation=operation, request=request, response=response, - response_body=body, # type: ignore + response_body=await self._buffer_async_body(response.body), error_registry=error_registry, context=context, ) + # if body is not streaming and is async, we have to buffer it + body: SyncStreamingBlob | None = None + if not operation.output_stream_member and not is_streaming_blob(body): + body = await self._buffer_async_body(response.body) + # TODO(optimization): response binding cache like done in SJ deserializer = HTTPResponseDeserializer( payload_codec=self.payload_codec, http_trait=operation.schema.expect_trait(HTTPTrait), response=response, - body=body, # type: ignore + body=body, ) return operation.output.deserialize(deserializer) + async def _buffer_async_body(self, stream: AsyncStreamingBlob) -> SyncStreamingBlob: + match stream: + case AsyncByteStream(): + if not iscoroutinefunction(stream.read): + return stream # type: ignore + return await stream.read() + case AsyncIterable(): + full = b"" + async for chunk in stream: + full += chunk + return full + case _: + return stream + def _is_success( self, operation: APIOperation[Any, Any],