Skip to content

Commit 1c84fbe

Browse files
committed
refactor(incremental): introduce BoxedPromiseOrValue to save awaited results
Replicates graphql/graphql-js@062785e
1 parent db0237f commit 1c84fbe

File tree

9 files changed

+270
-183
lines changed

9 files changed

+270
-183
lines changed

src/graphql/execution/execute.py

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
CancelledError,
77
TimeoutError, # only needed for Python < 3.11 # noqa: A004
88
ensure_future,
9+
sleep,
910
)
1011
from contextlib import suppress
1112
from copy import copy
@@ -40,6 +41,7 @@
4041
)
4142
from ..pyutils import (
4243
AwaitableOrValue,
44+
BoxedAwaitableOrValue,
4345
Path,
4446
RefMap,
4547
Undefined,
@@ -1627,7 +1629,8 @@ def execute_deferred_grouped_field_sets(
16271629
)
16281630

16291631
deferred_grouped_field_set_record = DeferredGroupedFieldSetRecord(
1630-
deferred_fragment_records, cast("DeferredGroupedFieldSetResult", None)
1632+
deferred_fragment_records,
1633+
cast("BoxedAwaitableOrValue[DeferredGroupedFieldSetResult]", None),
16311634
)
16321635

16331636
if should_defer(parent_defer_usages, defer_usage_set):
@@ -1650,10 +1653,12 @@ async def executor(
16501653
return await result
16511654
return cast("DeferredGroupedFieldSetResult", result)
16521655

1653-
deferred_grouped_field_set_record.result = executor(
1654-
deferred_grouped_field_set_record,
1655-
grouped_field_set,
1656-
defer_usage_set,
1656+
deferred_grouped_field_set_record.result = BoxedAwaitableOrValue(
1657+
executor(
1658+
deferred_grouped_field_set_record,
1659+
grouped_field_set,
1660+
defer_usage_set,
1661+
)
16571662
)
16581663
else:
16591664
executed = self.execute_deferred_grouped_field_set(
@@ -1665,7 +1670,9 @@ async def executor(
16651670
IncrementalContext(defer_usage_set),
16661671
defer_map,
16671672
)
1668-
deferred_grouped_field_set_record.result = executed
1673+
deferred_grouped_field_set_record.result = BoxedAwaitableOrValue(
1674+
executed
1675+
)
16691676

16701677
append_record(deferred_grouped_field_set_record)
16711678

@@ -1743,59 +1750,71 @@ async def await_result() -> StreamItemsResult:
17431750
path = stream_record.path
17441751
initial_path = Path(path, initial_index, None)
17451752

1746-
result = self.complete_stream_items(
1747-
stream_record,
1748-
initial_path,
1749-
initial_item,
1750-
IncrementalContext(),
1751-
field_group,
1752-
info,
1753-
item_type,
1753+
result: BoxedAwaitableOrValue[StreamItemsResult] = BoxedAwaitableOrValue(
1754+
self.complete_stream_items(
1755+
stream_record,
1756+
initial_path,
1757+
initial_item,
1758+
IncrementalContext(),
1759+
field_group,
1760+
info,
1761+
item_type,
1762+
)
17541763
)
17551764
first_stream_items = StreamItemsRecord(stream_record, result)
17561765
current_stream_items = first_stream_items
17571766
current_index = initial_index
17581767
errored_synchronously = False
17591768
for item in iterator:
1760-
if not is_awaitable(result) and not is_reconcilable_stream_items_result(
1761-
result # type: ignore
1769+
value = result.value
1770+
if not is_awaitable(value) and not is_reconcilable_stream_items_result(
1771+
value
17621772
):
17631773
errored_synchronously = True
17641774
break
17651775
current_index += 1
17661776
current_path = Path(path, current_index, None)
1767-
result = self.complete_stream_items(
1768-
stream_record,
1769-
current_path,
1770-
item,
1771-
IncrementalContext(),
1772-
field_group,
1773-
info,
1774-
item_type,
1777+
result = BoxedAwaitableOrValue(
1778+
self.complete_stream_items(
1779+
stream_record,
1780+
current_path,
1781+
item,
1782+
IncrementalContext(),
1783+
field_group,
1784+
info,
1785+
item_type,
1786+
)
17751787
)
17761788

17771789
next_stream_items = StreamItemsRecord(stream_record, result)
1778-
current_stream_items.result = prepend_next_stream_items(
1779-
current_stream_items.result, next_stream_items
1790+
current_stream_items.result = BoxedAwaitableOrValue(
1791+
prepend_next_stream_items(
1792+
current_stream_items.result.value, next_stream_items
1793+
)
17801794
)
17811795
current_stream_items = next_stream_items
17821796

17831797
# If a non-reconcilable stream items result was encountered,
17841798
# then the stream terminates in error. Otherwise, add a stream terminator.
17851799
if not errored_synchronously:
1786-
current_stream_items.result = prepend_next_stream_items(
1787-
current_stream_items.result,
1788-
StreamItemsRecord(
1789-
stream_record, TerminatingStreamItemsResult(stream_record)
1790-
),
1800+
current_stream_items.result = BoxedAwaitableOrValue(
1801+
prepend_next_stream_items(
1802+
current_stream_items.result.value,
1803+
StreamItemsRecord(
1804+
stream_record,
1805+
BoxedAwaitableOrValue(
1806+
TerminatingStreamItemsResult(stream_record),
1807+
),
1808+
),
1809+
)
17911810
)
17921811

1793-
result = first_stream_items.result
1794-
if is_awaitable(result):
1795-
return await result
1796-
return cast("StreamItemsResult", result)
1812+
value = first_stream_items.result.value
1813+
if is_awaitable(value):
1814+
return await value
1815+
return value
17971816

1798-
return StreamItemsRecord(stream_record, await_result())
1817+
return StreamItemsRecord(stream_record, BoxedAwaitableOrValue(await_result()))
17991818

18001819
def first_async_stream_items(
18011820
self,
@@ -1810,14 +1829,16 @@ def first_async_stream_items(
18101829
"""Get the first async stream items."""
18111830
return StreamItemsRecord(
18121831
stream_record,
1813-
self.get_next_async_stream_items_result(
1814-
stream_record,
1815-
path,
1816-
initial_index,
1817-
async_iterator,
1818-
field_group,
1819-
info,
1820-
item_type,
1832+
BoxedAwaitableOrValue(
1833+
self.get_next_async_stream_items_result(
1834+
stream_record,
1835+
path,
1836+
initial_index,
1837+
async_iterator,
1838+
field_group,
1839+
info,
1840+
item_type,
1841+
)
18211842
),
18221843
)
18231844

@@ -1856,14 +1877,23 @@ async def get_next_async_stream_items_result(
18561877

18571878
next_stream_items_record = StreamItemsRecord(
18581879
stream_record,
1859-
self.get_next_async_stream_items_result(
1860-
stream_record, path, index, async_iterator, field_group, info, item_type
1880+
BoxedAwaitableOrValue(
1881+
self.get_next_async_stream_items_result(
1882+
stream_record,
1883+
path,
1884+
index,
1885+
async_iterator,
1886+
field_group,
1887+
info,
1888+
item_type,
1889+
)
18611890
),
18621891
)
18631892

18641893
result = self.prepend_next_stream_items(result, next_stream_items_record)
18651894

18661895
if self.is_awaitable(result):
1896+
await sleep(0)
18671897
return await result
18681898
return cast("StreamItemsResult", result)
18691899

src/graphql/execution/incremental_graph.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22

33
from __future__ import annotations
44

5-
from asyncio import CancelledError, Future, Task, ensure_future
5+
from asyncio import (
6+
CancelledError,
7+
Future,
8+
Task,
9+
ensure_future,
10+
get_running_loop,
11+
isfuture,
12+
)
613
from typing import (
714
TYPE_CHECKING,
815
Any,
@@ -20,8 +27,6 @@
2027
is_deferred_grouped_field_set_record,
2128
)
2229

23-
from ..pyutils.is_awaitable import is_awaitable
24-
2530
if TYPE_CHECKING:
2631
from graphql.execution.types import (
2732
DeferredFragmentRecord,
@@ -163,17 +168,17 @@ def get_new_pending(self) -> list[SubsequentResultRecord]:
163168

164169
enqueue = self._enqueue
165170
for incremental_data_record in _new_incremental_data_records:
166-
result = incremental_data_record.result
167-
if is_awaitable(result):
171+
value = incremental_data_record.result.value
172+
if isfuture(value):
168173

169-
async def enqueue_incremental(
170-
result: Awaitable[IncrementalDataRecordResult],
174+
async def enqueue_later(
175+
value: Awaitable[IncrementalDataRecordResult],
171176
) -> None:
172-
enqueue(await result)
177+
enqueue(await value)
173178

174-
self._add_task(enqueue_incremental(result))
179+
self._add_task(enqueue_later(value))
175180
else:
176-
enqueue(result) # type: ignore
181+
enqueue(value)
177182
_new_incremental_data_records.clear()
178183

179184
return new_pending
@@ -182,12 +187,15 @@ async def completed_incremental_data(
182187
self,
183188
) -> AsyncGenerator[Iterable[IncrementalDataRecordResult], None]:
184189
"""Asynchronously yield completed incremental data record results."""
190+
loop = get_running_loop()
185191
while True:
186192
if self._completed_queue:
187193
first_result = self._completed_queue.pop(0)
188194
yield self._yield_current_completed_incremental_data(first_result)
189195
else:
190-
future: Future[Iterable[IncrementalDataRecordResult]] = Future()
196+
future: Future[Iterable[IncrementalDataRecordResult]] = (
197+
loop.create_future()
198+
)
191199
self._next_queue.append(future)
192200
try:
193201
yield await future

src/graphql/execution/incremental_publisher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ async def _subscribe(
156156
for completed_result in completed_results:
157157
await handle_completed_incremental_data(completed_result, context)
158158

159-
if context.incremental or context.completed:
159+
if context.incremental or context.completed: # pragma: no branch
160160
has_next = check_has_next()
161161

162162
if not has_next:
@@ -192,6 +192,7 @@ async def _handle_completed_incremental_data(
192192
completed_incremental_data: IncrementalDataRecordResult,
193193
context: SubsequentIncrementalExecutionResultContext,
194194
) -> None:
195+
"""Handle completed incremental data."""
195196
if is_deferred_grouped_field_set_result(completed_incremental_data):
196197
self._handle_completed_deferred_grouped_field_set(
197198
completed_incremental_data, context
@@ -203,7 +204,6 @@ async def _handle_completed_incremental_data(
203204
await self._handle_completed_stream_items(
204205
completed_incremental_data, context
205206
)
206-
207207
new_pending = self._incremental_graph.get_new_pending()
208208
context.pending.extend(self._pending_sources_to_results(new_pending))
209209

src/graphql/execution/types.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
if TYPE_CHECKING:
2121
from ..error import GraphQLError, GraphQLFormattedError
22-
from ..pyutils import AwaitableOrValue, Path
22+
from ..pyutils import BoxedAwaitableOrValue, Path
2323

2424
try:
2525
from typing import TypeGuard
@@ -783,14 +783,14 @@ class DeferredGroupedFieldSetRecord:
783783
"""Deferred grouped field set record"""
784784

785785
deferred_fragment_records: list[DeferredFragmentRecord]
786-
result: AwaitableOrValue[DeferredGroupedFieldSetResult]
786+
result: BoxedAwaitableOrValue[DeferredGroupedFieldSetResult]
787787

788788
__slots__ = "deferred_fragment_records", "result"
789789

790790
def __init__(
791791
self,
792792
deferred_fragment_records: list[DeferredFragmentRecord],
793-
result: AwaitableOrValue[DeferredGroupedFieldSetResult],
793+
result: BoxedAwaitableOrValue[DeferredGroupedFieldSetResult],
794794
) -> None:
795795
self.result = result
796796
self.deferred_fragment_records = deferred_fragment_records
@@ -923,12 +923,12 @@ class StreamItemsRecord:
923923
__slots__ = "result", "stream_record"
924924

925925
stream_record: StreamRecord
926-
result: AwaitableOrValue[StreamItemsResult]
926+
result: BoxedAwaitableOrValue[StreamItemsResult]
927927

928928
def __init__(
929929
self,
930930
stream_record: StreamRecord,
931-
result: AwaitableOrValue[StreamItemsResult],
931+
result: BoxedAwaitableOrValue[StreamItemsResult],
932932
) -> None:
933933
self.stream_record = stream_record
934934
self.result = result

src/graphql/pyutils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
from .async_reduce import async_reduce
12+
from .boxed_awaitable_or_value import BoxedAwaitableOrValue
1213
from .gather_with_cancel import gather_with_cancel
1314
from .convert_case import camel_to_snake, snake_to_camel
1415
from .cached_property import cached_property
@@ -39,6 +40,7 @@
3940

4041
__all__ = [
4142
"AwaitableOrValue",
43+
"BoxedAwaitableOrValue",
4244
"Description",
4345
"FrozenError",
4446
"Path",

0 commit comments

Comments
 (0)