diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 46c8b4d48..82bd8537d 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -56,6 +56,56 @@ Global statistics .. autoclass:: RunStatistics() +.. _trio_contexts: + +Checking for Trio +----------------- + +If you want to interact with an active Trio run -- perhaps you need to +know the :func:`~trio.current_time` or the +:func:`~trio.lowlevel.current_task` -- then Trio needs to have certain +state available to it or else you will get a +``RuntimeError("must be called from async context")``. +This requires that you either be: + +* indirectly inside (and on the same thread as) a call to + :func:`trio.run`, for run-level information such as the + :func:`~trio.current_time` or :func:`~trio.lowlevel.current_clock`; + or + +* indirectly inside a Trio task, for task-level information such as + the :func:`~trio.lowlevel.current_task` or + :func:`~trio.current_effective_deadline`. + +Internally, this state is provided by thread-local variables tracking +the current run and the current task. Sometimes, it's useful to know +in advance whether a call will fail or to have dynamic information for +safeguards against running something inside or outside Trio. To do so, +call :func:`trio.lowlevel.in_trio_run` or +:func:`trio.lowlevel.in_trio_task`, which will provide answers +according to the following table. + + ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| situation | :func:`trio.lowlevel.in_trio_run` | :func:`trio.lowlevel.in_trio_task` | ++========================================================+===================================+====================================+ +| inside a Trio-flavored async function | `True` | `True` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| in a thread without an active call to :func:`trio.run` | `False` | `False` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| in a guest run's host loop | `True` | `False` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| inside an instrument call | `True` | depends | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| in a thread created by :func:`trio.to_thread.run_sync` | `False` | `False` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ +| inside an abort function | `True` | `True` | ++--------------------------------------------------------+-----------------------------------+------------------------------------+ + +.. autofunction:: in_trio_run + +.. autofunction:: in_trio_task + The current clock ----------------- diff --git a/newsfragments/2757.feature.rst b/newsfragments/2757.feature.rst new file mode 100644 index 000000000..317299561 --- /dev/null +++ b/newsfragments/2757.feature.rst @@ -0,0 +1 @@ +Add :func:`trio.lowlevel.in_trio_run` and :func:`trio.lowlevel.in_trio_task` and document the semantics (and differences) thereof. See :ref:`the documentation `. diff --git a/src/trio/_core/__init__.py b/src/trio/_core/__init__.py index fdef90292..d21aefb3e 100644 --- a/src/trio/_core/__init__.py +++ b/src/trio/_core/__init__.py @@ -45,6 +45,8 @@ current_task, current_time, current_trio_token, + in_trio_run, + in_trio_task, notify_closing, open_nursery, remove_instrument, diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 0dc3ced5d..eedb99644 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -2283,7 +2283,7 @@ def setup_runner( # It wouldn't be *hard* to support nested calls to run(), but I can't # think of a single good reason for it, so let's be conservative for # now: - if hasattr(GLOBAL_RUN_CONTEXT, "runner"): + if in_trio_run(): raise RuntimeError("Attempted to call run() from inside a run()") if clock is None: @@ -2832,8 +2832,9 @@ def unrolled_run( except BaseException as exc: raise TrioInternalError("internal error in Trio - please file a bug!") from exc finally: - GLOBAL_RUN_CONTEXT.__dict__.clear() runner.close() + GLOBAL_RUN_CONTEXT.__dict__.clear() + # Have to do this after runner.close() has disabled KI protection, # because otherwise there's a race where ki_pending could get set # after we check it. @@ -2952,6 +2953,24 @@ async def checkpoint_if_cancelled() -> None: task._cancel_points += 1 +def in_trio_run() -> bool: + """Check whether we are in a Trio run. + This returns `True` if and only if :func:`~trio.current_time` will succeed. + + See also the discussion of differing ways of :ref:`detecting Trio `. + """ + return hasattr(GLOBAL_RUN_CONTEXT, "runner") + + +def in_trio_task() -> bool: + """Check whether we are in a Trio task. + This returns `True` if and only if :func:`~trio.lowlevel.current_task` will succeed. + + See also the discussion of differing ways of :ref:`detecting Trio `. + """ + return hasattr(GLOBAL_RUN_CONTEXT, "task") + + if sys.platform == "win32": from ._generated_io_windows import * from ._io_windows import ( diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index b455175f4..81b7a07d8 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -264,6 +264,26 @@ async def synchronize() -> None: sniffio_library.name = None +def test_guest_mode_trio_context_detection() -> None: + def check(thing: bool) -> None: + assert thing + + assert not trio.lowlevel.in_trio_run() + assert not trio.lowlevel.in_trio_task() + + async def trio_main(in_host: InHost) -> None: + for _ in range(2): + assert trio.lowlevel.in_trio_run() + assert trio.lowlevel.in_trio_task() + + in_host(lambda: check(trio.lowlevel.in_trio_run())) + in_host(lambda: check(not trio.lowlevel.in_trio_task())) + + trivial_guest_run(trio_main) + assert not trio.lowlevel.in_trio_run() + assert not trio.lowlevel.in_trio_task() + + def test_warn_set_wakeup_fd_overwrite() -> None: assert signal.set_wakeup_fd(-1) == -1 diff --git a/src/trio/_core/_tests/test_instrumentation.py b/src/trio/_core/_tests/test_instrumentation.py index 220ac9314..60c54307e 100644 --- a/src/trio/_core/_tests/test_instrumentation.py +++ b/src/trio/_core/_tests/test_instrumentation.py @@ -266,3 +266,50 @@ async def main() -> None: assert "task_exited" not in runner.instruments _core.run(main) + + +def test_instrument_call_trio_context() -> None: + called = set() + + class Instrument(_abc.Instrument): + pass + + hooks = { + # not run in task context + "after_io_wait": (True, False), + "before_io_wait": (True, False), + "before_run": (True, False), + "after_run": (True, False), + # run in task context + "before_task_step": (True, True), + "after_task_step": (True, True), + "task_exited": (True, True), + # depends + "task_scheduled": (True, None), + "task_spawned": (True, None), + } + for hook, val in hooks.items(): + + def h( + self: Instrument, + *args: object, + hook: str = hook, + val: tuple[bool, bool | None] = val, + ) -> None: + fail_str = f"failed in {hook}" + + assert _core.in_trio_run() == val[0], fail_str + if val[1] is not None: + assert _core.in_trio_task() == val[1], fail_str + called.add(hook) + + setattr(Instrument, hook, h) + + async def main() -> None: + await _core.checkpoint() + + async with _core.open_nursery() as nursery: + nursery.start_soon(_core.checkpoint) + + _core.run(main, instruments=[Instrument()]) + assert called == set(hooks) diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index 75e5457d7..576b807f9 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -2855,3 +2855,34 @@ def run(self, fn: Callable[[], object]) -> object: with mock.patch("trio._core._run.copy_context", return_value=Context()): assert _count_context_run_tb_frames() == 1 + + +@restore_unraisablehook() +def test_trio_context_detection() -> None: + assert not _core.in_trio_run() + assert not _core.in_trio_task() + + def inner() -> None: + assert _core.in_trio_run() + assert _core.in_trio_task() + + def sync_inner() -> None: + assert not _core.in_trio_run() + assert not _core.in_trio_task() + + def inner_abort(_: object) -> _core.Abort: + assert _core.in_trio_run() + assert _core.in_trio_task() + return _core.Abort.SUCCEEDED + + async def main() -> None: + assert _core.in_trio_run() + assert _core.in_trio_task() + + inner() + + await to_thread_run_sync(sync_inner) + with _core.CancelScope(deadline=_core.current_time() - 1): + await _core.wait_task_rescheduled(inner_abort) + + _core.run(main) diff --git a/src/trio/lowlevel.py b/src/trio/lowlevel.py index 9e385a004..bbeab6af1 100644 --- a/src/trio/lowlevel.py +++ b/src/trio/lowlevel.py @@ -37,6 +37,8 @@ currently_ki_protected as currently_ki_protected, disable_ki_protection as disable_ki_protection, enable_ki_protection as enable_ki_protection, + in_trio_run as in_trio_run, + in_trio_task as in_trio_task, notify_closing as notify_closing, permanently_detach_coroutine_object as permanently_detach_coroutine_object, reattach_detached_coroutine_object as reattach_detached_coroutine_object,