5454from email .utils import parsedate_to_datetime
5555from enum import Enum
5656from time import sleep
57- from typing import Any
57+ from typing import Any , Callable , Optional
5858from typing import cast
5959from typing import Dict
6060from typing import List
6161from typing import Literal
62- from typing import Optional
63- from typing import Tuple
64- from typing import TypedDict
6562from typing import Union
63+ from typing import TypedDict
64+ from typing import Tuple
6665from zoneinfo import ZoneInfo
6766
6867import lz4 .block
6968import requests
69+
70+ # Progress callback type definition
71+ ProgressCallback = Callable [['TrinoStatus' , Dict [str , Any ]], None ]
7072import zstandard
7173from requests import Response
7274from requests import Session
@@ -810,6 +812,7 @@ def __init__(
810812 legacy_primitive_types : bool = False ,
811813 fetch_mode : Literal ["mapped" , "segments" ] = "mapped" ,
812814 heartbeat_interval : float = 60.0 , # seconds
815+ progress_callback : Optional [ProgressCallback ] = None ,
813816 ) -> None :
814817 self ._query_id : Optional [str ] = None
815818 self ._stats : Dict [Any , Any ] = {}
@@ -832,6 +835,8 @@ def __init__(
832835 self ._heartbeat_stop_event = threading .Event ()
833836 self ._heartbeat_failures = 0
834837 self ._heartbeat_enabled = True
838+ self ._progress_callback = progress_callback
839+ self ._last_progress_stats = {}
835840
836841 @property
837842 def query_id (self ) -> Optional [str ]:
@@ -952,6 +957,10 @@ def _update_state(self, status):
952957 legacy_primitive_types = self ._legacy_primitive_types )
953958 if status .columns :
954959 self ._columns = status .columns
960+
961+ # Call progress callback if provided
962+ if self ._progress_callback is not None :
963+ self ._progress_callback (status , self ._stats )
955964
956965 def fetch (self ) -> List [Union [List [Any ]], Any ]:
957966 """Continue fetching data for the current query_id"""
@@ -1034,6 +1043,35 @@ def is_running(self) -> bool:
10341043 """Return True if the query is still running (not finished or cancelled)."""
10351044 return not self .finished and not self .cancelled
10361045
1046+ def calculate_progress_percentage (self , stats : Dict [str , Any ]) -> float :
1047+ """
1048+ Calculate progress percentage based on available statistics.
1049+
1050+ Args:
1051+ stats: The current query statistics from Trino
1052+
1053+ Returns:
1054+ Progress percentage as a float between 0.0 and 100.0
1055+ """
1056+ # Try to calculate progress based on splits completion
1057+ if 'completedSplits' in stats and 'totalSplits' in stats :
1058+ completed_splits = stats .get ('completedSplits' , 0 )
1059+ total_splits = stats .get ('totalSplits' , 0 )
1060+ if total_splits > 0 :
1061+ return min (100.0 , (completed_splits / total_splits ) * 100.0 )
1062+
1063+ # Fallback: check if query is finished
1064+ if stats .get ('state' ) == 'FINISHED' :
1065+ return 100.0
1066+
1067+ # If query is running but we don't have split info, estimate based on time
1068+ # This is a rough estimate and may not be accurate
1069+ if stats .get ('state' ) == 'RUNNING' :
1070+ # Return a conservative estimate - could be enhanced with more sophisticated logic
1071+ return 5.0 # Assume some progress has been made
1072+
1073+ return 0.0
1074+
10371075
10381076def _retry_with (handle_retry , handled_exceptions , conditions , max_attempts ):
10391077 def wrapper (func ):
0 commit comments