Skip to content

Commit 01f4363

Browse files
committed
Minor changes in MapAsyncIterator and its test
1 parent 49e7124 commit 01f4363

File tree

2 files changed

+62
-60
lines changed

2 files changed

+62
-60
lines changed

graphql/subscription/map_async_iterator.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
__all__ = ['MapAsyncIterator']
77

88

9+
# noinspection PyAttributeOutsideInit
910
class MapAsyncIterator:
1011
"""Map an AsyncIterable over a callback function.
1112
@@ -23,75 +24,72 @@ def __init__(self, iterable: AsyncIterable, callback: Callable,
2324
self.reject_callback = reject_callback
2425
self._close_event = Event()
2526

26-
@property
27-
def closed(self) -> bool:
28-
return self._close_event.is_set()
29-
30-
@closed.setter
31-
def closed(self, value: bool) -> None:
32-
if value:
33-
self._close_event.set()
34-
else:
35-
self._close_event.clear()
36-
3727
def __aiter__(self):
3828
return self
3929

4030
async def __anext__(self):
41-
if self.closed:
31+
if self.is_closed:
4232
if not isasyncgen(self.iterator):
4333
raise StopAsyncIteration
44-
result = await self.iterator.__anext__()
45-
return self.callback(result)
34+
value = await self.iterator.__anext__()
35+
result = self.callback(value)
4636

47-
_close = ensure_future(self._close_event.wait())
48-
_next = ensure_future(self.iterator.__anext__())
49-
done, pending = await wait(
50-
[_close, _next],
51-
return_when=FIRST_COMPLETED,
52-
)
37+
else:
38+
aclose = ensure_future(self._close_event.wait())
39+
anext = ensure_future(self.iterator.__anext__())
5340

54-
for task in pending:
55-
task.cancel()
41+
done, pending = await wait(
42+
[aclose, anext], return_when=FIRST_COMPLETED)
43+
for task in pending:
44+
task.cancel()
5645

57-
if _close.done():
58-
raise StopAsyncIteration
46+
if aclose.done():
47+
raise StopAsyncIteration
5948

60-
if _next.done():
61-
error = _next.exception()
49+
error = anext.exception()
6250
if error:
6351
if not self.reject_callback or isinstance(error, (
6452
StopAsyncIteration, GeneratorExit)):
6553
raise error
6654
result = self.reject_callback(error)
6755
else:
68-
result = self.callback(_next.result())
56+
value = anext.result()
57+
result = self.callback(value)
6958

70-
return (await result) if isawaitable(result) else result
59+
return await result if isawaitable(result) else result
7160

7261
async def athrow(self, type_, value=None, traceback=None):
73-
if self.closed:
74-
return
75-
athrow = getattr(self.iterator, 'athrow', None)
76-
if athrow:
77-
await athrow(type_, value, traceback)
78-
else:
79-
self.closed = True
80-
if value is None:
81-
if traceback is None:
82-
raise type_
83-
value = type_()
84-
if traceback is not None:
85-
value = value.with_traceback(traceback)
86-
raise value
62+
if not self.is_closed:
63+
athrow = getattr(self.iterator, 'athrow', None)
64+
if athrow:
65+
await athrow(type_, value, traceback)
66+
else:
67+
self.is_closed = True
68+
if value is None:
69+
if traceback is None:
70+
raise type_
71+
value = type_()
72+
if traceback is not None:
73+
value = value.with_traceback(traceback)
74+
raise value
8775

8876
async def aclose(self):
89-
if self.closed:
90-
return
91-
aclose = getattr(self.iterator, 'aclose', None)
92-
if aclose:
93-
try:
94-
await aclose()
95-
except RuntimeError:
96-
pass
97-
self.closed = True
77+
if not self.is_closed:
78+
aclose = getattr(self.iterator, 'aclose', None)
79+
if aclose:
80+
try:
81+
await aclose()
82+
except RuntimeError:
83+
pass
84+
self.is_closed = True
85+
86+
@property
87+
def is_closed(self) -> bool:
88+
return self._close_event.is_set()
89+
90+
@is_closed.setter
91+
def is_closed(self, value: bool) -> None:
92+
if value:
93+
self._close_event.set()
94+
else:
95+
self._close_event.clear()

tests/subscription/test_map_async_iterator.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,22 @@ async def source():
257257
yield 2
258258
yield 3
259259

260-
doubles = MapAsyncIterator(source(), lambda x: x * 2)
260+
singles = source()
261+
doubles = MapAsyncIterator(singles, lambda x: x * 2)
261262

262263
result = await anext(doubles)
263264
assert result == 2
264265

265-
# Block at event.wait()
266-
fut = ensure_future(anext(doubles))
267-
await sleep(.01)
268-
assert not fut.done()
266+
# Make sure it is blocked
267+
doubles_future = ensure_future(anext(doubles))
268+
await sleep(.05)
269+
assert not doubles_future.done()
269270

270-
# Trigger cancellation and watch StopAsyncIteration propogate
271+
# Unblock and watch StopAsyncIteration propagate
271272
await doubles.aclose()
272-
await sleep(.01)
273-
assert fut.done()
274-
assert isinstance(fut.exception(), StopAsyncIteration)
273+
await sleep(.05)
274+
assert doubles_future.done()
275+
assert isinstance(doubles_future.exception(), StopAsyncIteration)
276+
277+
with raises(StopAsyncIteration):
278+
await anext(singles)

0 commit comments

Comments
 (0)