diff --git a/src/obspec_utils/stores/_aiohttp.py b/src/obspec_utils/stores/_aiohttp.py index 1f19432..fcb9986 100644 --- a/src/obspec_utils/stores/_aiohttp.py +++ b/src/obspec_utils/stores/_aiohttp.py @@ -26,10 +26,11 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterator, Iterator, Sequence +import threading +from collections.abc import AsyncIterator, Coroutine, Iterator, Sequence from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar from obspec import GetResult, GetResultAsync @@ -38,6 +39,8 @@ if TYPE_CHECKING: from obspec import Attributes, GetOptions, ObjectMeta +T = TypeVar("T") + try: import aiohttp except ImportError as e: @@ -196,6 +199,10 @@ def __init__( self.headers = headers or {} self.timeout = aiohttp.ClientTimeout(total=timeout) self._session: aiohttp.ClientSession | None = None + # Event loop for sync methods when called from async context (e.g., Jupyter) + self._sync_loop: asyncio.AbstractEventLoop | None = None + self._sync_thread: threading.Thread | None = None + self._sync_lock = threading.Lock() async def __aenter__(self) -> "AiohttpStore": """Enter the async context manager, creating a reusable session.""" @@ -211,6 +218,52 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self._session.close() self._session = None + def _get_sync_loop(self) -> asyncio.AbstractEventLoop: + """Get or create event loop for sync operations when inside a running loop.""" + if self._sync_loop is None: + with self._sync_lock: + if self._sync_loop is None: + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=loop.run_forever, + name=f"aiohttp_store_{id(self)}", + daemon=True, + ) + thread.start() + self._sync_loop = loop + self._sync_thread = thread + return self._sync_loop + + def _run_sync(self, coro: Coroutine[None, None, T]) -> T: + """Run coroutine synchronously, handling nested event loops (e.g., Jupyter).""" + try: + asyncio.get_running_loop() + except RuntimeError: + # No running loop - use asyncio.run() directly + return asyncio.run(coro) + + # Inside running loop - use store's dedicated loop + loop = self._get_sync_loop() + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result() + + def _cleanup_sync_loop(self) -> None: + """Stop the sync loop and thread.""" + if self._sync_loop is not None: + self._sync_loop.call_soon_threadsafe(self._sync_loop.stop) + if self._sync_thread is not None: + self._sync_thread.join(timeout=1.0) + self._sync_loop = None + self._sync_thread = None + + def close(self) -> None: + """Close the store and release resources.""" + self._cleanup_sync_loop() + + def __del__(self) -> None: + """Clean up on garbage collection.""" + self._cleanup_sync_loop() + def _build_url(self, path: str) -> str: """Build the full URL from base URL and path.""" path = path.removeprefix("/") @@ -497,7 +550,7 @@ def get( AiohttpGetResult Result object with buffer() method and metadata. """ - result = asyncio.run(self.get_async(path, options=options)) + result = self._run_sync(self.get_async(path, options=options)) return AiohttpGetResult( _data=result._data, _meta=result._meta, @@ -534,7 +587,7 @@ def get_range( bytes The requested byte range. """ - return asyncio.run( + return self._run_sync( self.get_range_async(path, start=start, end=end, length=length) ) @@ -567,7 +620,7 @@ def get_ranges( Sequence[bytes] The requested byte ranges. """ - return asyncio.run( + return self._run_sync( self.get_ranges_async(path, starts=starts, ends=ends, lengths=lengths) ) @@ -627,7 +680,7 @@ def head(self, path: str) -> ObjectMeta: ObjectMeta File metadata including size, last_modified, e_tag, etc. """ - return asyncio.run(self.head_async(path)) + return self._run_sync(self.head_async(path)) __all__ = ["AiohttpStore", "AiohttpGetResult", "AiohttpGetResultAsync"] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 3ee6429..9a2e2dc 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -657,3 +657,81 @@ def test_head_sync(minio_test_file): assert meta["path"] == minio_test_file["path"] assert meta["e_tag"] is not None assert meta["last_modified"] is not None + + +# --- Nested Event Loop Handling (Jupyter compatibility) --- + + +@requires_minio +@pytest.mark.asyncio +async def test_sync_methods_from_running_loop(minio_test_file): + """ + Sync methods work when called from within a running event loop. + + This simulates the Jupyter notebook environment where an event loop + is already running. The per-store event loop design handles this by + creating a dedicated thread with its own event loop for sync operations. + """ + store = AiohttpStore(minio_test_file["base_url"]) + + # We're inside an async function, so there's a running event loop. + # Calling sync methods would fail with asyncio.run() but should + # work with the per-store event loop implementation. + try: + # Test head (sync) + meta = store.head(minio_test_file["path"]) + assert meta["size"] == len(minio_test_file["content"]) + + # Test get (sync) + result = store.get(minio_test_file["path"]) + assert result.buffer() == minio_test_file["content"] + + # Test get_range (sync) + data = store.get_range(minio_test_file["path"], start=0, length=5) + assert bytes(data) == b"01234" + + # Test get_ranges (sync) + results = store.get_ranges( + minio_test_file["path"], starts=[0, 10], lengths=[5, 6] + ) + assert [bytes(r) for r in results] == [b"01234", b"ABCDEF"] + + finally: + store.close() + + +def test_sync_loop_not_created_outside_async(): + """Sync loop is not created when not inside a running event loop.""" + store = AiohttpStore("https://example.com") + + # Before any sync call + assert store._sync_loop is None + assert store._sync_thread is None + + # close() should be safe even if loop was never created + store.close() + assert store._sync_loop is None + + +@requires_minio +@pytest.mark.asyncio +async def test_sync_loop_created_inside_async(minio_test_file): + """Sync loop is lazily created when sync method called from async context.""" + store = AiohttpStore(minio_test_file["base_url"]) + + # Before sync call + assert store._sync_loop is None + assert store._sync_thread is None + + # Call sync method from async context + _ = store.head(minio_test_file["path"]) + + # Sync loop should now exist + assert store._sync_loop is not None + assert store._sync_thread is not None + assert store._sync_thread.is_alive() + + # Cleanup + store.close() + assert store._sync_loop is None + assert store._sync_thread is None