|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +import asyncio |
| 4 | +import logging |
| 5 | +from asyncio import Future, sleep |
| 6 | +from collections.abc import Awaitable, Callable, Sequence |
| 7 | +from copy import copy |
| 8 | +from dataclasses import dataclass, replace |
| 9 | +from typing import Any |
| 10 | + |
| 11 | +from .. import URI |
| 12 | +from ..auth import AuthParams |
| 13 | +from ..deserializers import DeserializeableShape, ShapeDeserializer |
| 14 | +from ..endpoints import EndpointResolverParams |
| 15 | +from ..exceptions import RetryError, SmithyError |
| 16 | +from ..interceptors import ( |
| 17 | + InputContext, |
| 18 | + Interceptor, |
| 19 | + OutputContext, |
| 20 | + RequestContext, |
| 21 | + ResponseContext, |
| 22 | +) |
| 23 | +from ..interfaces import Endpoint, TypedProperties |
| 24 | +from ..interfaces.auth import AuthOption, AuthSchemeResolver |
| 25 | +from ..interfaces.retries import RetryStrategy |
| 26 | +from ..schemas import APIOperation |
| 27 | +from ..serializers import SerializeableShape |
| 28 | +from ..shapes import ShapeID |
| 29 | +from ..types import PropertyKey |
| 30 | +from .eventstream import DuplexEventStream, InputEventStream, OutputEventStream |
| 31 | +from .interfaces import ( |
| 32 | + ClientProtocol, |
| 33 | + ClientTransport, |
| 34 | + EndpointResolver, |
| 35 | + Request, |
| 36 | + Response, |
| 37 | +) |
| 38 | +from .interfaces.auth import AuthScheme |
| 39 | +from .interfaces.eventstream import EventReceiver |
| 40 | +from .utils import seek |
| 41 | + |
| 42 | +AUTH_SCHEME = PropertyKey(key="auth_scheme", value_type=AuthScheme[Any, Any, Any, Any]) |
| 43 | + |
| 44 | +_UNRESOLVED = URI(host="", path="/") |
| 45 | +_LOGGER = logging.getLogger(__name__) |
| 46 | + |
| 47 | + |
| 48 | +@dataclass(kw_only=True, frozen=True) |
| 49 | +class ClientCall[I: SerializeableShape, O: DeserializeableShape]: |
| 50 | + """A data class containing all the initial information about an operation |
| 51 | + invocation.""" |
| 52 | + |
| 53 | + input: I |
| 54 | + """The input of the operation.""" |
| 55 | + |
| 56 | + operation: APIOperation[I, O] |
| 57 | + """The schema of the operation.""" |
| 58 | + |
| 59 | + context: TypedProperties |
| 60 | + """The initial context of the operation.""" |
| 61 | + |
| 62 | + interceptor: Interceptor[I, O, Any, Any] |
| 63 | + """The interceptor to use in the course of the operation invocation. |
| 64 | +
|
| 65 | + This SHOULD be an InterceptorChain. |
| 66 | + """ |
| 67 | + |
| 68 | + auth_scheme_resolver: AuthSchemeResolver |
| 69 | + """The auth scheme resolver for the operation.""" |
| 70 | + |
| 71 | + supported_auth_schemes: dict[ShapeID, AuthScheme[Any, Any, Any, Any]] |
| 72 | + """The supported auth schemes for the operation.""" |
| 73 | + |
| 74 | + endpoint_resolver: EndpointResolver |
| 75 | + """The endpoint resolver for the operation.""" |
| 76 | + |
| 77 | + retry_strategy: RetryStrategy |
| 78 | + """The retry strategy to use for the operation.""" |
| 79 | + |
| 80 | + retry_scope: str | None = None |
| 81 | + """The retry scope for the operation.""" |
| 82 | + |
| 83 | + def retryable(self) -> bool: |
| 84 | + # TODO: check to see if the stream is seekable |
| 85 | + return self.operation.input_stream_member is None |
| 86 | + |
| 87 | + |
| 88 | +class RequestPipeline[TRequest: Request, TResponse: Response]: |
| 89 | + """Invokes client operations asynchronously.""" |
| 90 | + |
| 91 | + protocol: ClientProtocol[TRequest, TResponse] |
| 92 | + """The protocol to use to serialize the request and deserialize the response.""" |
| 93 | + |
| 94 | + transport: ClientTransport[TRequest, TResponse] |
| 95 | + """The transport to use to send the request and receive the response (e.g. an HTTP |
| 96 | + Client).""" |
| 97 | + |
| 98 | + def __init__( |
| 99 | + self, |
| 100 | + protocol: ClientProtocol[TRequest, TResponse], |
| 101 | + transport: ClientTransport[TRequest, TResponse], |
| 102 | + ) -> None: |
| 103 | + self.protocol = protocol |
| 104 | + self.transport = transport |
| 105 | + |
| 106 | + async def __call__[I: SerializeableShape, O: DeserializeableShape]( |
| 107 | + self, call: ClientCall[I, O], / |
| 108 | + ) -> O: |
| 109 | + """Invoke an operation asynchronously. |
| 110 | +
|
| 111 | + :param call: The operation to invoke and associated context. |
| 112 | + """ |
| 113 | + output, _ = await self._execute_request(call, None) |
| 114 | + return output |
| 115 | + |
| 116 | + async def input_stream[ |
| 117 | + I: SerializeableShape, |
| 118 | + O: DeserializeableShape, |
| 119 | + E: SerializeableShape, |
| 120 | + ](self, call: ClientCall[I, O], event_type: type[E], /) -> InputEventStream[E, O]: |
| 121 | + """Invoke an input stream operation asynchronously. |
| 122 | +
|
| 123 | + :param call: The operation to invoke and associated context. |
| 124 | + :param event_type: The event type to send in the input stream. |
| 125 | + """ |
| 126 | + request_future = Future[RequestContext[I, TRequest]]() |
| 127 | + output_future = asyncio.create_task( |
| 128 | + self._await_output(self._execute_request(call, request_future)) |
| 129 | + ) |
| 130 | + request_context = await request_future |
| 131 | + input_stream = self.protocol.create_event_publisher( |
| 132 | + operation=call.operation, |
| 133 | + request=request_context.transport_request, |
| 134 | + event_type=event_type, |
| 135 | + context=request_context.properties, |
| 136 | + auth_scheme=request_context.properties.get(AUTH_SCHEME), |
| 137 | + ) |
| 138 | + return InputEventStream(input_stream=input_stream, output_future=output_future) |
| 139 | + |
| 140 | + async def _await_output[I: SerializeableShape, O: DeserializeableShape]( |
| 141 | + self, |
| 142 | + execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest, TResponse]]], |
| 143 | + ) -> O: |
| 144 | + output, _ = await execute_task |
| 145 | + return output |
| 146 | + |
| 147 | + async def output_stream[ |
| 148 | + I: SerializeableShape, |
| 149 | + O: DeserializeableShape, |
| 150 | + E: DeserializeableShape, |
| 151 | + ]( |
| 152 | + self, |
| 153 | + call: ClientCall[I, O], |
| 154 | + event_type: type[E], |
| 155 | + event_deserializer: Callable[[ShapeDeserializer], E], |
| 156 | + /, |
| 157 | + ) -> OutputEventStream[E, O]: |
| 158 | + """Invoke an input stream operation asynchronously. |
| 159 | +
|
| 160 | + :param call: The operation to invoke and associated context. |
| 161 | + :param event_type: The event type to receive in the output stream. |
| 162 | + :param event_deserializer: The method used to deserialize events. |
| 163 | + """ |
| 164 | + output, output_context = await self._execute_request(call, None) |
| 165 | + output_stream = self.protocol.create_event_receiver( |
| 166 | + operation=call.operation, |
| 167 | + request=output_context.transport_request, |
| 168 | + response=output_context.transport_response, |
| 169 | + event_type=event_type, |
| 170 | + event_deserializer=event_deserializer, |
| 171 | + context=output_context.properties, |
| 172 | + ) |
| 173 | + return OutputEventStream(output_stream=output_stream, output=output) |
| 174 | + |
| 175 | + async def duplex_stream[ |
| 176 | + I: SerializeableShape, |
| 177 | + O: DeserializeableShape, |
| 178 | + IE: SerializeableShape, |
| 179 | + OE: DeserializeableShape, |
| 180 | + ]( |
| 181 | + self, |
| 182 | + call: ClientCall[I, O], |
| 183 | + input_event_type: type[IE], |
| 184 | + output_event_type: type[OE], |
| 185 | + event_deserializer: Callable[[ShapeDeserializer], OE], |
| 186 | + /, |
| 187 | + ) -> DuplexEventStream[IE, OE, O]: |
| 188 | + """Invoke an input stream operation asynchronously. |
| 189 | +
|
| 190 | + :param call: The operation to invoke and associated context. |
| 191 | + :param input_event_type: The event type to send in the input stream. |
| 192 | + :param output_event_type: The event type to receive in the output stream. |
| 193 | + :param event_deserializer: The method used to deserialize events. |
| 194 | + """ |
| 195 | + request_future = Future[RequestContext[I, TRequest]]() |
| 196 | + execute_task = asyncio.create_task(self._execute_request(call, request_future)) |
| 197 | + request_context = await request_future |
| 198 | + input_stream = self.protocol.create_event_publisher( |
| 199 | + operation=call.operation, |
| 200 | + request=request_context.transport_request, |
| 201 | + event_type=input_event_type, |
| 202 | + context=request_context.properties, |
| 203 | + auth_scheme=request_context.properties.get(AUTH_SCHEME), |
| 204 | + ) |
| 205 | + output_future = asyncio.create_task( |
| 206 | + self._await_output_stream( |
| 207 | + call=call, |
| 208 | + execute_task=execute_task, |
| 209 | + output_event_type=output_event_type, |
| 210 | + event_deserializer=event_deserializer, |
| 211 | + ) |
| 212 | + ) |
| 213 | + return DuplexEventStream(input_stream=input_stream, output_future=output_future) |
| 214 | + |
| 215 | + async def _await_output_stream[ |
| 216 | + I: SerializeableShape, |
| 217 | + O: DeserializeableShape, |
| 218 | + OE: DeserializeableShape, |
| 219 | + ]( |
| 220 | + self, |
| 221 | + call: ClientCall[I, O], |
| 222 | + execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest, TResponse]]], |
| 223 | + output_event_type: type[OE], |
| 224 | + event_deserializer: Callable[[ShapeDeserializer], OE], |
| 225 | + ) -> tuple[O, EventReceiver[OE]]: |
| 226 | + output, output_context = await execute_task |
| 227 | + output_stream = self.protocol.create_event_receiver( |
| 228 | + operation=call.operation, |
| 229 | + request=output_context.transport_request, |
| 230 | + response=output_context.transport_response, |
| 231 | + event_type=output_event_type, |
| 232 | + event_deserializer=event_deserializer, |
| 233 | + context=output_context.properties, |
| 234 | + ) |
| 235 | + return output, output_stream |
| 236 | + |
| 237 | + async def _execute_request[I: SerializeableShape, O: DeserializeableShape]( |
| 238 | + self, |
| 239 | + call: ClientCall[I, O], |
| 240 | + request_future: Future[RequestContext[I, TRequest]] | None, |
| 241 | + ) -> tuple[O, OutputContext[I, O, TRequest, TResponse]]: |
| 242 | + _LOGGER.debug( |
| 243 | + 'Making request for operation "%s" with parameters: %s', |
| 244 | + call.operation.schema.id.name, |
| 245 | + call.input, |
| 246 | + ) |
| 247 | + output_context = await self._handle_execution(call, request_future) |
| 248 | + output_context = self._finalize_execution(call, output_context) |
| 249 | + |
| 250 | + if isinstance(output_context.response, Exception): |
| 251 | + e = output_context.response |
| 252 | + if not isinstance(e, SmithyError): |
| 253 | + raise SmithyError(e) from e |
| 254 | + raise e |
| 255 | + |
| 256 | + return output_context.response, output_context # type: ignore |
| 257 | + |
| 258 | + async def _handle_execution[I: SerializeableShape, O: DeserializeableShape]( |
| 259 | + self, |
| 260 | + call: ClientCall[I, O], |
| 261 | + request_future: Future[RequestContext[I, TRequest]] | None, |
| 262 | + ) -> OutputContext[I, O, TRequest | None, TResponse | None]: |
| 263 | + try: |
| 264 | + interceptor = call.interceptor |
| 265 | + |
| 266 | + input_context = InputContext(request=call.input, properties=call.context) |
| 267 | + interceptor.read_before_execution(input_context) |
| 268 | + |
| 269 | + input_context = replace( |
| 270 | + input_context, |
| 271 | + request=interceptor.modify_before_serialization(input_context), |
| 272 | + ) |
| 273 | + |
| 274 | + interceptor.read_before_serialization(input_context) |
| 275 | + _LOGGER.debug("Serializing request for: %s", input_context.request) |
| 276 | + |
| 277 | + transport_request = self.protocol.serialize_request( |
| 278 | + operation=call.operation, |
| 279 | + input=call.input, |
| 280 | + endpoint=_UNRESOLVED, |
| 281 | + context=input_context.properties, |
| 282 | + ) |
| 283 | + request_context = RequestContext( |
| 284 | + request=input_context.request, |
| 285 | + transport_request=transport_request, |
| 286 | + properties=input_context.properties, |
| 287 | + ) |
| 288 | + |
| 289 | + _LOGGER.debug( |
| 290 | + "Serialization complete. Transport request: %s", transport_request |
| 291 | + ) |
| 292 | + except Exception as e: |
| 293 | + return OutputContext( |
| 294 | + request=call.input, |
| 295 | + response=e, |
| 296 | + transport_request=None, |
| 297 | + transport_response=None, |
| 298 | + properties=call.context, |
| 299 | + ) |
| 300 | + |
| 301 | + try: |
| 302 | + interceptor.read_after_serialization(request_context) |
| 303 | + request_context = replace( |
| 304 | + request_context, |
| 305 | + transport_request=interceptor.modify_before_retry_loop(request_context), |
| 306 | + ) |
| 307 | + |
| 308 | + return await self._retry(call, request_context, request_future) |
| 309 | + except Exception as e: |
| 310 | + return OutputContext( |
| 311 | + request=request_context.request, |
| 312 | + response=e, |
| 313 | + transport_request=request_context.transport_request, |
| 314 | + transport_response=None, |
| 315 | + properties=request_context.properties, |
| 316 | + ) |
| 317 | + |
| 318 | + async def _retry[I: SerializeableShape, O: DeserializeableShape]( |
| 319 | + self, |
| 320 | + call: ClientCall[I, O], |
| 321 | + request_context: RequestContext[I, TRequest], |
| 322 | + request_future: Future[RequestContext[I, TRequest]] | None, |
| 323 | + ) -> OutputContext[I, O, TRequest | None, TResponse | None]: |
| 324 | + if not call.retryable(): |
| 325 | + return await self._handle_attempt(call, request_context, request_future) |
| 326 | + |
| 327 | + retry_strategy = call.retry_strategy |
| 328 | + retry_token = retry_strategy.acquire_initial_retry_token( |
| 329 | + token_scope=call.retry_scope |
| 330 | + ) |
| 331 | + |
| 332 | + while True: |
| 333 | + if retry_token.retry_delay: |
| 334 | + await sleep(retry_token.retry_delay) |
| 335 | + |
| 336 | + output_context = await self._handle_attempt( |
| 337 | + call, |
| 338 | + replace( |
| 339 | + request_context, |
| 340 | + transport_request=copy(request_context.transport_request), |
| 341 | + ), |
| 342 | + request_future, |
| 343 | + ) |
| 344 | + |
| 345 | + if isinstance(output_context.response, Exception): |
| 346 | + try: |
| 347 | + retry_strategy.refresh_retry_token_for_retry( |
| 348 | + token_to_renew=retry_token, |
| 349 | + error=output_context.response, |
| 350 | + ) |
| 351 | + except RetryError: |
| 352 | + raise output_context.response |
| 353 | + |
| 354 | + _LOGGER.debug( |
| 355 | + "Retry needed. Attempting request #%s in %.4f seconds.", |
| 356 | + retry_token.retry_count + 1, |
| 357 | + retry_token.retry_delay, |
| 358 | + ) |
| 359 | + |
| 360 | + await seek(request_context.transport_request.body, 0) |
| 361 | + else: |
| 362 | + retry_strategy.record_success(token=retry_token) |
| 363 | + return output_context |
| 364 | + |
| 365 | + async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape]( |
| 366 | + self, |
| 367 | + call: ClientCall[I, O], |
| 368 | + request_context: RequestContext[I, TRequest], |
| 369 | + request_future: Future[RequestContext[I, TRequest]] | None, |
| 370 | + ) -> OutputContext[I, O, TRequest, TResponse | None]: |
| 371 | + output_context: OutputContext[I, O, TRequest, TResponse | None] |
| 372 | + try: |
| 373 | + interceptor = call.interceptor |
| 374 | + interceptor.read_before_attempt(request_context) |
| 375 | + |
| 376 | + endpoint_params = EndpointResolverParams( |
| 377 | + operation=call.operation, |
| 378 | + input=call.input, |
| 379 | + context=request_context.properties, |
| 380 | + ) |
| 381 | + _LOGGER.debug("Calling endpoint resolver.") |
| 382 | + endpoint: Endpoint = await call.endpoint_resolver.resolve_endpoint( |
| 383 | + endpoint_params |
| 384 | + ) |
| 385 | + _LOGGER.debug("Endpoint resolver result: %s", endpoint) |
| 386 | + |
| 387 | + request_context = replace( |
| 388 | + request_context, |
| 389 | + transport_request=self.protocol.set_service_endpoint( |
| 390 | + request=request_context.transport_request, endpoint=endpoint |
| 391 | + ), |
| 392 | + ) |
| 393 | + |
| 394 | + request_context = replace( |
| 395 | + request_context, |
| 396 | + transport_request=interceptor.modify_before_signing(request_context), |
| 397 | + ) |
| 398 | + interceptor.read_before_signing(request_context) |
| 399 | + |
| 400 | + auth_params = AuthParams[I, O]( |
| 401 | + protocol_id=self.protocol.id, |
| 402 | + operation=call.operation, |
| 403 | + context=request_context.properties, |
| 404 | + ) |
| 405 | + auth = self._resolve_auth(call, auth_params) |
| 406 | + if auth is not None: |
| 407 | + option, scheme = auth |
| 408 | + request_context.properties[AUTH_SCHEME] = scheme |
| 409 | + identity_resolver = scheme.identity_resolver(context=call.context) |
| 410 | + identity = await identity_resolver.get_identity( |
| 411 | + properties=option.identity_properties |
| 412 | + ) |
| 413 | + |
| 414 | + _LOGGER.debug("Request to sign: %s", request_context.transport_request) |
| 415 | + _LOGGER.debug("Signer properties: %s", option.signer_properties) |
| 416 | + |
| 417 | + signer = scheme.signer() |
| 418 | + request_context = replace( |
| 419 | + request_context, |
| 420 | + transport_request=await signer.sign( |
| 421 | + request=request_context.transport_request, |
| 422 | + identity=identity, |
| 423 | + properties=option.signer_properties, |
| 424 | + ), |
| 425 | + ) |
| 426 | + |
| 427 | + interceptor.read_after_signing(request_context) |
| 428 | + request_context = replace( |
| 429 | + request_context, |
| 430 | + transport_request=interceptor.modify_before_transmit(request_context), |
| 431 | + ) |
| 432 | + interceptor.read_before_transmit(request_context) |
| 433 | + |
| 434 | + _LOGGER.debug("Sending request %s", request_context.transport_request) |
| 435 | + |
| 436 | + if request_future is not None: |
| 437 | + # If we have an input event stream (or duplex event stream) then we |
| 438 | + # need to let the client return ASAP so that it can start sending |
| 439 | + # events. So here we start the transport send in a background task |
| 440 | + # then set the result of the request future. It's important to sequence |
| 441 | + # it just like that so that the client gets a stream that's ready |
| 442 | + # to send. |
| 443 | + transport_task = asyncio.create_task( |
| 444 | + self.transport.send(request=request_context.transport_request) |
| 445 | + ) |
| 446 | + request_future.set_result(request_context) |
| 447 | + transport_response = await transport_task |
| 448 | + else: |
| 449 | + # If we don't have an input stream, there's no point in creating a |
| 450 | + # task, so we just immediately await the coroutine. |
| 451 | + transport_response = await self.transport.send( |
| 452 | + request=request_context.transport_request |
| 453 | + ) |
| 454 | + |
| 455 | + _LOGGER.debug("Received response: %s", transport_response) |
| 456 | + |
| 457 | + response_context = ResponseContext( |
| 458 | + request=request_context.request, |
| 459 | + transport_request=request_context.transport_request, |
| 460 | + transport_response=transport_response, |
| 461 | + properties=request_context.properties, |
| 462 | + ) |
| 463 | + |
| 464 | + interceptor.read_after_transmit(response_context) |
| 465 | + |
| 466 | + response_context = replace( |
| 467 | + response_context, |
| 468 | + transport_response=interceptor.modify_before_deserialization( |
| 469 | + response_context |
| 470 | + ), |
| 471 | + ) |
| 472 | + |
| 473 | + interceptor.read_before_deserialization(response_context) |
| 474 | + |
| 475 | + _LOGGER.debug( |
| 476 | + "Deserializing response: %s", response_context.transport_response |
| 477 | + ) |
| 478 | + |
| 479 | + output = await self.protocol.deserialize_response( |
| 480 | + operation=call.operation, |
| 481 | + request=response_context.transport_request, |
| 482 | + response=response_context.transport_response, |
| 483 | + error_registry=call.operation.error_registry, |
| 484 | + context=response_context.properties, |
| 485 | + ) |
| 486 | + |
| 487 | + _LOGGER.debug("Deserialization complete. Output: %s", output) |
| 488 | + |
| 489 | + output_context = OutputContext( |
| 490 | + request=response_context.request, |
| 491 | + response=output, |
| 492 | + transport_request=response_context.transport_request, |
| 493 | + transport_response=response_context.transport_response, |
| 494 | + properties=response_context.properties, |
| 495 | + ) |
| 496 | + |
| 497 | + interceptor.read_after_deserialization(output_context) |
| 498 | + except Exception as e: |
| 499 | + output_context = OutputContext( |
| 500 | + request=request_context.request, |
| 501 | + response=e, |
| 502 | + transport_request=request_context.transport_request, |
| 503 | + transport_response=None, |
| 504 | + properties=request_context.properties, |
| 505 | + ) |
| 506 | + |
| 507 | + return self._finalize_attempt(call, output_context) |
| 508 | + |
| 509 | + def _resolve_auth[I: SerializeableShape, O: DeserializeableShape]( |
| 510 | + self, call: ClientCall[Any, Any], params: AuthParams[I, O] |
| 511 | + ) -> tuple[AuthOption, AuthScheme[TRequest, Any, Any, Any]] | None: |
| 512 | + auth_options: Sequence[AuthOption] = ( |
| 513 | + call.auth_scheme_resolver.resolve_auth_scheme(auth_parameters=params) |
| 514 | + ) |
| 515 | + |
| 516 | + for option in auth_options: |
| 517 | + if ( |
| 518 | + scheme := call.supported_auth_schemes.get(option.scheme_id) |
| 519 | + ) is not None: |
| 520 | + return option, scheme |
| 521 | + |
| 522 | + return None |
| 523 | + |
| 524 | + def _finalize_attempt[I: SerializeableShape, O: DeserializeableShape]( |
| 525 | + self, |
| 526 | + call: ClientCall[I, O], |
| 527 | + output_context: OutputContext[I, O, TRequest, TResponse | None], |
| 528 | + ) -> OutputContext[I, O, TRequest, TResponse | None]: |
| 529 | + interceptor = call.interceptor |
| 530 | + try: |
| 531 | + output_context = replace( |
| 532 | + output_context, |
| 533 | + response=interceptor.modify_before_attempt_completion(output_context), |
| 534 | + ) |
| 535 | + except Exception as e: |
| 536 | + output_context = replace(output_context, response=e) |
| 537 | + |
| 538 | + try: |
| 539 | + interceptor.read_after_attempt(output_context) |
| 540 | + except Exception as e: |
| 541 | + output_context = replace(output_context, response=e) |
| 542 | + |
| 543 | + return output_context |
| 544 | + |
| 545 | + def _finalize_execution[I: SerializeableShape, O: DeserializeableShape]( |
| 546 | + self, |
| 547 | + call: ClientCall[I, O], |
| 548 | + output_context: OutputContext[I, O, TRequest | None, TResponse | None], |
| 549 | + ) -> OutputContext[I, O, TRequest | None, TResponse | None]: |
| 550 | + interceptor = call.interceptor |
| 551 | + try: |
| 552 | + output_context = replace( |
| 553 | + output_context, |
| 554 | + response=interceptor.modify_before_completion(output_context), |
| 555 | + ) |
| 556 | + |
| 557 | + # TODO trace probe |
| 558 | + except Exception as e: |
| 559 | + output_context = replace(output_context, response=e) |
| 560 | + |
| 561 | + try: |
| 562 | + interceptor.read_after_execution(output_context) |
| 563 | + except Exception as e: |
| 564 | + output_context = replace(output_context, response=e) |
| 565 | + |
| 566 | + return output_context |
0 commit comments