Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions openhands-sdk/openhands/sdk/utils/async_executor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""Reusable async-to-sync execution utility."""

import asyncio
import concurrent.futures
import inspect
import threading
import time
from collections.abc import Callable
from typing import Any

from openhands.sdk.logger import get_logger

logger = get_logger(__name__)

class AsyncExecutor:
"""
Expand All @@ -21,12 +26,21 @@ def __init__(self):
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._lock = threading.Lock()
self._shutdown = threading.Event()

def _ensure_loop(self) -> asyncio.AbstractEventLoop:
def _safe_execute_on_loop(self, callback: Callable[[asyncio.AbstractEventLoop], Any]) -> Any:
"""Ensure the background event loop is running."""
with self._lock:
if self._shutdown.is_set():
raise RuntimeError("AsyncExecutor has been shut down")

if self._loop is not None:
return self._loop
if self._loop.is_running():
return callback(self._loop)

logger.warning("The loop is not empty, but it is not in a running state. "
"Under normal circumstances, this should not happen.")
self._loop.close()

loop = asyncio.new_event_loop()

Expand All @@ -39,18 +53,22 @@ def _runner():

# Wait for loop to start
while not loop.is_running():
pass
time.sleep(0.01)

self._loop = loop
self._thread = t
return loop
return callback(self._loop)

def _shutdown_loop(self) -> None:
"""Shutdown the background event loop."""
if self._shutdown.is_set():
logger.info("AsyncExecutor has been shutdown")
return
with self._lock:
loop, t = self._loop, self._thread
self._loop = None
self._thread = None
self._shutdown.set()

if loop and loop.is_running():
try:
Expand All @@ -59,6 +77,8 @@ def _shutdown_loop(self) -> None:
pass
if t and t.is_alive():
t.join(timeout=1.0)
if t.is_alive():
logger.warning("AsyncExecutor thread did not terminate gracefully")

def run_async(
self,
Expand All @@ -83,15 +103,18 @@ def run_async(
TypeError: If awaitable_or_fn is not a coroutine or async function
asyncio.TimeoutError: If the operation times out
"""
if self._shutdown.is_set():
raise RuntimeError("AsyncExecutor has been shut down")
if inspect.iscoroutine(awaitable_or_fn):
coro = awaitable_or_fn
elif inspect.iscoroutinefunction(awaitable_or_fn):
coro = awaitable_or_fn(*args, **kwargs)
else:
raise TypeError("run_async expects a coroutine or async function")
def submit_task(loop: asyncio.AbstractEventLoop) -> concurrent.futures.Future[Any]:
return asyncio.run_coroutine_threadsafe(coro, loop)

loop = self._ensure_loop()
fut = asyncio.run_coroutine_threadsafe(coro, loop)
fut = self._safe_execute_on_loop(submit_task)
return fut.result(timeout)

def close(self):
Expand Down
Loading