Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions src/obspec_utils/stores/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator, Iterator, Sequence
import threading
from collections.abc import AsyncIterator, Coroutine, Iterator, Sequence
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar

from obspec import GetResult, GetResultAsync

Expand All @@ -38,6 +39,8 @@
if TYPE_CHECKING:
from obspec import Attributes, GetOptions, ObjectMeta

T = TypeVar("T")

try:
import aiohttp
except ImportError as e:
Expand Down Expand Up @@ -196,6 +199,10 @@ def __init__(
self.headers = headers or {}
self.timeout = aiohttp.ClientTimeout(total=timeout)
self._session: aiohttp.ClientSession | None = None
# Event loop for sync methods when called from async context (e.g., Jupyter)
self._sync_loop: asyncio.AbstractEventLoop | None = None
self._sync_thread: threading.Thread | None = None
self._sync_lock = threading.Lock()

async def __aenter__(self) -> "AiohttpStore":
"""Enter the async context manager, creating a reusable session."""
Expand All @@ -211,6 +218,52 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self._session.close()
self._session = None

def _get_sync_loop(self) -> asyncio.AbstractEventLoop:
"""Get or create event loop for sync operations when inside a running loop."""
if self._sync_loop is None:
with self._sync_lock:
if self._sync_loop is None:
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=loop.run_forever,
name=f"aiohttp_store_{id(self)}",
daemon=True,
)
thread.start()
self._sync_loop = loop
self._sync_thread = thread
return self._sync_loop

def _run_sync(self, coro: Coroutine[None, None, T]) -> T:
"""Run coroutine synchronously, handling nested event loops (e.g., Jupyter)."""
try:
asyncio.get_running_loop()
except RuntimeError:
# No running loop - use asyncio.run() directly
return asyncio.run(coro)

# Inside running loop - use store's dedicated loop
loop = self._get_sync_loop()
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result()

def _cleanup_sync_loop(self) -> None:
"""Stop the sync loop and thread."""
if self._sync_loop is not None:
self._sync_loop.call_soon_threadsafe(self._sync_loop.stop)
if self._sync_thread is not None:
self._sync_thread.join(timeout=1.0)
self._sync_loop = None
self._sync_thread = None

def close(self) -> None:
"""Close the store and release resources."""
self._cleanup_sync_loop()

def __del__(self) -> None:
"""Clean up on garbage collection."""
self._cleanup_sync_loop()

def _build_url(self, path: str) -> str:
"""Build the full URL from base URL and path."""
path = path.removeprefix("/")
Expand Down Expand Up @@ -497,7 +550,7 @@ def get(
AiohttpGetResult
Result object with buffer() method and metadata.
"""
result = asyncio.run(self.get_async(path, options=options))
result = self._run_sync(self.get_async(path, options=options))
return AiohttpGetResult(
_data=result._data,
_meta=result._meta,
Expand Down Expand Up @@ -534,7 +587,7 @@ def get_range(
bytes
The requested byte range.
"""
return asyncio.run(
return self._run_sync(
self.get_range_async(path, start=start, end=end, length=length)
)

Expand Down Expand Up @@ -567,7 +620,7 @@ def get_ranges(
Sequence[bytes]
The requested byte ranges.
"""
return asyncio.run(
return self._run_sync(
self.get_ranges_async(path, starts=starts, ends=ends, lengths=lengths)
)

Expand Down Expand Up @@ -627,7 +680,7 @@ def head(self, path: str) -> ObjectMeta:
ObjectMeta
File metadata including size, last_modified, e_tag, etc.
"""
return asyncio.run(self.head_async(path))
return self._run_sync(self.head_async(path))


__all__ = ["AiohttpStore", "AiohttpGetResult", "AiohttpGetResultAsync"]
78 changes: 78 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,3 +657,81 @@ def test_head_sync(minio_test_file):
assert meta["path"] == minio_test_file["path"]
assert meta["e_tag"] is not None
assert meta["last_modified"] is not None


# --- Nested Event Loop Handling (Jupyter compatibility) ---


@requires_minio
@pytest.mark.asyncio
async def test_sync_methods_from_running_loop(minio_test_file):
"""
Sync methods work when called from within a running event loop.

This simulates the Jupyter notebook environment where an event loop
is already running. The per-store event loop design handles this by
creating a dedicated thread with its own event loop for sync operations.
"""
store = AiohttpStore(minio_test_file["base_url"])

# We're inside an async function, so there's a running event loop.
# Calling sync methods would fail with asyncio.run() but should
# work with the per-store event loop implementation.
try:
# Test head (sync)
meta = store.head(minio_test_file["path"])
assert meta["size"] == len(minio_test_file["content"])

# Test get (sync)
result = store.get(minio_test_file["path"])
assert result.buffer() == minio_test_file["content"]

# Test get_range (sync)
data = store.get_range(minio_test_file["path"], start=0, length=5)
assert bytes(data) == b"01234"

# Test get_ranges (sync)
results = store.get_ranges(
minio_test_file["path"], starts=[0, 10], lengths=[5, 6]
)
assert [bytes(r) for r in results] == [b"01234", b"ABCDEF"]

finally:
store.close()


def test_sync_loop_not_created_outside_async():
"""Sync loop is not created when not inside a running event loop."""
store = AiohttpStore("https://example.com")

# Before any sync call
assert store._sync_loop is None
assert store._sync_thread is None

# close() should be safe even if loop was never created
store.close()
assert store._sync_loop is None


@requires_minio
@pytest.mark.asyncio
async def test_sync_loop_created_inside_async(minio_test_file):
"""Sync loop is lazily created when sync method called from async context."""
store = AiohttpStore(minio_test_file["base_url"])

# Before sync call
assert store._sync_loop is None
assert store._sync_thread is None

# Call sync method from async context
_ = store.head(minio_test_file["path"])

# Sync loop should now exist
assert store._sync_loop is not None
assert store._sync_thread is not None
assert store._sync_thread.is_alive()

# Cleanup
store.close()
assert store._sync_loop is None
assert store._sync_thread is None
Loading