Skip to content

Commit 8297c56

Browse files
Add non-blocking buffer to interface with CRT
1 parent 5369a62 commit 8297c56

File tree

2 files changed

+184
-6
lines changed

2 files changed

+184
-6
lines changed

packages/smithy-http/src/smithy_http/aio/crt.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
# pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false
44
# flake8: noqa: F811
55
import asyncio
6+
from collections import deque
67
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable
78
from concurrent.futures import Future
89
from copy import deepcopy
9-
from io import BytesIO
10+
from io import BytesIO, BufferedIOBase
1011
from threading import Lock
1112
from typing import TYPE_CHECKING, Any
1213

@@ -17,6 +18,11 @@
1718
from awscrt import http as crt_http
1819
from awscrt import io as crt_io
1920

21+
# Both of these are types that essentially are "castable to bytes/memoryview"
22+
# Unfortunately they're not exposed anywhere so we have to import them from
23+
# _typeshed.
24+
from _typeshed import WriteableBuffer, ReadableBuffer
25+
2026
try:
2127
from awscrt import http as crt_http
2228
from awscrt import io as crt_io
@@ -304,7 +310,7 @@ async def _marshal_request(
304310
# If the body is async, or potentially very large, start up a task to read
305311
# it into the BytesIO object that CRT needs. By using asyncio.create_task
306312
# we'll start the coroutine without having to explicitly await it.
307-
crt_body = BytesIO()
313+
crt_body = BufferableByteStream()
308314
if not isinstance(body, AsyncIterable):
309315
# If the body isn't already an async iterable, wrap it in one. Objects
310316
# with read methods will be read in chunks so as not to exhaust memory.
@@ -327,15 +333,92 @@ async def _marshal_request(
327333
return crt_request
328334

329335
async def _consume_body_async(
330-
self, source: AsyncIterable[bytes], dest: BytesIO
336+
self, source: AsyncIterable[bytes], dest: "BufferableByteStream"
331337
) -> None:
332338
async for chunk in source:
333339
dest.write(chunk)
334-
# Should we call close here? Or will that make the crt unable to read the last
335-
# chunk?
340+
dest.end_stream()
336341

337342
def __deepcopy__(self, memo: Any) -> "AWSCRTHTTPClient":
338343
return AWSCRTHTTPClient(
339344
eventloop=self._eventloop,
340345
client_config=deepcopy(self._config),
341346
)
347+
348+
349+
# This is adapted from the transcribe streaming sdk
350+
class BufferableByteStream(BufferedIOBase):
351+
"""A non-blocking bytes buffer."""
352+
353+
def __init__(self) -> None:
354+
# We're always manipulating the front and back of the buffer, so a deque
355+
# will be much more efficient than a list.
356+
self._chunks: deque[bytes] = deque()
357+
self._closed = False
358+
self._done = False
359+
360+
def read(self, size: int | None = -1) -> bytes:
361+
if self._closed:
362+
return b""
363+
364+
if len(self._chunks) == 0:
365+
# When the CRT recieves this, it'll try again later.
366+
raise BlockingIOError("read")
367+
368+
# We could compile all the chunks here instead of just returning
369+
# the one, BUT the CRT will keep calling read until empty bytes
370+
# are returned. So it's actually better to just return one chunk
371+
# since combining them would have some potentially bad memory
372+
# usage issues.
373+
result = self._chunks.popleft()
374+
if size is not None and size > 0:
375+
remainder = result[size:]
376+
result = result[:size]
377+
if remainder:
378+
self._chunks.appendleft(remainder)
379+
380+
if self._done and len(self._chunks) == 0:
381+
self.close()
382+
383+
return result
384+
385+
def read1(self, size: int = -1) -> bytes:
386+
return self.read(size)
387+
388+
def readinto(self, buffer: "WriteableBuffer") -> int:
389+
if not isinstance(buffer, memoryview):
390+
buffer = memoryview(buffer).cast("B")
391+
392+
data = self.read(len(buffer)) # type: ignore
393+
n = len(data)
394+
buffer[:n] = data
395+
return n
396+
397+
def write(self, buffer: "ReadableBuffer") -> int:
398+
if not isinstance(buffer, bytes):
399+
raise ValueError(
400+
f"Unexpected value written to BufferableByteStream. "
401+
f"Only bytes are support but {type(buffer)} was provided."
402+
)
403+
404+
if self._closed:
405+
raise IOError("Stream is completed and doesn't support further writes.")
406+
407+
if buffer:
408+
self._chunks.append(buffer)
409+
return len(buffer)
410+
411+
@property
412+
def closed(self) -> bool:
413+
return self._closed
414+
415+
def close(self) -> None:
416+
self._closed = True
417+
self._done = True
418+
419+
# Clear out the remaining chunks so that they don't sit around in memory.
420+
self._chunks.clear()
421+
422+
def end_stream(self) -> None:
423+
"""End the stream, letting any remaining chunks be read before it is closed."""
424+
self._done = True
Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,103 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
13
from copy import deepcopy
24

3-
from smithy_http.aio.crt import AWSCRTHTTPClient
5+
import pytest
6+
7+
from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream
48

59

610
def test_deepcopy_client() -> None:
711
client = AWSCRTHTTPClient()
812
deepcopy(client)
13+
14+
15+
def test_stream_write() -> None:
16+
stream = BufferableByteStream()
17+
stream.write(b"foo")
18+
assert stream.read() == b"foo"
19+
20+
21+
def test_stream_reads_individual_chunks() -> None:
22+
stream = BufferableByteStream()
23+
stream.write(b"foo")
24+
stream.write(b"bar")
25+
assert stream.read() == b"foo"
26+
assert stream.read() == b"bar"
27+
28+
29+
def test_stream_empty_read() -> None:
30+
stream = BufferableByteStream()
31+
with pytest.raises(BlockingIOError):
32+
stream.read()
33+
34+
35+
def test_stream_partial_chunk_read() -> None:
36+
stream = BufferableByteStream()
37+
stream.write(b"foobar")
38+
assert stream.read(3) == b"foo"
39+
assert stream.read() == b"bar"
40+
41+
42+
def test_stream_write_empty_bytes() -> None:
43+
stream = BufferableByteStream()
44+
stream.write(b"")
45+
stream.write(b"foo")
46+
stream.write(b"")
47+
assert stream.read() == b"foo"
48+
49+
50+
def test_stream_write_non_bytes() -> None:
51+
stream = BufferableByteStream()
52+
with pytest.raises(ValueError):
53+
stream.write(memoryview(b"foo"))
54+
55+
56+
def test_closed_stream_write() -> None:
57+
stream = BufferableByteStream()
58+
stream.close()
59+
with pytest.raises(IOError):
60+
stream.write(b"foo")
61+
62+
63+
def test_closed_stream_read() -> None:
64+
stream = BufferableByteStream()
65+
stream.write(b"foo")
66+
stream.close()
67+
assert stream.read() == b""
68+
69+
70+
def test_stream_read1() -> None:
71+
stream = BufferableByteStream()
72+
stream.write(b"foo")
73+
stream.write(b"bar")
74+
assert stream.read1() == b"foo"
75+
assert stream.read1() == b"bar"
76+
with pytest.raises(BlockingIOError):
77+
stream.read()
78+
79+
80+
def test_stream_readinto_memoryview() -> None:
81+
buffer = memoryview(bytearray(b" "))
82+
stream = BufferableByteStream()
83+
stream.write(b"foobar")
84+
stream.readinto(buffer)
85+
assert bytes(buffer) == b"foo"
86+
87+
88+
def test_stream_readinto_bytearray() -> None:
89+
buffer = bytearray(b" ")
90+
stream = BufferableByteStream()
91+
stream.write(b"foobar")
92+
stream.readinto(buffer)
93+
assert bytes(buffer) == b"foo"
94+
95+
96+
def test_end_stream() -> None:
97+
stream = BufferableByteStream()
98+
stream.write(b"foo")
99+
stream.end_stream()
100+
101+
assert not stream.closed
102+
assert stream.read() == b"foo"
103+
assert stream.closed

0 commit comments

Comments
 (0)