Skip to content

Fix: convert gRPC stream termination to YDB errors in async query client (issue #696) #697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/aio/query/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from unittest import mock

import grpc
import pytest
from grpc._cython import cygrpc

from ydb.aio.query.session import QuerySession
from ydb.aio.query.pool import QuerySessionPool

Expand Down Expand Up @@ -32,3 +37,22 @@ async def tx(session):
async def pool(driver):
async with QuerySessionPool(driver) as pool:
yield pool


@pytest.fixture
async def ydb_terminates_streams_with_unavailable():
async def _patch(self):
message = await self._read() # Read the first message
while message is not cygrpc.EOF: # While the message is not empty, continue reading the stream
yield message
message = await self._read()

# Emulate stream termination
raise grpc.aio.AioRpcError(
code=grpc.StatusCode.UNAVAILABLE,
initial_metadata=await self.initial_metadata(),
trailing_metadata=await self.trailing_metadata(),
)

with mock.patch.object(grpc.aio._call._StreamResponseMixin, "_fetch_stream_responses", _patch):
yield
12 changes: 12 additions & 0 deletions tests/aio/query/test_query_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest

import ydb
from ydb.aio.query.session import QuerySession


Expand Down Expand Up @@ -113,3 +115,13 @@ async def test_two_results(self, session: QuerySession):

assert res == [[1], [2]]
assert counter == 2

@pytest.mark.asyncio
@pytest.mark.usefixtures("ydb_terminates_streams_with_unavailable")
async def test_terminated_stream_raises_ydb_error(self, session: QuerySession):
await session.create()

with pytest.raises(ydb.Unavailable):
async with await session.execute("select 1") as results:
async for _ in results:
pass
11 changes: 11 additions & 0 deletions tests/aio/query/test_query_transaction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import ydb
from ydb.aio.query.transaction import QueryTxContext
from ydb.query.transaction import QueryTxStateEnum

Expand Down Expand Up @@ -107,3 +108,13 @@ async def test_execute_two_results(self, tx: QueryTxContext):

assert res == [[1], [2]]
assert counter == 2

@pytest.mark.asyncio
@pytest.mark.usefixtures("ydb_terminates_streams_with_unavailable")
async def test_terminated_stream_raises_ydb_error(self, tx: QueryTxContext):
await tx.begin()

with pytest.raises(ydb.Unavailable):
async with await tx.execute("select 1") as results:
async for _ in results:
pass
27 changes: 26 additions & 1 deletion ydb/_errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import grpc

from . import issues

Expand Down Expand Up @@ -52,3 +54,26 @@ def check_retriable_error(err, retry_settings, attempt):
class ErrorRetryInfo:
is_retriable: bool
sleep_timeout_seconds: Optional[float]


def stream_error_converter(exc: BaseException) -> Union[issues.Error, BaseException]:
"""Converts gRPC stream errors to appropriate YDB exception types.

This function takes a base exception and converts specific gRPC aio stream errors
to their corresponding YDB exception types for better error handling and semantic
clarity.

Args:
exc (BaseException): The original exception to potentially convert.

Returns:
BaseException: Either a converted YDB exception or the original exception
if no specific conversion rule applies.
"""
if isinstance(exc, (grpc.RpcError, grpc.aio.AioRpcError)):
if exc.code() == grpc.StatusCode.UNAVAILABLE:
return issues.Unavailable(exc.details() or "")
if exc.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
return issues.DeadlineExceed("Deadline exceeded on request")
return issues.Error("Stream has been terminated. Original exception: {}".format(str(exc.details())))
return exc
11 changes: 9 additions & 2 deletions ydb/aio/_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@


class AsyncResponseIterator(object):
def __init__(self, it, wrapper):
def __init__(self, it, wrapper, error_converter=None):
self.it = it.__aiter__()
self.wrapper = wrapper
self.error_converter = error_converter

def cancel(self):
self.it.cancel()
Expand All @@ -17,7 +18,13 @@ def __aiter__(self):
return self

async def _next(self):
res = self.wrapper(await self.it.__anext__())
try:
res = self.wrapper(await self.it.__anext__())
except BaseException as e:
if self.error_converter:
raise self.error_converter(e) from e
raise e

if res is not None:
return res
return await self._next()
Expand Down
6 changes: 4 additions & 2 deletions ydb/aio/query/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from ..._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT
from ..._errors import stream_error_converter


class QuerySession(BaseQuerySession):
Expand Down Expand Up @@ -151,12 +152,13 @@ async def execute(
)

return AsyncResponseContextIterator(
stream_it,
lambda resp: base.wrap_execute_query_response(
it=stream_it,
wrapper=lambda resp: base.wrap_execute_query_response(
rpc_state=None,
response_pb=resp,
session_state=self._state,
session=self,
settings=self._settings,
),
error_converter=stream_error_converter,
)
6 changes: 4 additions & 2 deletions ydb/aio/query/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BaseQueryTxContext,
QueryTxStateEnum,
)
from ..._errors import stream_error_converter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -181,14 +182,15 @@ async def execute(
)

self._prev_stream = AsyncResponseContextIterator(
stream_it,
lambda resp: base.wrap_execute_query_response(
it=stream_it,
wrapper=lambda resp: base.wrap_execute_query_response(
rpc_state=None,
response_pb=resp,
session_state=self._session_state,
tx=self,
commit_tx=commit_tx,
settings=self.session._settings,
),
error_converter=stream_error_converter,
)
return self._prev_stream
Loading