Skip to content

Commit 36247d1

Browse files
Add support for polling.
1 parent 545e813 commit 36247d1

File tree

2 files changed

+43
-54
lines changed

2 files changed

+43
-54
lines changed

presto/client.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -423,43 +423,15 @@ def process(self, http_response):
423423
)
424424

425425

426-
class PrestoResult(object):
426+
class PrestoQuery(object):
427427
"""
428-
Represent the result of a Presto query as an iterator on rows.
428+
Represent the execution of a SQL statement by Presto.
429429
430-
This class implements the iterator protocol as a generator type
430+
Results of the query can be extracted by iterating over this class, since it
431+
implements the iterator protocol as a generator type
431432
https://docs.python.org/3/library/stdtypes.html#generator-types
432433
"""
433434

434-
def __init__(self, query, rows=None):
435-
self._query = query
436-
self._rows = rows or []
437-
self._rownumber = 0
438-
439-
@property
440-
def rownumber(self):
441-
# type: () -> int
442-
return self._rownumber
443-
444-
def __iter__(self):
445-
# Initial fetch from the first POST request
446-
for row in self._rows:
447-
self._rownumber += 1
448-
yield row
449-
self._rows = None
450-
451-
# Subsequent fetches from GET requests until next_uri is empty.
452-
while not self._query.is_finished():
453-
rows = self._query.fetch()
454-
for row in rows:
455-
self._rownumber += 1
456-
logger.debug("row {}".format(row))
457-
yield row
458-
459-
460-
class PrestoQuery(object):
461-
"""Represent the execution of a SQL statement by Presto."""
462-
463435
def __init__(
464436
self,
465437
request, # type: PrestoRequest
@@ -476,7 +448,9 @@ def __init__(
476448
self._cancelled = False
477449
self._request = request
478450
self._sql = sql
479-
self._result = PrestoResult(self)
451+
452+
self._rows = []
453+
self._rownumber = 0
480454

481455
@property
482456
def columns(self):
@@ -490,10 +464,6 @@ def stats(self):
490464
def warnings(self):
491465
return self._warnings
492466

493-
@property
494-
def result(self):
495-
return self._result
496-
497467
def execute(self):
498468
# type: () -> PrestoResult
499469
"""Initiate a Presto query by sending the SQL statement
@@ -514,10 +484,10 @@ def execute(self):
514484
self._warnings = getattr(status, "warnings", [])
515485
if status.next_uri is None:
516486
self._finished = True
517-
self._result = PrestoResult(self, status.rows)
518-
return self._result
487+
self._rows = status.rows
488+
return self
519489

520-
def fetch(self):
490+
def _fetch(self):
521491
# type: () -> List[List[Any]]
522492
"""Continue fetching data for the current query_id"""
523493
response = self._request.get(self._request.next_uri)
@@ -530,6 +500,14 @@ def fetch(self):
530500
self._finished = True
531501
return status.rows
532502

503+
def poll(self):
504+
# type: () -> Dict
505+
"""Retrieve the current status of a presto query, caching any results."""
506+
if not self.query_id or self._finished:
507+
return self.stats
508+
self._rows.extend(self._fetch())
509+
return self.stats
510+
533511
def cancel(self):
534512
# type: () -> None
535513
"""Cancel the current query"""
@@ -549,3 +527,12 @@ def cancel(self):
549527
def is_finished(self):
550528
# type: () -> bool
551529
return self._finished
530+
531+
def __iter__(self):
532+
while self._rows or not self.is_finished():
533+
for row in self._rows:
534+
self._rownumber += 1
535+
logger.debug('row {}'.format(row))
536+
yield row
537+
self._rows = []
538+
self.poll()

presto/dbapi.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -225,17 +225,18 @@ def warnings(self):
225225
return self._query.warnings
226226
return None
227227

228+
def poll(self):
229+
return self._query.poll()
230+
228231
def setinputsizes(self, sizes):
229232
raise presto.exceptions.NotSupportedError
230233

231234
def setoutputsize(self, size, column):
232235
raise presto.exceptions.NotSupportedError
233236

234237
def execute(self, operation, params=None):
235-
self._query = presto.client.PrestoQuery(self._request, sql=operation)
236-
result = self._query.execute()
237-
self._iterator = iter(result)
238-
return result
238+
self._query = presto.client.PrestoQuery(self._request, sql=operation).execute()
239+
return self._query
239240

240241
def executemany(self, operation, seq_of_params):
241242
raise presto.exceptions.NotSupportedError
@@ -250,13 +251,10 @@ def fetchone(self):
250251
An Error (or subclass) exception is raised if the previous call to
251252
.execute*() did not produce any result set or no call was issued yet.
252253
"""
253-
254-
try:
255-
return next(self._iterator)
256-
except StopIteration:
254+
result = self.fetchmany(1)
255+
if len(result) != 1:
257256
return None
258-
except presto.exceptions.HttpError as err:
259-
raise presto.exceptions.OperationalError(str(err))
257+
return result[0]
260258

261259
def fetchmany(self, size=None):
262260
# type: (Optional[int]) -> List[List[Any]]
@@ -284,16 +282,20 @@ def fetchmany(self, size=None):
284282
size = self.arraysize
285283

286284
result = []
285+
iterator = iter(self._query)
286+
287287
for _ in range(size):
288-
row = self.fetchone()
289-
if row is None:
288+
try:
289+
result.append(next(iterator))
290+
except StopIteration:
290291
break
291-
result.append(row)
292+
except prestodb.exceptions.HttpError as err:
293+
raise prestodb.exceptions.OperationalError(str(err))
292294

293295
return result
294296

295297
def genall(self):
296-
return self._query.result
298+
return self._query
297299

298300
def fetchall(self):
299301
# type: () -> List[List[Any]]

0 commit comments

Comments
 (0)