Skip to content

Commit 52e8ce0

Browse files
committed
ADD: Live client callback exception warning
1 parent 5ab1145 commit 52e8ce0

File tree

5 files changed

+116
-41
lines changed

5 files changed

+116
-41
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#### Enhancements
66
- Added a property `Live.session_id` which returns the streaming session ID when the client is connected
77
- Streams added with `Live.add_stream()` which do not define an exception handler will now emit a warning if an exception is raised while executing the callback
8+
- Callback functions added with `Live.add_callback()` which do not define an exception handler will now emit a warning if an exception is raised while executing the callback
89
- Upgraded `databento-dbn` to 0.44.0
910
- Added logic to set `code` when upgrading version 1 `SystemMsg` to newer versions
1011

databento/common/types.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import datetime as dt
22
import logging
3-
import warnings
43
from collections.abc import Callable
54
from os import PathLike
65
import pathlib
7-
from typing import IO, Generic
6+
from typing import Generic
7+
from typing import IO
88
from typing import TypedDict
99
from typing import TypeVar
10+
import warnings
1011

1112
import databento_dbn
1213
import pandas as pd
@@ -39,10 +40,6 @@
3940
| databento_dbn.ErrorMsgV1
4041
)
4142

42-
RecordCallback = Callable[[DBNRecord], None]
43-
ExceptionCallback = Callable[[Exception], None]
44-
ReconnectCallback = Callable[[pd.Timestamp, pd.Timestamp], None]
45-
4643
_T = TypeVar("_T")
4744

4845

@@ -97,6 +94,11 @@ class MappingIntervalDict(TypedDict):
9794
symbol: str
9895

9996

97+
RecordCallback = Callable[[DBNRecord], None]
98+
ExceptionCallback = Callable[[Exception], None]
99+
ReconnectCallback = Callable[[pd.Timestamp, pd.Timestamp], None]
100+
101+
100102
class ClientStream:
101103
def __init__(
102104
self,
@@ -213,3 +215,73 @@ def _warn(self, msg: str) -> None:
213215
BentoWarning,
214216
stacklevel=3,
215217
)
218+
219+
220+
class ClientRecordCallback:
221+
def __init__(
222+
self,
223+
fn: RecordCallback,
224+
exc_fn: ExceptionCallback | None = None,
225+
max_warnings: int = 10,
226+
) -> None:
227+
if not callable(fn):
228+
raise ValueError(f"{fn} is not callable")
229+
if exc_fn is not None and not callable(exc_fn):
230+
raise ValueError(f"{exc_fn} is not callable")
231+
232+
self._fn = fn
233+
self._exc_fn = exc_fn
234+
self._max_warnings = max(0, max_warnings)
235+
self._warning_count = 0
236+
237+
@property
238+
def callback_name(self) -> str:
239+
return getattr(self._fn, "__name__", str(self._fn))
240+
241+
@property
242+
def exc_callback_name(self) -> str:
243+
return getattr(self._exc_fn, "__name__", str(self._exc_fn))
244+
245+
def call(self, record: DBNRecord) -> None:
246+
"""
247+
Execute the callback function, passing `record` in as the first
248+
argument. Any exceptions encountered will be dispatched to the
249+
exception callback, if defined.
250+
251+
Parameters
252+
----------
253+
record : DBNRecord
254+
255+
"""
256+
try:
257+
self._fn(record)
258+
except Exception as exc:
259+
if self._exc_fn is None:
260+
self._warn(
261+
f"callback '{self.callback_name}' encountered an exception without an exception callback: {repr(exc)}",
262+
)
263+
else:
264+
try:
265+
self._exc_fn(exc)
266+
except Exception as inner_exc:
267+
self._warn(
268+
f"exception callback '{self.exc_callback_name}' encountered an exception: {repr(inner_exc)}",
269+
)
270+
raise inner_exc from exc
271+
raise exc
272+
273+
def _warn(self, msg: str) -> None:
274+
logger.warning(msg)
275+
if self._warning_count < self._max_warnings:
276+
self._warning_count += 1
277+
warnings.warn(
278+
msg,
279+
BentoWarning,
280+
stacklevel=3,
281+
)
282+
if self._warning_count == self._max_warnings:
283+
warnings.warn(
284+
f"suppressing further warnings for '{self.callback_name}'",
285+
BentoWarning,
286+
stacklevel=3,
287+
)

databento/live/client.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from databento.common.error import BentoError
2424
from databento.common.parsing import optional_datetime_to_unix_nanoseconds
2525
from databento.common.publishers import Dataset
26-
from databento.common.types import ClientStream, DBNRecord
26+
from databento.common.types import ClientRecordCallback
27+
from databento.common.types import ClientStream
28+
from databento.common.types import DBNRecord
2729
from databento.common.types import ExceptionCallback
2830
from databento.common.types import ReconnectCallback
2931
from databento.common.types import RecordCallback
@@ -110,7 +112,7 @@ def __init__(
110112
reconnect_policy=reconnect_policy,
111113
)
112114

113-
self._session._user_callbacks.append((self._map_symbol, None))
115+
self._session._user_callbacks.append(ClientRecordCallback(self._map_symbol))
114116

115117
with Live._lock:
116118
if not Live._thread.is_alive():
@@ -269,7 +271,9 @@ def add_callback(
269271
A callback to register for handling live records as they arrive.
270272
exception_callback : Callable[[Exception], None], optional
271273
An error handling callback to process exceptions that are raised
272-
in `record_callback`.
274+
in `record_callback`. If no exception callback is provided,
275+
any exceptions encountered will be logged and raised as warnings
276+
for visibility.
273277
274278
Raises
275279
------
@@ -282,15 +286,13 @@ def add_callback(
282286
Live.add_stream
283287
284288
"""
285-
if not callable(record_callback):
286-
raise ValueError(f"{record_callback} is not callable")
287-
288-
if exception_callback is not None and not callable(exception_callback):
289-
raise ValueError(f"{exception_callback} is not callable")
289+
client_callback = ClientRecordCallback(
290+
fn=record_callback,
291+
exc_fn=exception_callback,
292+
)
290293

291-
callback_name = getattr(record_callback, "__name__", str(record_callback))
292-
logger.info("adding user callback %s", callback_name)
293-
self._session._user_callbacks.append((record_callback, exception_callback))
294+
logger.info("adding user callback %s", client_callback.callback_name)
295+
self._session._user_callbacks.append(client_callback)
294296

295297
def add_stream(
296298
self,

databento/live/session.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from databento.common.enums import ReconnectPolicy
2020
from databento.common.error import BentoError
2121
from databento.common.publishers import Dataset
22-
from databento.common.types import ClientStream, DBNRecord
22+
from databento.common.types import ClientRecordCallback
23+
from databento.common.types import ClientStream
24+
from databento.common.types import DBNRecord
2325
from databento.common.types import ExceptionCallback
2426
from databento.common.types import ReconnectCallback
25-
from databento.common.types import RecordCallback
2627
from databento.live.gateway import SubscriptionRequest
2728
from databento.live.protocol import DatabentoLiveProtocol
2829

@@ -195,8 +196,8 @@ def __init__(
195196
api_key: str,
196197
dataset: Dataset | str,
197198
dbn_queue: DBNQueue,
198-
user_callbacks: list[tuple[RecordCallback, ExceptionCallback | None]],
199199
user_streams: list[ClientStream],
200+
user_callbacks: list[ClientRecordCallback],
200201
loop: asyncio.AbstractEventLoop,
201202
metadata: SessionMetadata,
202203
ts_out: bool = False,
@@ -237,18 +238,16 @@ def received_record(self, record: DBNRecord) -> None:
237238
return super().received_record(record)
238239

239240
def _dispatch_callbacks(self, record: DBNRecord) -> None:
240-
for callback, exc_callback in self._user_callbacks:
241+
for callback in self._user_callbacks:
241242
try:
242-
callback(record)
243+
callback.call(record)
243244
except Exception as exc:
244245
logger.error(
245246
"error dispatching %s to `%s` callback",
246247
type(record).__name__,
247-
getattr(callback, "__name__", str(callback)),
248+
callback.callback_name,
248249
exc_info=exc,
249250
)
250-
if exc_callback is not None:
251-
exc_callback(exc)
252251

253252
def _dispatch_writes(self, record: DBNRecord) -> None:
254253
record_bytes = bytes(record)
@@ -315,8 +314,8 @@ def __init__(
315314
self._loop = loop
316315
self._metadata = SessionMetadata()
317316
self._user_gateway: str | None = user_gateway
318-
self._user_callbacks: list[tuple[RecordCallback, ExceptionCallback | None]] = []
319317
self._user_streams: list[ClientStream] = []
318+
self._user_callbacks: list[ClientRecordCallback] = []
320319
self._user_reconnect_callbacks: list[tuple[ReconnectCallback, ExceptionCallback | None]] = (
321320
[]
322321
)
@@ -527,19 +526,20 @@ async def wait_for_close(self) -> None:
527526
return
528527

529528
try:
530-
await self._protocol.authenticated
531-
except Exception as exc:
532-
raise BentoError(exc) from None
533-
534-
try:
535-
if self._reconnect_task is not None:
536-
await self._reconnect_task
537-
else:
538-
await self._protocol.disconnected
539-
except Exception as exc:
540-
raise BentoError(exc) from None
529+
try:
530+
await self._protocol.authenticated
531+
except Exception as exc:
532+
raise BentoError(exc) from None
541533

542-
self._cleanup()
534+
try:
535+
if self._reconnect_task is not None:
536+
await self._reconnect_task
537+
else:
538+
await self._protocol.disconnected
539+
except Exception as exc:
540+
raise BentoError(exc) from None
541+
finally:
542+
self._cleanup()
543543

544544
def _cleanup(self) -> None:
545545
logger.debug("cleaning up session_id=%s", self.session_id)

tests/test_live_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -983,15 +983,15 @@ def test_live_add_callback(
983983
"""
984984

985985
# Arrange
986-
def callback(_: object) -> None:
986+
def test_callback(_: object) -> None:
987987
pass
988988

989989
# Act
990-
live_client.add_callback(callback)
990+
live_client.add_callback(test_callback)
991991

992992
# Assert
993993
assert len(live_client._session._user_callbacks) == 2 # include map_symbols callback
994-
assert (callback, None) in live_client._session._user_callbacks
994+
assert live_client._session._user_callbacks[-1].callback_name == "test_callback"
995995

996996

997997
def test_live_add_stream(

0 commit comments

Comments
 (0)