From e764a6a809e20beccfe873e1041612a6a97f6d0c Mon Sep 17 00:00:00 2001
From: shaheer <hshaheer99@gmail.com>
Date: Tue, 18 Jul 2023 23:55:56 +0530
Subject: [PATCH] Add defer_connect config to allow eagerly verifying
 connection

This commit adds a new connection parameter `defer_connect` which can be
set to False to force creating a connection when `trino.dbapi.connect`
is called. Any connection errors as a result of that get rewrapped into
`trino.exceptions.TrinoConnectionError`.

By default `defer_connect` is set to `True` so users can explicitly call
`trino.dbapi.Connection.connect` to do the connection check.

This doesn't end up actually executing a query on the server because
after the initial POST request the nextUri in the response is not
followed which leaves the query in QUEUED state. This is not documented
in the Trino REST API but the server does behave like this today. The
benefit is that we can very cheaply verify if the connection is valid
without polluting the server's query history or adding queries to queue.

Some unit tests today relied on the lazy connection behaviour so they
have been adjusted accrodingly.
---
 tests/unit/sqlalchemy/test_dialect.py |  3 ++-
 tests/unit/test_dbapi.py              | 18 +++++++++--------
 trino/dbapi.py                        | 28 +++++++++++++++++++++++++++
 3 files changed, 40 insertions(+), 9 deletions(-)

diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py
index c385ab90..cb62fc01 100644
--- a/tests/unit/sqlalchemy/test_dialect.py
+++ b/tests/unit/sqlalchemy/test_dialect.py
@@ -252,7 +252,8 @@ def test_get_default_isolation_level(self):
         assert isolation_level == "AUTOCOMMIT"
 
     def test_isolation_level(self):
-        dbapi_conn = Connection(host="localhost")
+        # The test only verifies that isolation level is correctly set, no need to attempt actual connection
+        dbapi_conn = Connection(host="localhost", defer_connect=True)
 
         self.dialect.set_isolation_level(dbapi_conn, "SERIALIZABLE")
         assert dbapi_conn._isolation_level == IsolationLevel.SERIALIZABLE
diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py
index b56466a2..e0367c86 100644
--- a/tests/unit/test_dbapi.py
+++ b/tests/unit/test_dbapi.py
@@ -184,7 +184,8 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post
         conn2.cursor().execute("SELECT 2")
         conn2.cursor().execute("SELECT 3")
 
-    assert len(_post_statement_requests()) == 7
+    assert len(_post_statement_requests()) == 9
+    # assert only a single token request was sent
     assert len(_get_token_requests(challenge_id)) == 1
 
 
@@ -275,37 +276,38 @@ def test_role_is_set_when_specified(mock_client):
 
 
 def test_hostname_parsing():
-    https_server_with_port = Connection("https://mytrinoserver.domain:9999")
+    # Since this test only verifies URL parsing there is no need to attempt actual connection
+    https_server_with_port = Connection("https://mytrinoserver.domain:9999", defer_connect=True)
     assert https_server_with_port.host == "mytrinoserver.domain"
     assert https_server_with_port.port == 9999
     assert https_server_with_port.http_scheme == constants.HTTPS
 
-    https_server_without_port = Connection("https://mytrinoserver.domain")
+    https_server_without_port = Connection("https://mytrinoserver.domain", defer_connect=True)
     assert https_server_without_port.host == "mytrinoserver.domain"
     assert https_server_without_port.port == 8080
     assert https_server_without_port.http_scheme == constants.HTTPS
 
-    http_server_with_port = Connection("http://mytrinoserver.domain:9999")
+    http_server_with_port = Connection("http://mytrinoserver.domain:9999", defer_connect=True)
     assert http_server_with_port.host == "mytrinoserver.domain"
     assert http_server_with_port.port == 9999
     assert http_server_with_port.http_scheme == constants.HTTP
 
-    http_server_without_port = Connection("http://mytrinoserver.domain")
+    http_server_without_port = Connection("http://mytrinoserver.domain", defer_connect=True)
     assert http_server_without_port.host == "mytrinoserver.domain"
     assert http_server_without_port.port == 8080
     assert http_server_without_port.http_scheme == constants.HTTP
 
-    http_server_with_path = Connection("http://mytrinoserver.domain/some_path")
+    http_server_with_path = Connection("http://mytrinoserver.domain/some_path", defer_connect=True)
     assert http_server_with_path.host == "mytrinoserver.domain/some_path"
     assert http_server_with_path.port == 8080
     assert http_server_with_path.http_scheme == constants.HTTP
 
-    only_hostname = Connection("mytrinoserver.domain")
+    only_hostname = Connection("mytrinoserver.domain", defer_connect=True)
     assert only_hostname.host == "mytrinoserver.domain"
     assert only_hostname.port == 8080
     assert only_hostname.http_scheme == constants.HTTP
 
-    only_hostname_with_path = Connection("mytrinoserver.domain/some_path")
+    only_hostname_with_path = Connection("mytrinoserver.domain/some_path", defer_connect=True)
     assert only_hostname_with_path.host == "mytrinoserver.domain/some_path"
     assert only_hostname_with_path.port == 8080
     assert only_hostname_with_path.http_scheme == constants.HTTP
diff --git a/trino/dbapi.py b/trino/dbapi.py
index 62ce893b..ae1348a3 100644
--- a/trino/dbapi.py
+++ b/trino/dbapi.py
@@ -28,6 +28,8 @@
 from typing import Any, Dict, List, NamedTuple, Optional  # NOQA for mypy types
 from urllib.parse import urlparse
 
+from requests.exceptions import RequestException
+
 try:
     from zoneinfo import ZoneInfo
 except ModuleNotFoundError:
@@ -157,6 +159,7 @@ def __init__(
         legacy_prepared_statements=None,
         roles=None,
         timezone=None,
+        defer_connect=False,
     ):
         # Automatically assign http_schema, port based on hostname
         parsed_host = urlparse(host, allow_fragments=False)
@@ -201,6 +204,31 @@ def __init__(
         self.legacy_primitive_types = legacy_primitive_types
         self.legacy_prepared_statements = legacy_prepared_statements
 
+        if not defer_connect:
+            self.connect()
+
+    def connect(self) -> None:
+        connection_test_request = trino.client.TrinoRequest(
+            self.host,
+            self.port,
+            self._client_session,
+            self._http_session,
+            self.http_scheme,
+            self.auth,
+            self.max_attempts,
+            self.request_timeout,
+            verify=self._http_session.verify,
+        )
+        try:
+            test_response = connection_test_request.post("<not-going-to-be-executed>")
+            response_content = test_response.content if test_response.content else ""
+            if not test_response.ok:
+                raise trino.exceptions.TrinoConnectionError(
+                    "error {}: {}".format(test_response.status_code, response_content))
+
+        except RequestException as e:
+            raise trino.exceptions.TrinoConnectionError("connection failed: {}".format(e))
+
     @property
     def isolation_level(self):
         return self._isolation_level