3
3
# pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false
4
4
# flake8: noqa: F811
5
5
import asyncio
6
+ from collections import deque
6
7
from collections .abc import AsyncGenerator , AsyncIterable , Awaitable
7
8
from concurrent .futures import Future
8
9
from copy import deepcopy
9
- from io import BytesIO
10
+ from io import BytesIO , BufferedIOBase
10
11
from threading import Lock
11
12
from typing import TYPE_CHECKING , Any
12
13
17
18
from awscrt import http as crt_http
18
19
from awscrt import io as crt_io
19
20
21
+ # Both of these are types that essentially are "castable to bytes/memoryview"
22
+ # Unfortunately they're not exposed anywhere so we have to import them from
23
+ # _typeshed.
24
+ from _typeshed import WriteableBuffer , ReadableBuffer
25
+
20
26
try :
21
27
from awscrt import http as crt_http
22
28
from awscrt import io as crt_io
@@ -304,7 +310,7 @@ async def _marshal_request(
304
310
# If the body is async, or potentially very large, start up a task to read
305
311
# it into the BytesIO object that CRT needs. By using asyncio.create_task
306
312
# we'll start the coroutine without having to explicitly await it.
307
- crt_body = BytesIO ()
313
+ crt_body = BufferableByteStream ()
308
314
if not isinstance (body , AsyncIterable ):
309
315
# If the body isn't already an async iterable, wrap it in one. Objects
310
316
# with read methods will be read in chunks so as not to exhaust memory.
@@ -327,15 +333,92 @@ async def _marshal_request(
327
333
return crt_request
328
334
329
335
async def _consume_body_async (
330
- self , source : AsyncIterable [bytes ], dest : BytesIO
336
+ self , source : AsyncIterable [bytes ], dest : "BufferableByteStream"
331
337
) -> None :
332
338
async for chunk in source :
333
339
dest .write (chunk )
334
- # Should we call close here? Or will that make the crt unable to read the last
335
- # chunk?
340
+ dest .end_stream ()
336
341
337
342
def __deepcopy__ (self , memo : Any ) -> "AWSCRTHTTPClient" :
338
343
return AWSCRTHTTPClient (
339
344
eventloop = self ._eventloop ,
340
345
client_config = deepcopy (self ._config ),
341
346
)
347
+
348
+
349
+ # This is adapted from the transcribe streaming sdk
350
+ class BufferableByteStream (BufferedIOBase ):
351
+ """A non-blocking bytes buffer."""
352
+
353
+ def __init__ (self ) -> None :
354
+ # We're always manipulating the front and back of the buffer, so a deque
355
+ # will be much more efficient than a list.
356
+ self ._chunks : deque [bytes ] = deque ()
357
+ self ._closed = False
358
+ self ._done = False
359
+
360
+ def read (self , size : int | None = - 1 ) -> bytes :
361
+ if self ._closed :
362
+ return b""
363
+
364
+ if len (self ._chunks ) == 0 :
365
+ # When the CRT recieves this, it'll try again later.
366
+ raise BlockingIOError ("read" )
367
+
368
+ # We could compile all the chunks here instead of just returning
369
+ # the one, BUT the CRT will keep calling read until empty bytes
370
+ # are returned. So it's actually better to just return one chunk
371
+ # since combining them would have some potentially bad memory
372
+ # usage issues.
373
+ result = self ._chunks .popleft ()
374
+ if size is not None and size > 0 :
375
+ remainder = result [size :]
376
+ result = result [:size ]
377
+ if remainder :
378
+ self ._chunks .appendleft (remainder )
379
+
380
+ if self ._done and len (self ._chunks ) == 0 :
381
+ self .close ()
382
+
383
+ return result
384
+
385
+ def read1 (self , size : int = - 1 ) -> bytes :
386
+ return self .read (size )
387
+
388
+ def readinto (self , buffer : "WriteableBuffer" ) -> int :
389
+ if not isinstance (buffer , memoryview ):
390
+ buffer = memoryview (buffer ).cast ("B" )
391
+
392
+ data = self .read (len (buffer )) # type: ignore
393
+ n = len (data )
394
+ buffer [:n ] = data
395
+ return n
396
+
397
+ def write (self , buffer : "ReadableBuffer" ) -> int :
398
+ if not isinstance (buffer , bytes ):
399
+ raise ValueError (
400
+ f"Unexpected value written to BufferableByteStream. "
401
+ f"Only bytes are support but { type (buffer )} was provided."
402
+ )
403
+
404
+ if self ._closed :
405
+ raise IOError ("Stream is completed and doesn't support further writes." )
406
+
407
+ if buffer :
408
+ self ._chunks .append (buffer )
409
+ return len (buffer )
410
+
411
+ @property
412
+ def closed (self ) -> bool :
413
+ return self ._closed
414
+
415
+ def close (self ) -> None :
416
+ self ._closed = True
417
+ self ._done = True
418
+
419
+ # Clear out the remaining chunks so that they don't sit around in memory.
420
+ self ._chunks .clear ()
421
+
422
+ def end_stream (self ) -> None :
423
+ """End the stream, letting any remaining chunks be read before it is closed."""
424
+ self ._done = True
0 commit comments