Skip to content

Commit f67f43a

Browse files
committed
added callback injection into session that will return query metrics
1 parent 2b45ed9 commit f67f43a

File tree

3 files changed

+201
-4
lines changed

3 files changed

+201
-4
lines changed

README.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,74 @@ assert rows[0][0] == "-2001-08-22"
566566
assert cur.description[0][1] == "date"
567567
```
568568

569+
## Progress Callback
570+
571+
The Trino client supports progress callbacks to track query execution progress in real-time. you can provide a callback function that gets called whenever the query status is updated.
572+
573+
### Basic Usage
574+
575+
```python
576+
from trino.client import TrinoQuery, TrinoRequest, ClientSession, TrinoStatus
577+
from typing import Dict, Any
578+
579+
def progress_callback(status: TrinoStatus, stats: Dict[str, Any]) -> None:
580+
"""Progress callback function that gets called whenever the query status is updated."""
581+
state = stats.get('state', 'UNKNOWN')
582+
processed_bytes = stats.get('processedBytes', 0)
583+
processed_rows = stats.get('processedRows', 0)
584+
completed_splits = stats.get('completedSplits', 0)
585+
total_splits = stats.get('totalSplits', 0)
586+
587+
print(f"Query {status.id}: {state} - {processed_bytes} bytes, {processed_rows} rows")
588+
if total_splits > 0:
589+
progress = (completed_splits / total_splits) * 100.0
590+
print(f"Progress: {progress:.1f}% ({completed_splits}/{total_splits} splits)")
591+
592+
session = ClientSession(user="test_user", catalog="memory", schema="default")
593+
594+
request = TrinoRequest(
595+
host="localhost",
596+
port=8080,
597+
client_session=session,
598+
http_scheme="http"
599+
)
600+
601+
query = TrinoQuery(
602+
request=request,
603+
query="SELECT * FROM large_table",
604+
progress_callback=progress_callback
605+
)
606+
607+
result = query.execute()
608+
609+
while not query.finished:
610+
rows = query.fetch()
611+
```
612+
613+
### Progress Calculation
614+
615+
The callback receives a `stats` dictionary containing various metrics that can be used to calculate progress:
616+
617+
- `state`: Query state (RUNNING, FINISHED, FAILED, etc.)
618+
- `processedBytes`: Total bytes processed
619+
- `processedRows`: Total rows processed
620+
- `completedSplits`: Number of completed splits
621+
- `totalSplits`: Total number of splits
622+
623+
The most accurate progress calculation is based on splits completion:
624+
625+
```python
626+
def calculate_progress(stats: Dict[str, Any]) -> float:
627+
"""Calculate progress percentage based on splits completion."""
628+
completed_splits = stats.get('completedSplits', 0)
629+
total_splits = stats.get('totalSplits', 0)
630+
if total_splits > 0:
631+
return min(100.0, (completed_splits / total_splits) * 100.0)
632+
elif stats.get('state') == 'FINISHED':
633+
return 100.0
634+
return 0.0
635+
```
636+
569637
### Trino to Python type mappings
570638

571639
| Trino type | Python type |

tests/unit/test_client.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,3 +1593,94 @@ class Resp:
15931593
query._stop_heartbeat()
15941594
# Heartbeat should have stopped after query cancelled
15951595
assert head_call_count >= 1
1596+
1597+
1598+
# Progress Callback Tests
1599+
def test_progress_callback_initialization():
1600+
"""Test that progress callback is properly initialized."""
1601+
req = TrinoRequest(
1602+
host="coordinator",
1603+
port=8080,
1604+
client_session=ClientSession(user="test"),
1605+
http_scheme="http",
1606+
)
1607+
1608+
def callback(status, stats):
1609+
pass
1610+
1611+
query = TrinoQuery(request=req, query="SELECT 1", progress_callback=callback)
1612+
assert query._progress_callback == callback
1613+
1614+
# Test without callback
1615+
query_no_callback = TrinoQuery(request=req, query="SELECT 1")
1616+
assert query_no_callback._progress_callback is None
1617+
1618+
1619+
def test_calculate_progress_percentage():
1620+
"""Test progress percentage calculation."""
1621+
req = TrinoRequest(
1622+
host="coordinator",
1623+
port=8080,
1624+
client_session=ClientSession(user="test"),
1625+
http_scheme="http",
1626+
)
1627+
query = TrinoQuery(request=req, query="SELECT 1")
1628+
1629+
# Test splits-based calculation
1630+
assert query.calculate_progress_percentage({'completedSplits': 5, 'totalSplits': 10}) == 50.0
1631+
assert query.calculate_progress_percentage({'completedSplits': 10, 'totalSplits': 10}) == 100.0
1632+
assert query.calculate_progress_percentage({'completedSplits': 15, 'totalSplits': 10}) == 100.0 # Cap at 100%
1633+
1634+
# Test state-based calculation
1635+
assert query.calculate_progress_percentage({'state': 'FINISHED'}) == 100.0
1636+
assert query.calculate_progress_percentage({'state': 'RUNNING'}) == 5.0
1637+
assert query.calculate_progress_percentage({'state': 'FAILED'}) == 0.0
1638+
1639+
# Test empty stats
1640+
assert query.calculate_progress_percentage({}) == 0.0
1641+
1642+
1643+
@mock.patch("trino.client.TrinoRequest.post")
1644+
@mock.patch("trino.client.TrinoRequest.get")
1645+
def test_progress_callback_execution(mock_get, mock_post):
1646+
"""Test that progress callback is called during query execution."""
1647+
callback_calls = []
1648+
1649+
def callback(status, stats):
1650+
callback_calls.append((status, stats))
1651+
1652+
# Mock responses
1653+
mock_post_response = Mock()
1654+
mock_post_response.json.return_value = {
1655+
'id': 'test_query_id',
1656+
'nextUri': 'http://localhost:8080/v1/statement/test_query_id/1',
1657+
'stats': {'state': 'RUNNING', 'completedSplits': 0, 'totalSplits': 10},
1658+
'data': [],
1659+
'columns': []
1660+
}
1661+
mock_post.return_value = mock_post_response
1662+
1663+
mock_get_response = Mock()
1664+
mock_get_response.json.return_value = {
1665+
'id': 'test_query_id',
1666+
'nextUri': None,
1667+
'stats': {'state': 'FINISHED', 'completedSplits': 10, 'totalSplits': 10},
1668+
'data': [[1]],
1669+
'columns': [{'name': 'col1', 'type': 'integer'}]
1670+
}
1671+
mock_get.return_value = mock_get_response
1672+
1673+
req = TrinoRequest(
1674+
host="coordinator",
1675+
port=8080,
1676+
client_session=ClientSession(user="test"),
1677+
http_scheme="http",
1678+
)
1679+
1680+
query = TrinoQuery(request=req, query="SELECT 1", progress_callback=callback)
1681+
result = query.execute()
1682+
1683+
# Verify callback was called
1684+
assert len(callback_calls) > 0
1685+
assert isinstance(callback_calls[0][0], TrinoStatus)
1686+
assert isinstance(callback_calls[0][1], dict)

trino/client.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,21 @@
5454
from email.utils import parsedate_to_datetime
5555
from enum import Enum
5656
from time import sleep
57-
from typing import Any
57+
from typing import Any, Callable, Optional
5858
from typing import cast
5959
from typing import Dict
6060
from typing import List
6161
from typing import Literal
62-
from typing import Optional
63-
from typing import Tuple
64-
from typing import TypedDict
6562
from typing import Union
63+
from typing import TypedDict
64+
from typing import Tuple
6665
from zoneinfo import ZoneInfo
6766

6867
import lz4.block
6968
import requests
69+
70+
# Progress callback type definition
71+
ProgressCallback = Callable[['TrinoStatus', Dict[str, Any]], None]
7072
import zstandard
7173
from requests import Response
7274
from requests import Session
@@ -810,6 +812,7 @@ def __init__(
810812
legacy_primitive_types: bool = False,
811813
fetch_mode: Literal["mapped", "segments"] = "mapped",
812814
heartbeat_interval: float = 60.0, # seconds
815+
progress_callback: Optional[ProgressCallback] = None,
813816
) -> None:
814817
self._query_id: Optional[str] = None
815818
self._stats: Dict[Any, Any] = {}
@@ -832,6 +835,8 @@ def __init__(
832835
self._heartbeat_stop_event = threading.Event()
833836
self._heartbeat_failures = 0
834837
self._heartbeat_enabled = True
838+
self._progress_callback = progress_callback
839+
self._last_progress_stats = {}
835840

836841
@property
837842
def query_id(self) -> Optional[str]:
@@ -952,6 +957,10 @@ def _update_state(self, status):
952957
legacy_primitive_types=self._legacy_primitive_types)
953958
if status.columns:
954959
self._columns = status.columns
960+
961+
# Call progress callback if provided
962+
if self._progress_callback is not None:
963+
self._progress_callback(status, self._stats)
955964

956965
def fetch(self) -> List[Union[List[Any]], Any]:
957966
"""Continue fetching data for the current query_id"""
@@ -1034,6 +1043,35 @@ def is_running(self) -> bool:
10341043
"""Return True if the query is still running (not finished or cancelled)."""
10351044
return not self.finished and not self.cancelled
10361045

1046+
def calculate_progress_percentage(self, stats: Dict[str, Any]) -> float:
1047+
"""
1048+
Calculate progress percentage based on available statistics.
1049+
1050+
Args:
1051+
stats: The current query statistics from Trino
1052+
1053+
Returns:
1054+
Progress percentage as a float between 0.0 and 100.0
1055+
"""
1056+
# Try to calculate progress based on splits completion
1057+
if 'completedSplits' in stats and 'totalSplits' in stats:
1058+
completed_splits = stats.get('completedSplits', 0)
1059+
total_splits = stats.get('totalSplits', 0)
1060+
if total_splits > 0:
1061+
return min(100.0, (completed_splits / total_splits) * 100.0)
1062+
1063+
# Fallback: check if query is finished
1064+
if stats.get('state') == 'FINISHED':
1065+
return 100.0
1066+
1067+
# If query is running but we don't have split info, estimate based on time
1068+
# This is a rough estimate and may not be accurate
1069+
if stats.get('state') == 'RUNNING':
1070+
# Return a conservative estimate - could be enhanced with more sophisticated logic
1071+
return 5.0 # Assume some progress has been made
1072+
1073+
return 0.0
1074+
10371075

10381076
def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
10391077
def wrapper(func):

0 commit comments

Comments
 (0)