Skip to content

Commit 8a3e034

Browse files
committed
add an option deferred_fetch to Cursor.execute()
1 parent 887bbd0 commit 8a3e034

File tree

4 files changed

+85
-15
lines changed

4 files changed

+85
-15
lines changed

tests/unit/test_client.py

+51
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,57 @@ def json(self):
10181018
assert isinstance(result, TrinoResult)
10191019

10201020

1021+
def test_trino_query_deferred_fetch(sample_get_response_data):
1022+
"""
1023+
Validates that the `TrinoQuery.execute` function deferred_fetch and non-block execution
1024+
"""
1025+
1026+
class MockResponse(mock.Mock):
1027+
# Fake response class
1028+
@property
1029+
def headers(self):
1030+
return {
1031+
'X-Trino-Fake-1': 'one',
1032+
'X-Trino-Fake-2': 'two',
1033+
}
1034+
1035+
def json(self):
1036+
return sample_get_response_data
1037+
1038+
rows = sample_get_response_data['data']
1039+
sample_get_response_data['data'] = []
1040+
sql = 'SELECT 1'
1041+
request = TrinoRequest(
1042+
host="coordinator",
1043+
port=8080,
1044+
client_session=ClientSession(
1045+
user="test",
1046+
source="test",
1047+
catalog="test",
1048+
schema="test",
1049+
properties={},
1050+
),
1051+
http_scheme="http",
1052+
)
1053+
query = TrinoQuery(
1054+
request=request,
1055+
query=sql
1056+
)
1057+
1058+
with \
1059+
mock.patch.object(request, 'post', return_value=MockResponse()), \
1060+
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch:
1061+
result = query.execute()
1062+
mock_fetch.assert_called_once()
1063+
assert result.rows == rows
1064+
1065+
with \
1066+
mock.patch.object(request, 'post', return_value=MockResponse()), \
1067+
mock.patch.object(query, 'fetch', return_value=rows) as mock_fetch:
1068+
result = query.execute(deferred_fetch=True)
1069+
mock_fetch.assert_not_called()
1070+
1071+
10211072
def test_delay_exponential_without_jitter():
10221073
max_delay = 1200.0
10231074
get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay)

trino/client.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -777,13 +777,18 @@ def result(self):
777777
def info_uri(self):
778778
return self._info_uri
779779

780-
def execute(self, additional_http_headers=None) -> TrinoResult:
781-
"""Initiate a Trino query by sending the SQL statement
782-
783-
This is the first HTTP request sent to the coordinator.
784-
It sets the query_id and returns a Result object used to
785-
track the rows returned by the query. To fetch all rows,
786-
call fetch() until finished is true.
780+
def execute(
781+
self,
782+
additional_http_headers: Optional[Dict[str, Any]] = None,
783+
deferred_fetch: bool = False,
784+
) -> TrinoResult:
785+
"""Initiate a Trino query by sending the SQL statement to the coordinator.
786+
To fetch all rows, call fetch() until finished is true.
787+
788+
Parameters:
789+
additional_http_headers: extra headers send to the Trino server.
790+
deferred_fetch: By default, the execution is blocked until at least one row is received
791+
or query is finished or cancelled. To continue without waiting the result, set deferred_fetch=True.
787792
"""
788793
if self.cancelled:
789794
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)
@@ -804,9 +809,11 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
804809
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
805810
self._result = TrinoResult(self, rows)
806811

807-
# Execute should block until at least one row is received or query is finished or cancelled
808-
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
809-
self._result.rows += self.fetch()
812+
if not deferred_fetch:
813+
# Execute should block until at least one row is received or query is finished or cancelled
814+
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
815+
self._result.rows += self.fetch()
816+
810817
return self._result
811818

812819
def _update_state(self, status):

trino/dbapi.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,10 @@ def _deallocate_prepared_statement(self, statement_name: str) -> None:
558558
def _generate_unique_statement_name(self):
559559
return 'st_' + uuid.uuid4().hex.replace('-', '')
560560

561-
def execute(self, operation, params=None):
561+
def execute(self, operation, params=None, **kwargs: Any):
562+
additional_http_headers = kwargs.get("additional_http_headers", None)
563+
deferred_fetch = kwargs.get("deferred_fetch", False)
564+
562565
if params:
563566
assert isinstance(params, (list, tuple)), (
564567
'params must be a list or tuple containing the query '
@@ -575,7 +578,10 @@ def execute(self, operation, params=None):
575578
self._query = self._execute_prepared_statement(
576579
statement_name, params
577580
)
578-
self._iterator = iter(self._query.execute())
581+
self._iterator = iter(self._query.execute(
582+
additional_http_headers=additional_http_headers,
583+
deferred_fetch=deferred_fetch,
584+
))
579585
finally:
580586
# Send deallocate statement
581587
# At this point the query can be deallocated since it has already
@@ -584,12 +590,18 @@ def execute(self, operation, params=None):
584590
self._deallocate_prepared_statement(statement_name)
585591
else:
586592
self._query = self._execute_immediate_statement(operation, params)
587-
self._iterator = iter(self._query.execute())
593+
self._iterator = iter(self._query.execute(
594+
additional_http_headers=additional_http_headers,
595+
deferred_fetch=deferred_fetch,
596+
))
588597

589598
else:
590599
self._query = trino.client.TrinoQuery(self._request, query=operation,
591600
legacy_primitive_types=self._legacy_primitive_types)
592-
self._iterator = iter(self._query.execute())
601+
self._iterator = iter(self._query.execute(
602+
additional_http_headers=additional_http_headers,
603+
deferred_fetch=deferred_fetch,
604+
))
593605
return self
594606

595607
def executemany(self, operation, seq_of_params):

trino/sqlalchemy/dialect.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
396396
def do_execute(
397397
self, cursor: Cursor, statement: str, parameters: Tuple[Any, ...], context: DefaultExecutionContext = None
398398
):
399-
cursor.execute(statement, parameters)
399+
cursor.execute(statement, parameters, **context.execution_options)
400400

401401
def do_rollback(self, dbapi_connection: trino_dbapi.Connection):
402402
if dbapi_connection.transaction is not None:

0 commit comments

Comments
 (0)