Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 31a635c

Browse files
committedMay 9, 2025
Implement the request pipeline in pure python
1 parent f00644a commit 31a635c

File tree

3 files changed

+576
-0
lines changed

3 files changed

+576
-0
lines changed
 
Lines changed: 566 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,566 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import asyncio
4+
import logging
5+
from asyncio import Future, sleep
6+
from collections.abc import Awaitable, Callable, Sequence
7+
from copy import copy
8+
from dataclasses import dataclass, replace
9+
from typing import Any
10+
11+
from .. import URI
12+
from ..auth import AuthParams
13+
from ..deserializers import DeserializeableShape, ShapeDeserializer
14+
from ..endpoints import EndpointResolverParams
15+
from ..exceptions import RetryError, SmithyError
16+
from ..interceptors import (
17+
InputContext,
18+
Interceptor,
19+
OutputContext,
20+
RequestContext,
21+
ResponseContext,
22+
)
23+
from ..interfaces import Endpoint, TypedProperties
24+
from ..interfaces.auth import AuthOption, AuthSchemeResolver
25+
from ..interfaces.retries import RetryStrategy
26+
from ..schemas import APIOperation
27+
from ..serializers import SerializeableShape
28+
from ..shapes import ShapeID
29+
from ..types import PropertyKey
30+
from .eventstream import DuplexEventStream, InputEventStream, OutputEventStream
31+
from .interfaces import (
32+
ClientProtocol,
33+
ClientTransport,
34+
EndpointResolver,
35+
Request,
36+
Response,
37+
)
38+
from .interfaces.auth import AuthScheme
39+
from .interfaces.eventstream import EventReceiver
40+
from .utils import seek
41+
42+
AUTH_SCHEME = PropertyKey(key="auth_scheme", value_type=AuthScheme[Any, Any, Any, Any])
43+
44+
_UNRESOLVED = URI(host="", path="/")
45+
_LOGGER = logging.getLogger(__name__)
46+
47+
48+
@dataclass(kw_only=True, frozen=True)
49+
class ClientCall[I: SerializeableShape, O: DeserializeableShape]:
50+
"""A data class containing all the initial information about an operation
51+
invocation."""
52+
53+
input: I
54+
"""The input of the operation."""
55+
56+
operation: APIOperation[I, O]
57+
"""The schema of the operation."""
58+
59+
context: TypedProperties
60+
"""The initial context of the operation."""
61+
62+
interceptor: Interceptor[I, O, Any, Any]
63+
"""The interceptor to use in the course of the operation invocation.
64+
65+
This SHOULD be an InterceptorChain.
66+
"""
67+
68+
auth_scheme_resolver: AuthSchemeResolver
69+
"""The auth scheme resolver for the operation."""
70+
71+
supported_auth_schemes: dict[ShapeID, AuthScheme[Any, Any, Any, Any]]
72+
"""The supported auth schemes for the operation."""
73+
74+
endpoint_resolver: EndpointResolver
75+
"""The endpoint resolver for the operation."""
76+
77+
retry_strategy: RetryStrategy
78+
"""The retry strategy to use for the operation."""
79+
80+
retry_scope: str | None = None
81+
"""The retry scope for the operation."""
82+
83+
def retryable(self) -> bool:
84+
# TODO: check to see if the stream is seekable
85+
return self.operation.input_stream_member is None
86+
87+
88+
class RequestPipeline[TRequest: Request, TResponse: Response]:
89+
"""Invokes client operations asynchronously."""
90+
91+
protocol: ClientProtocol[TRequest, TResponse]
92+
"""The protocol to use to serialize the request and deserialize the response."""
93+
94+
transport: ClientTransport[TRequest, TResponse]
95+
"""The transport to use to send the request and receive the response (e.g. an HTTP
96+
Client)."""
97+
98+
def __init__(
99+
self,
100+
protocol: ClientProtocol[TRequest, TResponse],
101+
transport: ClientTransport[TRequest, TResponse],
102+
) -> None:
103+
self.protocol = protocol
104+
self.transport = transport
105+
106+
async def __call__[I: SerializeableShape, O: DeserializeableShape](
107+
self, call: ClientCall[I, O], /
108+
) -> O:
109+
"""Invoke an operation asynchronously.
110+
111+
:param call: The operation to invoke and associated context.
112+
"""
113+
output, _ = await self._execute_request(call, None)
114+
return output
115+
116+
async def input_stream[
117+
I: SerializeableShape,
118+
O: DeserializeableShape,
119+
E: SerializeableShape,
120+
](self, call: ClientCall[I, O], event_type: type[E], /) -> InputEventStream[E, O]:
121+
"""Invoke an input stream operation asynchronously.
122+
123+
:param call: The operation to invoke and associated context.
124+
:param event_type: The event type to send in the input stream.
125+
"""
126+
request_future = Future[RequestContext[I, TRequest]]()
127+
output_future = asyncio.create_task(
128+
self._await_output(self._execute_request(call, request_future))
129+
)
130+
request_context = await request_future
131+
input_stream = self.protocol.create_event_publisher(
132+
operation=call.operation,
133+
request=request_context.transport_request,
134+
event_type=event_type,
135+
context=request_context.properties,
136+
auth_scheme=request_context.properties.get(AUTH_SCHEME),
137+
)
138+
return InputEventStream(input_stream=input_stream, output_future=output_future)
139+
140+
async def _await_output[I: SerializeableShape, O: DeserializeableShape](
141+
self,
142+
execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest, TResponse]]],
143+
) -> O:
144+
output, _ = await execute_task
145+
return output
146+
147+
async def output_stream[
148+
I: SerializeableShape,
149+
O: DeserializeableShape,
150+
E: DeserializeableShape,
151+
](
152+
self,
153+
call: ClientCall[I, O],
154+
event_type: type[E],
155+
event_deserializer: Callable[[ShapeDeserializer], E],
156+
/,
157+
) -> OutputEventStream[E, O]:
158+
"""Invoke an input stream operation asynchronously.
159+
160+
:param call: The operation to invoke and associated context.
161+
:param event_type: The event type to receive in the output stream.
162+
:param event_deserializer: The method used to deserialize events.
163+
"""
164+
output, output_context = await self._execute_request(call, None)
165+
output_stream = self.protocol.create_event_receiver(
166+
operation=call.operation,
167+
request=output_context.transport_request,
168+
response=output_context.transport_response,
169+
event_type=event_type,
170+
event_deserializer=event_deserializer,
171+
context=output_context.properties,
172+
)
173+
return OutputEventStream(output_stream=output_stream, output=output)
174+
175+
async def duplex_stream[
176+
I: SerializeableShape,
177+
O: DeserializeableShape,
178+
IE: SerializeableShape,
179+
OE: DeserializeableShape,
180+
](
181+
self,
182+
call: ClientCall[I, O],
183+
input_event_type: type[IE],
184+
output_event_type: type[OE],
185+
event_deserializer: Callable[[ShapeDeserializer], OE],
186+
/,
187+
) -> DuplexEventStream[IE, OE, O]:
188+
"""Invoke an input stream operation asynchronously.
189+
190+
:param call: The operation to invoke and associated context.
191+
:param input_event_type: The event type to send in the input stream.
192+
:param output_event_type: The event type to receive in the output stream.
193+
:param event_deserializer: The method used to deserialize events.
194+
"""
195+
request_future = Future[RequestContext[I, TRequest]]()
196+
execute_task = asyncio.create_task(self._execute_request(call, request_future))
197+
request_context = await request_future
198+
input_stream = self.protocol.create_event_publisher(
199+
operation=call.operation,
200+
request=request_context.transport_request,
201+
event_type=input_event_type,
202+
context=request_context.properties,
203+
auth_scheme=request_context.properties.get(AUTH_SCHEME),
204+
)
205+
output_future = asyncio.create_task(
206+
self._await_output_stream(
207+
call=call,
208+
execute_task=execute_task,
209+
output_event_type=output_event_type,
210+
event_deserializer=event_deserializer,
211+
)
212+
)
213+
return DuplexEventStream(input_stream=input_stream, output_future=output_future)
214+
215+
async def _await_output_stream[
216+
I: SerializeableShape,
217+
O: DeserializeableShape,
218+
OE: DeserializeableShape,
219+
](
220+
self,
221+
call: ClientCall[I, O],
222+
execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest, TResponse]]],
223+
output_event_type: type[OE],
224+
event_deserializer: Callable[[ShapeDeserializer], OE],
225+
) -> tuple[O, EventReceiver[OE]]:
226+
output, output_context = await execute_task
227+
output_stream = self.protocol.create_event_receiver(
228+
operation=call.operation,
229+
request=output_context.transport_request,
230+
response=output_context.transport_response,
231+
event_type=output_event_type,
232+
event_deserializer=event_deserializer,
233+
context=output_context.properties,
234+
)
235+
return output, output_stream
236+
237+
async def _execute_request[I: SerializeableShape, O: DeserializeableShape](
238+
self,
239+
call: ClientCall[I, O],
240+
request_future: Future[RequestContext[I, TRequest]] | None,
241+
) -> tuple[O, OutputContext[I, O, TRequest, TResponse]]:
242+
_LOGGER.debug(
243+
'Making request for operation "%s" with parameters: %s',
244+
call.operation.schema.id.name,
245+
call.input,
246+
)
247+
output_context = await self._handle_execution(call, request_future)
248+
output_context = self._finalize_execution(call, output_context)
249+
250+
if isinstance(output_context.response, Exception):
251+
e = output_context.response
252+
if not isinstance(e, SmithyError):
253+
raise SmithyError(e) from e
254+
raise e
255+
256+
return output_context.response, output_context # type: ignore
257+
258+
async def _handle_execution[I: SerializeableShape, O: DeserializeableShape](
259+
self,
260+
call: ClientCall[I, O],
261+
request_future: Future[RequestContext[I, TRequest]] | None,
262+
) -> OutputContext[I, O, TRequest | None, TResponse | None]:
263+
try:
264+
interceptor = call.interceptor
265+
266+
input_context = InputContext(request=call.input, properties=call.context)
267+
interceptor.read_before_execution(input_context)
268+
269+
input_context = replace(
270+
input_context,
271+
request=interceptor.modify_before_serialization(input_context),
272+
)
273+
274+
interceptor.read_before_serialization(input_context)
275+
_LOGGER.debug("Serializing request for: %s", input_context.request)
276+
277+
transport_request = self.protocol.serialize_request(
278+
operation=call.operation,
279+
input=call.input,
280+
endpoint=_UNRESOLVED,
281+
context=input_context.properties,
282+
)
283+
request_context = RequestContext(
284+
request=input_context.request,
285+
transport_request=transport_request,
286+
properties=input_context.properties,
287+
)
288+
289+
_LOGGER.debug(
290+
"Serialization complete. Transport request: %s", transport_request
291+
)
292+
except Exception as e:
293+
return OutputContext(
294+
request=call.input,
295+
response=e,
296+
transport_request=None,
297+
transport_response=None,
298+
properties=call.context,
299+
)
300+
301+
try:
302+
interceptor.read_after_serialization(request_context)
303+
request_context = replace(
304+
request_context,
305+
transport_request=interceptor.modify_before_retry_loop(request_context),
306+
)
307+
308+
return await self._retry(call, request_context, request_future)
309+
except Exception as e:
310+
return OutputContext(
311+
request=request_context.request,
312+
response=e,
313+
transport_request=request_context.transport_request,
314+
transport_response=None,
315+
properties=request_context.properties,
316+
)
317+
318+
async def _retry[I: SerializeableShape, O: DeserializeableShape](
319+
self,
320+
call: ClientCall[I, O],
321+
request_context: RequestContext[I, TRequest],
322+
request_future: Future[RequestContext[I, TRequest]] | None,
323+
) -> OutputContext[I, O, TRequest | None, TResponse | None]:
324+
if not call.retryable():
325+
return await self._handle_attempt(call, request_context, request_future)
326+
327+
retry_strategy = call.retry_strategy
328+
retry_token = retry_strategy.acquire_initial_retry_token(
329+
token_scope=call.retry_scope
330+
)
331+
332+
while True:
333+
if retry_token.retry_delay:
334+
await sleep(retry_token.retry_delay)
335+
336+
output_context = await self._handle_attempt(
337+
call,
338+
replace(
339+
request_context,
340+
transport_request=copy(request_context.transport_request),
341+
),
342+
request_future,
343+
)
344+
345+
if isinstance(output_context.response, Exception):
346+
try:
347+
retry_strategy.refresh_retry_token_for_retry(
348+
token_to_renew=retry_token,
349+
error=output_context.response,
350+
)
351+
except RetryError:
352+
raise output_context.response
353+
354+
_LOGGER.debug(
355+
"Retry needed. Attempting request #%s in %.4f seconds.",
356+
retry_token.retry_count + 1,
357+
retry_token.retry_delay,
358+
)
359+
360+
await seek(request_context.transport_request.body, 0)
361+
else:
362+
retry_strategy.record_success(token=retry_token)
363+
return output_context
364+
365+
async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](
366+
self,
367+
call: ClientCall[I, O],
368+
request_context: RequestContext[I, TRequest],
369+
request_future: Future[RequestContext[I, TRequest]] | None,
370+
) -> OutputContext[I, O, TRequest, TResponse | None]:
371+
output_context: OutputContext[I, O, TRequest, TResponse | None]
372+
try:
373+
interceptor = call.interceptor
374+
interceptor.read_before_attempt(request_context)
375+
376+
endpoint_params = EndpointResolverParams(
377+
operation=call.operation,
378+
input=call.input,
379+
context=request_context.properties,
380+
)
381+
_LOGGER.debug("Calling endpoint resolver.")
382+
endpoint: Endpoint = await call.endpoint_resolver.resolve_endpoint(
383+
endpoint_params
384+
)
385+
_LOGGER.debug("Endpoint resolver result: %s", endpoint)
386+
387+
request_context = replace(
388+
request_context,
389+
transport_request=self.protocol.set_service_endpoint(
390+
request=request_context.transport_request, endpoint=endpoint
391+
),
392+
)
393+
394+
request_context = replace(
395+
request_context,
396+
transport_request=interceptor.modify_before_signing(request_context),
397+
)
398+
interceptor.read_before_signing(request_context)
399+
400+
auth_params = AuthParams[I, O](
401+
protocol_id=self.protocol.id,
402+
operation=call.operation,
403+
context=request_context.properties,
404+
)
405+
auth = self._resolve_auth(call, auth_params)
406+
if auth is not None:
407+
option, scheme = auth
408+
request_context.properties[AUTH_SCHEME] = scheme
409+
identity_resolver = scheme.identity_resolver(context=call.context)
410+
identity = await identity_resolver.get_identity(
411+
properties=option.identity_properties
412+
)
413+
414+
_LOGGER.debug("Request to sign: %s", request_context.transport_request)
415+
_LOGGER.debug("Signer properties: %s", option.signer_properties)
416+
417+
signer = scheme.signer()
418+
request_context = replace(
419+
request_context,
420+
transport_request=await signer.sign(
421+
request=request_context.transport_request,
422+
identity=identity,
423+
properties=option.signer_properties,
424+
),
425+
)
426+
427+
interceptor.read_after_signing(request_context)
428+
request_context = replace(
429+
request_context,
430+
transport_request=interceptor.modify_before_transmit(request_context),
431+
)
432+
interceptor.read_before_transmit(request_context)
433+
434+
_LOGGER.debug("Sending request %s", request_context.transport_request)
435+
436+
if request_future is not None:
437+
# If we have an input event stream (or duplex event stream) then we
438+
# need to let the client return ASAP so that it can start sending
439+
# events. So here we start the transport send in a background task
440+
# then set the result of the request future. It's important to sequence
441+
# it just like that so that the client gets a stream that's ready
442+
# to send.
443+
transport_task = asyncio.create_task(
444+
self.transport.send(request=request_context.transport_request)
445+
)
446+
request_future.set_result(request_context)
447+
transport_response = await transport_task
448+
else:
449+
# If we don't have an input stream, there's no point in creating a
450+
# task, so we just immediately await the coroutine.
451+
transport_response = await self.transport.send(
452+
request=request_context.transport_request
453+
)
454+
455+
_LOGGER.debug("Received response: %s", transport_response)
456+
457+
response_context = ResponseContext(
458+
request=request_context.request,
459+
transport_request=request_context.transport_request,
460+
transport_response=transport_response,
461+
properties=request_context.properties,
462+
)
463+
464+
interceptor.read_after_transmit(response_context)
465+
466+
response_context = replace(
467+
response_context,
468+
transport_response=interceptor.modify_before_deserialization(
469+
response_context
470+
),
471+
)
472+
473+
interceptor.read_before_deserialization(response_context)
474+
475+
_LOGGER.debug(
476+
"Deserializing response: %s", response_context.transport_response
477+
)
478+
479+
output = await self.protocol.deserialize_response(
480+
operation=call.operation,
481+
request=response_context.transport_request,
482+
response=response_context.transport_response,
483+
error_registry=call.operation.error_registry,
484+
context=response_context.properties,
485+
)
486+
487+
_LOGGER.debug("Deserialization complete. Output: %s", output)
488+
489+
output_context = OutputContext(
490+
request=response_context.request,
491+
response=output,
492+
transport_request=response_context.transport_request,
493+
transport_response=response_context.transport_response,
494+
properties=response_context.properties,
495+
)
496+
497+
interceptor.read_after_deserialization(output_context)
498+
except Exception as e:
499+
output_context = OutputContext(
500+
request=request_context.request,
501+
response=e,
502+
transport_request=request_context.transport_request,
503+
transport_response=None,
504+
properties=request_context.properties,
505+
)
506+
507+
return self._finalize_attempt(call, output_context)
508+
509+
def _resolve_auth[I: SerializeableShape, O: DeserializeableShape](
510+
self, call: ClientCall[Any, Any], params: AuthParams[I, O]
511+
) -> tuple[AuthOption, AuthScheme[TRequest, Any, Any, Any]] | None:
512+
auth_options: Sequence[AuthOption] = (
513+
call.auth_scheme_resolver.resolve_auth_scheme(auth_parameters=params)
514+
)
515+
516+
for option in auth_options:
517+
if (
518+
scheme := call.supported_auth_schemes.get(option.scheme_id)
519+
) is not None:
520+
return option, scheme
521+
522+
return None
523+
524+
def _finalize_attempt[I: SerializeableShape, O: DeserializeableShape](
525+
self,
526+
call: ClientCall[I, O],
527+
output_context: OutputContext[I, O, TRequest, TResponse | None],
528+
) -> OutputContext[I, O, TRequest, TResponse | None]:
529+
interceptor = call.interceptor
530+
try:
531+
output_context = replace(
532+
output_context,
533+
response=interceptor.modify_before_attempt_completion(output_context),
534+
)
535+
except Exception as e:
536+
output_context = replace(output_context, response=e)
537+
538+
try:
539+
interceptor.read_after_attempt(output_context)
540+
except Exception as e:
541+
output_context = replace(output_context, response=e)
542+
543+
return output_context
544+
545+
def _finalize_execution[I: SerializeableShape, O: DeserializeableShape](
546+
self,
547+
call: ClientCall[I, O],
548+
output_context: OutputContext[I, O, TRequest | None, TResponse | None],
549+
) -> OutputContext[I, O, TRequest | None, TResponse | None]:
550+
interceptor = call.interceptor
551+
try:
552+
output_context = replace(
553+
output_context,
554+
response=interceptor.modify_before_completion(output_context),
555+
)
556+
557+
# TODO trace probe
558+
except Exception as e:
559+
output_context = replace(output_context, response=e)
560+
561+
try:
562+
interceptor.read_after_execution(output_context)
563+
except Exception as e:
564+
output_context = replace(output_context, response=e)
565+
566+
return output_context

‎packages/smithy-core/src/smithy_core/aio/interfaces/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ...schemas import APIOperation
1616
from ...serializers import SerializeableShape
1717
from ...shapes import ShapeID
18+
from .auth import AuthScheme
1819

1920

2021
@runtime_checkable
@@ -164,13 +165,15 @@ def create_event_publisher[
164165
request: I,
165166
event_type: type[Event],
166167
context: TypedProperties,
168+
auth_scheme: "AuthScheme[Any, Any, Any, Any] | None" = None,
167169
) -> EventPublisher[Event]:
168170
"""Creates an event publisher for a protocol event stream.
169171
170172
:param operation: The event stream operation.
171173
:param request: The transport request that was sent for this stream.
172174
:param event_type: The type of event to publish.
173175
:param context: A context bag for the request.
176+
:param auth_scheme: The optional auth scheme used to sign events.
174177
"""
175178
raise UnsupportedStreamError()
176179

‎packages/smithy-core/src/smithy_core/aio/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,10 @@ async def close(stream: Any) -> None:
5858
if (close := getattr(stream, "close", None)) is not None:
5959
if iscoroutine(result := close()):
6060
await result
61+
62+
63+
async def seek(stream: Any, to: int) -> None:
64+
"""Seek a stream to a specified point."""
65+
if (seek := getattr(stream, "seek", None)) is not None:
66+
if iscoroutine(result := seek(to)):
67+
await result

0 commit comments

Comments
 (0)
Please sign in to comment.