Skip to content

Move Black to pre-commit #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -18,6 +18,13 @@ repos:
name: "Python: imports"
args: ["--py39-plus"]

- repo: "https://github.com/psf/black-pre-commit-mirror"
rev: "24.10.0"
hooks:
- id: "black"
name: "Python: format"
args: ["--target-version", "py39", "--line-length", "120"]

- repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: "v5.0.0"
hooks:
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -8,7 +8,9 @@ max-line-length = 120
#
# https://www.flake8rules.com/rules/W503.html
# https://www.python.org/dev/peps/pep-0008/#should-a-line-break-before-or-after-a-binary-operator
ignore = W503
#
# E203 and E701 are Black-specific exclusions.
ignore = E203,E701,W503

[mypy]
check_untyped_defs = true
33 changes: 18 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
@@ -27,27 +27,30 @@
readme = f.read()

kerberos_require = ["requests_kerberos"]
gssapi_require = [""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the purpose of "" here. Perhaps it should be removed?

"requests_gssapi",
# PyPy compatibility issue https://github.com/jborean93/pykrb5/issues/49
"krb5 == 0.5.1"]
gssapi_require = [
"" "requests_gssapi",
# PyPy compatibility issue https://github.com/jborean93/pykrb5/issues/49
"krb5 == 0.5.1",
]
sqlalchemy_require = ["sqlalchemy >= 1.3"]
external_authentication_token_cache_require = ["keyring"]

# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.
all_require = kerberos_require + sqlalchemy_require

tests_require = all_require + gssapi_require + [
# httpretty >= 1.1 duplicates requests in `httpretty.latest_requests`
# https://github.com/gabrielfalcao/HTTPretty/issues/425
"httpretty < 1.1",
"pytest",
"pytest-runner",
"pre-commit",
"black",
"isort",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both isort and black are here but were never enforced. Neither were mentioned in the development documentation as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to look at #227 for context.

The TL;DR version was:

  • While standard code format is nice black often makes changes to the detriment of readability
  • If we ever do this we need to figure out a way to be history preserving (probably via blame-ignore config - jvanzyl/provisio@a4cdb8d)
  • black and isort running in pre-commit makes iteration slower sometimes since you are prevented from committing or running ci just because of style issues

"keyring"
]
tests_require = (
all_require
+ gssapi_require
+ [
# httpretty >= 1.1 duplicates requests in `httpretty.latest_requests`
# https://github.com/gabrielfalcao/HTTPretty/issues/425
"httpretty < 1.1",
"pytest",
"pytest-runner",
"pre-commit",
"keyring",
]
)

setup(
name=about["__title__"],
12 changes: 2 additions & 10 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -72,13 +72,7 @@ def start_trino(image_tag=None):


def wait_for_trino_workers(host, port, timeout=180):
request = TrinoRequest(
host=host,
port=port,
client_session=ClientSession(
user="test_fixture"
)
)
request = TrinoRequest(host=host, port=port, client_session=ClientSession(user="test_fixture"))
sql = "SELECT state FROM system.runtime.nodes"
t0 = time.time()
while True:
@@ -124,9 +118,7 @@ def start_trino_and_wait(image_tag=None):
if host:
port = os.environ.get("TRINO_RUNNING_PORT", DEFAULT_PORT)
else:
container_id, proc, host, port = start_local_trino_server(
image_tag
)
container_id, proc, host, port = start_local_trino_server(image_tag)

print("trino.server.hostname {}".format(host))
print("trino.server.port {}".format(port))
895 changes: 504 additions & 391 deletions tests/integration/test_dbapi_integration.py

Large diffs are not rendered by default.

447 changes: 219 additions & 228 deletions tests/integration/test_sqlalchemy_integration.py

Large diffs are not rendered by default.

1,128 changes: 458 additions & 670 deletions tests/integration/test_types_integration.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -218,8 +218,7 @@ def sample_get_error_response_data():
"errorType": "USER_ERROR",
"failureInfo": {
"errorLocation": {"columnNumber": 15, "lineNumber": 1},
"message": "line 1:15: Schema must be specified "
"when session schema is not set",
"message": "line 1:15: Schema must be specified " "when session schema is not set",
"stack": [
"io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)",
"io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)",
@@ -298,4 +297,5 @@ def mock_get_and_post():

def sqlalchemy_version() -> str:
import sqlalchemy

return sqlalchemy.__version__
69 changes: 39 additions & 30 deletions tests/unit/oauth_test_utils.py
Original file line number Diff line number Diff line change
@@ -53,25 +53,28 @@ def __call__(self, request, uri, response_headers):
if authorization and authorization.replace("Bearer ", "") in self.tokens:
return [200, response_headers, json.dumps(self.sample_post_response_data)]
elif self.redirect_server is None and self.token_server is not None:
return [401,
{
'Www-Authenticate': (
'Bearer realm="Trino", token_type="JWT", '
f'Bearer x_token_server="{self.token_server}"'
),
'Basic realm': '"Trino"'
},
""]
return [401,
return [
401,
{
'Www-Authenticate': (
'Bearer realm="Trino", token_type="JWT", '
f'Bearer x_redirect_server="{self.redirect_server}", '
f'x_token_server="{self.token_server}"'
"Www-Authenticate": (
'Bearer realm="Trino", token_type="JWT", ' f'Bearer x_token_server="{self.token_server}"'
),
'Basic realm': '"Trino"'
"Basic realm": '"Trino"',
},
""]
"",
]
return [
401,
{
"Www-Authenticate": (
'Bearer realm="Trino", token_type="JWT", '
f'Bearer x_redirect_server="{self.redirect_server}", '
f'x_token_server="{self.token_server}"'
),
"Basic realm": '"Trino"',
},
"",
]


class GetTokenCallback:
@@ -90,19 +93,19 @@ def __call__(self, request, uri, response_headers):


def _get_token_requests(challenge_id):
return list(filter(
lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}",
httpretty.latest_requests()))
return list(
filter(lambda r: r.method == "GET" and r.path == f"/{TOKEN_PATH}/{challenge_id}", httpretty.latest_requests())
)


def _post_statement_requests():
return list(filter(
lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH,
httpretty.latest_requests()))
return list(
filter(lambda r: r.method == "POST" and r.path == constants.URL_STATEMENT_PATH, httpretty.latest_requests())
)


class MultithreadedTokenServer:
Challenge = namedtuple('Challenge', ['token', 'attempts'])
Challenge = namedtuple("Challenge", ["token", "attempts"])

def __init__(self, sample_post_response_data, attempts=1):
self.tokens = set()
@@ -114,13 +117,13 @@ def __init__(self, sample_post_response_data, attempts=1):
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}{constants.URL_STATEMENT_PATH}",
body=self.post_statement_callback)
body=self.post_statement_callback,
)

# bind get token
httpretty.register_uri(
method=httpretty.GET,
uri=re.compile(rf"{TOKEN_RESOURCE}/.*"),
body=self.get_token_callback)
method=httpretty.GET, uri=re.compile(rf"{TOKEN_RESOURCE}/.*"), body=self.get_token_callback
)

# noinspection PyUnusedLocal
def post_statement_callback(self, request, uri, response_headers):
@@ -135,9 +138,15 @@ def post_statement_callback(self, request, uri, response_headers):
self.challenges[challenge_id] = MultithreadedTokenServer.Challenge(token, self.attempts)
redirect_server = f"{REDIRECT_RESOURCE}/{challenge_id}"
token_server = f"{TOKEN_RESOURCE}/{challenge_id}"
return [401, {'Www-Authenticate': f'Bearer x_redirect_server="{redirect_server}", '
f'x_token_server="{token_server}"',
'Basic realm': '"Trino"'}, ""]
return [
401,
{
"Www-Authenticate": f'Bearer x_redirect_server="{redirect_server}", '
f'x_token_server="{token_server}"',
"Basic realm": '"Trino"',
},
"",
]

# noinspection PyUnusedLocal
def get_token_callback(self, request, uri, response_headers):
2 changes: 1 addition & 1 deletion tests/unit/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ def _assert_sqltype(this: SQLType, that: SQLType):
_assert_sqltype(this.value_type, that.value_type)
elif isinstance(this, ROW):
assert len(this.attr_types) == len(that.attr_types)
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
for this_attr, that_attr in zip(this.attr_types, that.attr_types):
assert this_attr[0] == that_attr[0]
_assert_sqltype(this_attr[1], that_attr[1])

98 changes: 41 additions & 57 deletions tests/unit/sqlalchemy/test_compiler.py
Original file line number Diff line number Diff line change
@@ -27,18 +27,12 @@

metadata = MetaData()
table_without_catalog = Table(
'table',
"table",
metadata,
Column('id', Integer),
Column('name', String),
)
table_with_catalog = Table(
'table',
metadata,
Column('id', Integer),
schema='default',
trino_catalog='other'
Column("id", Integer),
Column("name", String),
)
table_with_catalog = Table("table", metadata, Column("id", Integer), schema="default", trino_catalog="other")


@pytest.fixture
@@ -47,8 +41,7 @@ def dialect():


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable"
)
def test_limit_offset(dialect):
statement = select(table_without_catalog).limit(10).offset(0)
@@ -57,8 +50,7 @@ def test_limit_offset(dialect):


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable"
)
def test_limit(dialect):
statement = select(table_without_catalog).limit(10)
@@ -67,8 +59,7 @@ def test_limit(dialect):


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable"
)
def test_offset(dialect):
statement = select(table_without_catalog).offset(0)
@@ -77,24 +68,23 @@ def test_offset(dialect):


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable"
)
def test_cte_insert_order(dialect):
cte = select(table_without_catalog).cte('cte')
cte = select(table_without_catalog).cte("cte")
statement = insert(table_without_catalog).from_select(table_without_catalog.columns, cte)
query = statement.compile(dialect=dialect)
assert str(query) == \
'INSERT INTO "table" (id, name) WITH cte AS \n'\
'(SELECT "table".id AS id, "table".name AS name \n'\
'FROM "table")\n'\
' SELECT cte.id, cte.name \n'\
'FROM cte'
assert (
str(query) == 'INSERT INTO "table" (id, name) WITH cte AS \n'
'(SELECT "table".id AS id, "table".name AS name \n'
'FROM "table")\n'
" SELECT cte.id, cte.name \n"
"FROM cte"
)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable"
)
def test_catalogs_argument(dialect):
statement = select(table_with_catalog)
@@ -105,68 +95,62 @@ def test_catalogs_argument(dialect):
def test_catalogs_create_table(dialect):
statement = CreateTable(table_with_catalog)
query = statement.compile(dialect=dialect)
assert str(query) == \
'\n'\
'CREATE TABLE "other".default."table" (\n'\
'\tid INTEGER\n'\
')\n'\
'\n'
assert str(query) == "\n" 'CREATE TABLE "other".default."table" (\n' "\tid INTEGER\n" ")\n" "\n"


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable"
)
def test_table_clause(dialect):
statement = select(table("user", column("id"), column("name"), column("description")))
query = statement.compile(dialect=dialect)
assert str(query) == 'SELECT user.id, user.name, user.description \nFROM user'
assert str(query) == "SELECT user.id, user.name, user.description \nFROM user"


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
sqlalchemy_version() < "1.4", reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize(
'function,element',
"function,element",
[
('first_value', func.first_value),
('last_value', func.last_value),
('nth_value', func.nth_value),
('lead', func.lead),
('lag', func.lag),
]
("first_value", func.first_value),
("last_value", func.last_value),
("nth_value", func.nth_value),
("lead", func.lead),
("lag", func.lag),
],
)
def test_ignore_nulls(dialect, function, element):
statement = select(
element(
table_without_catalog.c.id,
ignore_nulls=True,
).over(partition_by=table_without_catalog.c.name).label('window')
)
.over(partition_by=table_without_catalog.c.name)
.label("window")
)
query = statement.compile(dialect=dialect)
assert str(query) == \
f'SELECT {function}("table".id) IGNORE NULLS OVER (PARTITION BY "table".name) AS window '\
f'\nFROM "table"'
assert (
str(query) == f'SELECT {function}("table".id) IGNORE NULLS OVER (PARTITION BY "table".name) AS window '
f'\nFROM "table"'
)

statement = select(
element(
table_without_catalog.c.id,
ignore_nulls=False,
).over(partition_by=table_without_catalog.c.name).label('window')
)
.over(partition_by=table_without_catalog.c.name)
.label("window")
)
query = statement.compile(dialect=dialect)
assert str(query) == \
f'SELECT {function}("table".id) OVER (PARTITION BY "table".name) AS window ' \
f'\nFROM "table"'
assert str(query) == f'SELECT {function}("table".id) OVER (PARTITION BY "table".name) AS window ' f'\nFROM "table"'


@pytest.mark.skipif(
sqlalchemy_version() < "2.0",
reason="ImportError: cannot import name 'try_cast' from 'sqlalchemy'"
)
@pytest.mark.skipif(sqlalchemy_version() < "2.0", reason="ImportError: cannot import name 'try_cast' from 'sqlalchemy'")
def test_try_cast(dialect):
from sqlalchemy import try_cast

statement = select(try_cast(table_without_catalog.c.id, String))
query = statement.compile(dialect=dialect)
assert str(query) == 'SELECT try_cast("table".id as VARCHAR) AS id \nFROM "table"'
4 changes: 2 additions & 2 deletions tests/unit/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
@@ -69,7 +69,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
"CHAR(10)": CHAR(10),
"VARCHAR(10)": VARCHAR(10),
"DECIMAL(20)": DECIMAL(20),
"DECIMAL(20, 3)": DECIMAL(20, 3)
"DECIMAL(20, 3)": DECIMAL(20, 3),
}


@@ -188,7 +188,7 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
"timestamp(3)": TIMESTAMP(3, timezone=False),
"timestamp(6)": TIMESTAMP(6),
"timestamp(12) with time zone": TIMESTAMP(12, timezone=True),
"timestamp with time zone": TIMESTAMP(timezone=True)
"timestamp with time zone": TIMESTAMP(timezone=True),
}


286 changes: 150 additions & 136 deletions tests/unit/sqlalchemy/test_dialect.py

Large diffs are not rendered by default.

259 changes: 123 additions & 136 deletions tests/unit/test_client.py

Large diffs are not rendered by default.

99 changes: 33 additions & 66 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
@@ -70,49 +70,43 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl

# bind post statement to submit query
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)
method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback
)

# bind get statement for result retrieval
httpretty.register_uri(
method=httpretty.GET,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
body=get_statement_callback)
body=get_statement_callback,
)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)
httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)

redirect_handler = RedirectHandler()

with connect(
"coordinator",
user="test",
auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
http_scheme=constants.HTTPS
"coordinator",
user="test",
auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
http_scheme=constants.HTTPS,
) as conn:
conn.cursor().execute("SELECT 1")
conn.cursor().execute("SELECT 2")
conn.cursor().execute("SELECT 3")

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)
httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)

redirect_handler = RedirectHandler()

with connect(
"coordinator",
user="test",
auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
http_scheme=constants.HTTPS
"coordinator",
user="test",
auth=OAuth2Authentication(redirect_auth_url_handler=redirect_handler),
http_scheme=constants.HTTPS,
) as conn2:
conn2.cursor().execute("SELECT 1")
conn2.cursor().execute("SELECT 2")
@@ -122,8 +116,9 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl


@httprettified
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data,
sample_get_response_data):
def test_token_retrieved_once_when_authentication_instance_is_shared(
sample_post_response_data, sample_get_response_data
):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

@@ -135,50 +130,34 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post

# bind post statement to submit query
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)
method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback
)

# bind get statement for result retrieval
httpretty.register_uri(
method=httpretty.GET,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
body=get_statement_callback)
body=get_statement_callback,
)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)
httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)

redirect_handler = RedirectHandler()

authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler)

with connect(
"coordinator",
user="test",
auth=authentication,
http_scheme=constants.HTTPS
) as conn:
with connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) as conn:
conn.cursor().execute("SELECT 1")
conn.cursor().execute("SELECT 2")
conn.cursor().execute("SELECT 3")

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)
httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)

with connect(
"coordinator",
user="test",
auth=authentication,
http_scheme=constants.HTTPS
) as conn2:
with connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS) as conn2:
conn2.cursor().execute("SELECT 1")
conn2.cursor().execute("SELECT 2")
conn2.cursor().execute("SELECT 3")
@@ -200,33 +179,25 @@ def test_token_retrieved_once_when_multithreaded(sample_post_response_data, samp

# bind post statement to submit query
httpretty.register_uri(
method=httpretty.POST,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}",
body=post_statement_callback)
method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", body=post_statement_callback
)

# bind get statement for result retrieval
httpretty.register_uri(
method=httpretty.GET,
uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1",
body=get_statement_callback)
body=get_statement_callback,
)

# bind get token
get_token_callback = GetTokenCallback(token_server, token)
httpretty.register_uri(
method=httpretty.GET,
uri=token_server,
body=get_token_callback)
httpretty.register_uri(method=httpretty.GET, uri=token_server, body=get_token_callback)

redirect_handler = RedirectHandler()

authentication = OAuth2Authentication(redirect_auth_url_handler=redirect_handler)

conn = connect(
"coordinator",
user="test",
auth=authentication,
http_scheme=constants.HTTPS
)
conn = connect("coordinator", user="test", auth=authentication, http_scheme=constants.HTTPS)

class RunningThread(threading.Thread):
lock = threading.Lock()
@@ -238,11 +209,7 @@ def run(self) -> None:
with RunningThread.lock:
conn.cursor().execute("SELECT 1")

threads = [
RunningThread(),
RunningThread(),
RunningThread()
]
threads = [RunningThread(), RunningThread(), RunningThread()]

# run and join all threads
for thread in threads:
@@ -313,4 +280,4 @@ def test_hostname_parsing():
def test_description_is_none_when_cursor_is_not_executed():
connection = Connection("sample_trino_cluster:443")
cursor = connection.cursor()
assert hasattr(cursor, 'description')
assert hasattr(cursor, "description")
4 changes: 1 addition & 3 deletions tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
@@ -27,9 +27,7 @@ def test_isolation_level_levels() -> None:


def test_isolation_level_values() -> None:
values = {
0, 1, 2, 3, 4
}
values = {0, 1, 2, 3, 4}

assert IsolationLevel.values() == values

108 changes: 60 additions & 48 deletions trino/auth.py
Original file line number Diff line number Diff line change
@@ -98,22 +98,24 @@ def get_exceptions(self) -> Tuple[Any, ...]:
try:
from requests_kerberos.exceptions import KerberosExchangeError

return KerberosExchangeError,
return (KerberosExchangeError,)
except ImportError:
raise RuntimeError("unable to import requests_kerberos")

def __eq__(self, other: object) -> bool:
if not isinstance(other, KerberosAuthentication):
return False
return (self._config == other._config
and self._service_name == other._service_name
and self._mutual_authentication == other._mutual_authentication
and self._force_preemptive == other._force_preemptive
and self._hostname_override == other._hostname_override
and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response
and self._principal == other._principal
and self._delegate == other._delegate
and self._ca_bundle == other._ca_bundle)
return (
self._config == other._config
and self._service_name == other._service_name
and self._mutual_authentication == other._mutual_authentication
and self._force_preemptive == other._force_preemptive
and self._hostname_override == other._hostname_override
and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response
and self._principal == other._principal
and self._delegate == other._delegate
and self._ca_bundle == other._ca_bundle
)


class GSSAPIAuthentication(Authentication):
@@ -173,9 +175,9 @@ def _get_credentials(self, principal: Optional[str] = None) -> Any:
return None

def _get_target_name(
self,
hostname_override: Optional[str] = None,
service_name: Optional[str] = None,
self,
hostname_override: Optional[str] = None,
service_name: Optional[str] = None,
) -> Any:
if service_name is not None:
try:
@@ -195,22 +197,24 @@ def get_exceptions(self) -> Tuple[Any, ...]:
try:
from requests_gssapi.exceptions import SPNEGOExchangeError

return SPNEGOExchangeError,
return (SPNEGOExchangeError,)
except ImportError:
raise RuntimeError("unable to import requests_kerberos")

def __eq__(self, other: object) -> bool:
if not isinstance(other, GSSAPIAuthentication):
return False
return (self._config == other._config
and self._service_name == other._service_name
and self._mutual_authentication == other._mutual_authentication
and self._force_preemptive == other._force_preemptive
and self._hostname_override == other._hostname_override
and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response
and self._principal == other._principal
and self._delegate == other._delegate
and self._ca_bundle == other._ca_bundle)
return (
self._config == other._config
and self._service_name == other._service_name
and self._mutual_authentication == other._mutual_authentication
and self._force_preemptive == other._force_preemptive
and self._hostname_override == other._hostname_override
and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response
and self._principal == other._principal
and self._delegate == other._delegate
and self._ca_bundle == other._ca_bundle
)


class BasicAuthentication(Authentication):
@@ -353,8 +357,9 @@ def __init__(self) -> None:
logger.info("keyring module not found. OAuth2 token will not be stored in keyring.")

def is_keyring_available(self) -> bool:
return self._keyring is not None \
and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring)
return self._keyring is not None and not isinstance(
self._keyring.get_keyring(), self._keyring.backends.fail.Keyring
)

def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
password = self._keyring.get_password(key, "token")
@@ -370,9 +375,11 @@ def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
password += str(self._keyring.get_password(key, f"token__{i}"))

except self._keyring.errors.NoKeyringError as e:
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
"information.") from e
raise trino.exceptions.NotSupportedError(
"Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
"information."
) from e
except ValueError:
pass

@@ -388,7 +395,7 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:
logger.debug(f"password is {len(token)} characters, sharding it.")

password_shards = [
token[i: i + MAX_NT_PASSWORD_SIZE] for i in range(0, len(token), MAX_NT_PASSWORD_SIZE)
token[i : i + MAX_NT_PASSWORD_SIZE] for i in range(0, len(token), MAX_NT_PASSWORD_SIZE)
]
shard_info = {
"sharded_password": True,
@@ -401,15 +408,18 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:
for i, s in enumerate(password_shards):
self._keyring.set_password(key, f"token__{i}", s)
except self._keyring.errors.NoKeyringError as e:
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
"information.") from e
raise trino.exceptions.NotSupportedError(
"Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
"information."
) from e


class _OAuth2TokenBearer(AuthBase):
"""
Custom implementation of Trino OAuth2 based authentication to get the token
"""

MAX_OAUTH_ATTEMPTS = 5
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)

@@ -428,9 +438,9 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest:
token = self._get_token_from_cache(key)

if token is not None:
r.headers['Authorization'] = "Bearer " + token
r.headers["Authorization"] = "Bearer " + token

r.register_hook('response', self._authenticate)
r.register_hook("response", self._authenticate)

return r

@@ -455,7 +465,7 @@ def _authenticate(self, response: Response, **kwargs: Any) -> Optional[Response]

def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:
# we have to handle the authentication, may be token the token expired, or it wasn't there at all
auth_info = response.headers.get('WWW-Authenticate')
auth_info = response.headers.get("WWW-Authenticate")
if not auth_info:
raise exceptions.TrinoAuthError("Error: header WWW-Authenticate not available in the response.")

@@ -468,8 +478,8 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:
# x_token_server="https://trino.com/oauth2/token/uuid4"'
auth_info_headers = self._parse_authenticate_header(auth_info)

auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server'))
token_server = auth_info_headers.get('bearer x_token_server', auth_info_headers.get('x_token_server'))
auth_server = auth_info_headers.get("bearer x_redirect_server", auth_info_headers.get("x_redirect_server"))
token_server = auth_info_headers.get("bearer x_token_server", auth_info_headers.get("x_token_server"))
if token_server is None:
raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server")

@@ -500,7 +510,7 @@ def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)
if token is not None:
request.headers['Authorization'] = "Bearer " + token
request.headers["Authorization"] = "Bearer " + token
retry_response = response.connection.send(request, **kwargs)
retry_response.history.append(response)
retry_response.request = request
@@ -510,24 +520,24 @@ def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> st
attempts = 0
while attempts < self.MAX_OAUTH_ATTEMPTS:
attempts += 1
with response.connection.send(Request(
method='GET', url=token_server).prepare(), **kwargs) as response:
with response.connection.send(Request(method="GET", url=token_server).prepare(), **kwargs) as response:
if response.status_code == 200:
token_response = json.loads(response.text)
token = token_response.get('token')
token = token_response.get("token")
if token:
return token
error = token_response.get('error')
error = token_response.get("error")
if error:
raise exceptions.TrinoAuthError(f"Error while getting the token: {error}")
else:
token_server = token_response.get('nextUri')
token_server = token_response.get("nextUri")
logger.debug(f"nextURi auth token server: {token_server}")
else:
raise exceptions.TrinoAuthError(
f"Error while getting the token response "
f"status code: {response.status_code}, "
f"body: {response.text}")
f"body: {response.text}"
)

raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token")

@@ -571,10 +581,12 @@ def _parse_authenticate_header(header: str) -> Dict[str, str]:


class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
WebBrowserRedirectHandler(),
ConsoleRedirectHandler()
])):
def __init__(
self,
redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler(
[WebBrowserRedirectHandler(), ConsoleRedirectHandler()]
),
):
self._redirect_auth_url = redirect_auth_url_handler
self._bearer = _OAuth2TokenBearer(self._redirect_auth_url)

85 changes: 32 additions & 53 deletions trino/client.py
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@
else:
PROXIES = {}

_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r"^\S[^\s=]*$")

ROLE_PATTERN = re.compile(r"^ROLE\{(.*)\}$")

@@ -250,10 +250,12 @@ def _format_roles(self, roles):
is_legacy_role_pattern = ROLE_PATTERN.match(role) is not None
if role in ("NONE", "ALL") or is_legacy_role_pattern:
if is_legacy_role_pattern:
warnings.warn(f"A role '{role}' is provided using a legacy format. "
"Please remove the ROLE{} wrapping. Support for the legacy format might be "
"removed in a future release.",
DeprecationWarning)
warnings.warn(
f"A role '{role}' is provided using a legacy format. "
"Please remove the ROLE{} wrapping. Support for the legacy format might be "
"removed in a future release.",
DeprecationWarning,
)
formatted_roles[catalog] = role
else:
formatted_roles[catalog] = f"ROLE{{{role}}}"
@@ -275,26 +277,17 @@ def get_header_values(headers, header):

def get_session_property_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs if kv)
]
return [(k.strip(), urllib.parse.unquote_plus(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs if kv)]


def get_prepared_statement_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs if kv)
]
return [(k.strip(), urllib.parse.unquote_plus(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs if kv)]


def get_roles_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs if kv)
]
return [(k.strip(), urllib.parse.unquote_plus(v.strip())) for k, v in (kv.split("=", 1) for kv in kvs if kv)]


@dataclass
@@ -324,26 +317,22 @@ def __repr__(self):


class _DelayExponential:
def __init__(
self, base=0.1, exponent=2, jitter=True, max_delay=1800 # 100ms # 30 min
):
def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=1800): # 100ms # 30 min
self._base = base
self._exponent = exponent
self._jitter = jitter
self._max_delay = max_delay

def __call__(self, attempt):
delay = float(self._base) * (self._exponent ** attempt)
delay = float(self._base) * (self._exponent**attempt)
if self._jitter:
delay *= random.random()
delay = min(float(self._max_delay), delay)
return delay


class _RetryWithExponentialBackoff:
def __init__(
self, base=0.1, exponent=2, jitter=True, max_delay=1800 # 100ms # 30 min
):
def __init__(self, base=0.1, exponent=2, jitter=True, max_delay=1800): # 100ms # 30 min
self._get_delay = _DelayExponential(base, exponent, jitter, max_delay)

def retry(self, func, args, kwargs, err, attempt):
@@ -469,7 +458,7 @@ def http_headers(self) -> Dict[str, str]:
headers[constants.HEADER_USER] = self._client_session.user
headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME'
headers[constants.HEADER_CLIENT_CAPABILITIES] = "PARAMETRIC_DATETIME"
headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}"
if len(self._client_session.roles.values()):
headers[constants.HEADER_ROLE] = ",".join(
@@ -502,19 +491,17 @@ def http_headers(self) -> Dict[str, str]:
transaction_id = self._client_session.transaction_id
headers[constants.HEADER_TRANSACTION] = transaction_id

if self._client_session.extra_credential is not None and \
len(self._client_session.extra_credential) > 0:
if self._client_session.extra_credential is not None and len(self._client_session.extra_credential) > 0:

for tup in self._client_session.extra_credential:
self._verify_extra_credential(tup)

# HTTP 1.1 section 4.2 combine multiple extra credentials into a
# comma-separated value
# extra credential value is encoded per spec (application/x-www-form-urlencoded MIME format)
headers[constants.HEADER_EXTRA_CREDENTIAL] = \
", ".join(
[f"{tup[0]}={urllib.parse.quote_plus(str(tup[1]))}"
for tup in self._client_session.extra_credential])
headers[constants.HEADER_EXTRA_CREDENTIAL] = ", ".join(
[f"{tup[0]}={urllib.parse.quote_plus(str(tup[1]))}" for tup in self._client_session.extra_credential]
)

return headers

@@ -623,15 +610,11 @@ def process(self, http_response) -> TrinoStatus:
raise self._process_error(response["error"], response.get("id"))

if constants.HEADER_CLEAR_SESSION in http_response.headers:
for prop in get_header_values(
http_response.headers, constants.HEADER_CLEAR_SESSION
):
for prop in get_header_values(http_response.headers, constants.HEADER_CLEAR_SESSION):
self._client_session.properties.pop(prop, None)

if constants.HEADER_SET_SESSION in http_response.headers:
for key, value in get_session_property_values(
http_response.headers, constants.HEADER_SET_SESSION
):
for key, value in get_session_property_values(http_response.headers, constants.HEADER_SET_SESSION):
self._client_session.properties[key] = value

if constants.HEADER_SET_CATALOG in http_response.headers:
@@ -641,21 +624,15 @@ def process(self, http_response) -> TrinoStatus:
self._client_session.schema = http_response.headers[constants.HEADER_SET_SCHEMA]

if constants.HEADER_SET_ROLE in http_response.headers:
for key, value in get_roles_values(
http_response.headers, constants.HEADER_SET_ROLE
):
for key, value in get_roles_values(http_response.headers, constants.HEADER_SET_ROLE):
self._client_session.roles[key] = value

if constants.HEADER_ADDED_PREPARE in http_response.headers:
for name, statement in get_prepared_statement_values(
http_response.headers, constants.HEADER_ADDED_PREPARE
):
for name, statement in get_prepared_statement_values(http_response.headers, constants.HEADER_ADDED_PREPARE):
self._client_session.prepared_statements[name] = statement

if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers:
for name in get_header_values(
http_response.headers, constants.HEADER_DEALLOCATED_PREPARE
):
for name in get_header_values(http_response.headers, constants.HEADER_DEALLOCATED_PREPARE):
self._client_session.prepared_statements.pop(name, None)

if constants.HEADER_SET_AUTHORIZATION_USER in http_response.headers:
@@ -690,7 +667,7 @@ def _verify_extra_credential(self, header):
raise ValueError(f"whitespace or '=' are disallowed in extra credential '{key}'")

try:
key.encode().decode('ascii')
key.encode().decode("ascii")
except UnicodeDecodeError:
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")

@@ -737,10 +714,10 @@ class TrinoQuery:
"""Represent the execution of a SQL statement by Trino."""

def __init__(
self,
request: TrinoRequest,
query: str,
legacy_primitive_types: bool = False,
self,
request: TrinoRequest,
query: str,
legacy_primitive_types: bool = False,
) -> None:
self._query_id: Optional[str] = None
self._stats: Dict[Any, Any] = {}
@@ -837,8 +814,9 @@ def _update_state(self, status):
self._update_count = status.update_count
self._next_uri = status.next_uri
if not self._row_mapper and status.columns:
self._row_mapper = RowMapperFactory().create(columns=status.columns,
legacy_primitive_types=self._legacy_primitive_types)
self._row_mapper = RowMapperFactory().create(
columns=status.columns, legacy_primitive_types=self._legacy_primitive_types
)
if status.columns:
self._columns = status.columns

@@ -877,6 +855,7 @@ def cancel(self) -> None:

def is_finished(self) -> bool:
import warnings

warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning)
return self.finished

6 changes: 3 additions & 3 deletions trino/constants.py
Original file line number Diff line number Diff line change
@@ -48,9 +48,9 @@
HEADER_STARTED_TRANSACTION = "X-Trino-Started-Transaction-Id"
HEADER_TRANSACTION = "X-Trino-Transaction-Id"

HEADER_PREPARED_STATEMENT = 'X-Trino-Prepared-Statement'
HEADER_ADDED_PREPARE = 'X-Trino-Added-Prepare'
HEADER_DEALLOCATED_PREPARE = 'X-Trino-Deallocated-Prepare'
HEADER_PREPARED_STATEMENT = "X-Trino-Prepared-Statement"
HEADER_ADDED_PREPARE = "X-Trino-Added-Prepare"
HEADER_DEALLOCATED_PREPARE = "X-Trino-Deallocated-Prepare"

HEADER_SET_SCHEMA = "X-Trino-Set-Schema"
HEADER_SET_CATALOG = "X-Trino-Set-Catalog"
96 changes: 42 additions & 54 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,7 @@ class TimeBoundLRUCache:
"""A bounded LRU cache which expires entries after a configured number of seconds.
Note that expired entries will be evicted only on an attempted access (or through
the LRU policy)."""

def __init__(self, capacity: int, ttl_seconds: int):
self.capacity = capacity
self.ttl_seconds = ttl_seconds
@@ -268,7 +269,7 @@ def cursor(self, legacy_primitive_types: bool = None):
self,
request,
# if legacy params are not explicitly set in Cursor, take them from Connection
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types,
)

def _use_legacy_prepared_statements(self):
@@ -278,15 +279,16 @@ def _use_legacy_prepared_statements(self):
value = must_use_legacy_prepared_statements.get((self.host, self.port))
if value is None:
try:
query = trino.client.TrinoQuery(
self._create_request(),
query="EXECUTE IMMEDIATE 'SELECT 1'")
query = trino.client.TrinoQuery(self._create_request(), query="EXECUTE IMMEDIATE 'SELECT 1'")
query.execute()
value = False
except Exception as e:
logger.warning(
"EXECUTE IMMEDIATE not available for %s:%s; defaulting to legacy prepared statements (%s)",
self.host, self.port, e)
self.host,
self.port,
e,
)
value = True
must_use_legacy_prepared_statements.put((self.host, self.port), value)
return value
@@ -327,7 +329,7 @@ def from_column(cls, column: Dict[str, Any]):
arguments[0]["value"] if raw_type in LENGTH_TYPES else None, # internal_size
arguments[0]["value"] if raw_type in PRECISION_TYPES else None, # precision
arguments[1]["value"] if raw_type in SCALE_TYPES else None, # scale
None # null_ok
None, # null_ok
)


@@ -339,15 +341,9 @@ class Cursor:
"""

def __init__(
self,
connection,
request,
legacy_primitive_types: bool = False):
def __init__(self, connection, request, legacy_primitive_types: bool = False):
if not isinstance(connection, Connection):
raise ValueError(
"connection must be a Connection object: {}".format(type(connection))
)
raise ValueError("connection must be a Connection object: {}".format(type(connection)))
self._connection = connection
self._request = request

@@ -381,9 +377,7 @@ def description(self) -> List[ColumnDescription]:
return None

# [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ]
return [
ColumnDescription.from_column(col) for col in self._query.columns
]
return [ColumnDescription.from_column(col) for col in self._query.columns]

@property
def rowcount(self):
@@ -442,16 +436,13 @@ def _prepare_statement(self, statement: str, name: str) -> None:
:param name: name that will be assigned to the prepared statement.
"""
sql = f"PREPARE {name} FROM {statement}"
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
legacy_primitive_types=self._legacy_primitive_types)
query = trino.client.TrinoQuery(
self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types
)
query.execute()

def _execute_prepared_statement(
self,
statement_name,
params
):
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))
def _execute_prepared_statement(self, statement_name, params):
sql = "EXECUTE " + statement_name + " USING " + ",".join(map(self._format_prepared_param, params))
return trino.client.TrinoQuery(self._request, query=sql, legacy_primitive_types=self._legacy_primitive_types)

def _execute_immediate_statement(self, statement: str, params):
@@ -461,10 +452,15 @@ def _execute_immediate_statement(self, statement: str, params):
:param statement: sql to be executed.
:param params: parameters to be bound.
"""
sql = "EXECUTE IMMEDIATE '" + statement.replace("'", "''") + \
"' USING " + ",".join(map(self._format_prepared_param, params))
sql = (
"EXECUTE IMMEDIATE '"
+ statement.replace("'", "''")
+ "' USING "
+ ",".join(map(self._format_prepared_param, params))
)
return trino.client.TrinoQuery(
self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types)
self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types
)

def _format_prepared_param(self, param):
"""
@@ -491,7 +487,7 @@ def _format_prepared_param(self, param):
return "DOUBLE '%s'" % param

if isinstance(param, str):
return ("'%s'" % param.replace("'", "''"))
return "'%s'" % param.replace("'", "''")

if isinstance(param, (bytes, bytearray)):
return "X'%s'" % param.hex()
@@ -517,28 +513,25 @@ def _format_prepared_param(self, param):
time_str = param.strftime("%H:%M:%S.%f")
# named timezones
if isinstance(param.tzinfo, ZoneInfo):
utc_offset = datetime.datetime.now(tz=param.tzinfo).strftime('%z')
utc_offset = datetime.datetime.now(tz=param.tzinfo).strftime("%z")
return "TIME '%s %s:%s'" % (time_str, utc_offset[:3], utc_offset[3:])
# offset-based timezones
return "TIME '%s %s'" % (time_str, param.strftime('%Z')[3:])
return "TIME '%s %s'" % (time_str, param.strftime("%Z")[3:])

if isinstance(param, datetime.date):
date_str = param.strftime("%Y-%m-%d")
return "DATE '%s'" % date_str

if isinstance(param, list):
return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param))
return "ARRAY[%s]" % ",".join(map(self._format_prepared_param, param))

if isinstance(param, tuple):
return "ROW(%s)" % ','.join(map(self._format_prepared_param, param))
return "ROW(%s)" % ",".join(map(self._format_prepared_param, param))

if isinstance(param, dict):
keys = list(param.keys())
values = [param[key] for key in keys]
return "MAP({}, {})".format(
self._format_prepared_param(keys),
self._format_prepared_param(values)
)
return "MAP({}, {})".format(self._format_prepared_param(keys), self._format_prepared_param(values))

if isinstance(param, uuid.UUID):
return "UUID '%s'" % param
@@ -549,19 +542,19 @@ def _format_prepared_param(self, param):
raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))

def _deallocate_prepared_statement(self, statement_name: str) -> None:
sql = 'DEALLOCATE PREPARE ' + statement_name
query = trino.client.TrinoQuery(self.connection._create_request(), query=sql,
legacy_primitive_types=self._legacy_primitive_types)
sql = "DEALLOCATE PREPARE " + statement_name
query = trino.client.TrinoQuery(
self.connection._create_request(), query=sql, legacy_primitive_types=self._legacy_primitive_types
)
query.execute()

def _generate_unique_statement_name(self):
return 'st_' + uuid.uuid4().hex.replace('-', '')
return "st_" + uuid.uuid4().hex.replace("-", "")

def execute(self, operation, params=None):
if params:
assert isinstance(params, (list, tuple)), (
'params must be a list or tuple containing the query '
'parameter values'
"params must be a list or tuple containing the query " "parameter values"
)

if self.connection._use_legacy_prepared_statements():
@@ -571,9 +564,7 @@ def execute(self, operation, params=None):
try:
# Send execute statement and assign the return value to `results`
# as it will be returned by the function
self._query = self._execute_prepared_statement(
statement_name, params
)
self._query = self._execute_prepared_statement(statement_name, params)
self._iterator = iter(self._query.execute())
finally:
# Send deallocate statement
@@ -586,8 +577,9 @@ def execute(self, operation, params=None):
self._iterator = iter(self._query.execute())

else:
self._query = trino.client.TrinoQuery(self._request, query=operation,
legacy_primitive_types=self._legacy_primitive_types)
self._query = trino.client.TrinoQuery(
self._request, query=operation, legacy_primitive_types=self._legacy_primitive_types
)
self._iterator = iter(self._query.execute())
return self

@@ -726,13 +718,9 @@ def __eq__(self, other):

STRING = DBAPITypeObject("VARCHAR", "CHAR", "VARBINARY", "JSON", "IPADDRESS")

BINARY = DBAPITypeObject(
"ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest"
)
BINARY = DBAPITypeObject("ARRAY", "MAP", "ROW", "HyperLogLog", "P4HyperLogLog", "QDigest")

NUMBER = DBAPITypeObject(
"BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL"
)
NUMBER = DBAPITypeObject("BOOLEAN", "TINYINT", "SMALLINT", "INTEGER", "BIGINT", "REAL", "DOUBLE", "DECIMAL")

DATETIME = DBAPITypeObject(
"DATE",
2 changes: 1 addition & 1 deletion trino/logging.py
Original file line number Diff line number Diff line change
@@ -27,4 +27,4 @@ def get_logger(name: str, log_level: Optional[int] = None) -> logging.Logger:


# set default log level to LEVEL
trino_root_logger = get_logger('trino', LEVEL)
trino_root_logger = get_logger("trino", LEVEL)
162 changes: 88 additions & 74 deletions trino/mapper.py
Original file line number Diff line number Diff line change
@@ -44,9 +44,9 @@ def map(self, value: Any) -> Optional[bool]:
return None
if isinstance(value, bool):
return value
if str(value).lower() == 'true':
if str(value).lower() == "true":
return True
if str(value).lower() == 'false':
if str(value).lower() == "false":
return False
raise ValueError(f"Server sent unexpected value {value} of type {type(value)} for boolean")

@@ -65,11 +65,11 @@ class DoubleValueMapper(ValueMapper[float]):
def map(self, value: Any) -> Optional[float]:
if value is None:
return None
if value == 'Infinity':
if value == "Infinity":
return float("inf")
if value == '-Infinity':
if value == "-Infinity":
return float("-inf")
if value == 'NaN':
if value == "NaN":
return float("nan")
return float(value)

@@ -110,12 +110,13 @@ def __init__(self, precision: int):
def map(self, value: Any) -> Optional[time]:
if value is None:
return None
whole_python_temporal_value = value[:self.time_default_size]
remaining_fractional_seconds = value[self.time_default_size + 1:]
return Time(
time.fromisoformat(whole_python_temporal_value),
_fraction_to_decimal(remaining_fractional_seconds)
).round_to(self.precision).to_python_type()
whole_python_temporal_value = value[: self.time_default_size]
remaining_fractional_seconds = value[self.time_default_size + 1 :]
return (
Time(time.fromisoformat(whole_python_temporal_value), _fraction_to_decimal(remaining_fractional_seconds))
.round_to(self.precision)
.to_python_type()
)

def _add_second(self, time_value: time) -> time:
return (datetime.combine(datetime(1, 1, 1), time_value) + timedelta(seconds=1)).time()
@@ -125,13 +126,17 @@ class TimeWithTimeZoneValueMapper(TimeValueMapper):
def map(self, value: Any) -> Optional[time]:
if value is None:
return None
whole_python_temporal_value = value[:self.time_default_size]
remaining_fractional_seconds = value[self.time_default_size + 1:len(value) - 6]
timezone_part = value[len(value) - 6:]
return TimeWithTimeZone(
time.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)),
_fraction_to_decimal(remaining_fractional_seconds),
).round_to(self.precision).to_python_type()
whole_python_temporal_value = value[: self.time_default_size]
remaining_fractional_seconds = value[self.time_default_size + 1 : len(value) - 6]
timezone_part = value[len(value) - 6 :]
return (
TimeWithTimeZone(
time.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)),
_fraction_to_decimal(remaining_fractional_seconds),
)
.round_to(self.precision)
.to_python_type()
)


class TimestampValueMapper(ValueMapper[datetime]):
@@ -142,25 +147,33 @@ def __init__(self, precision: int):
def map(self, value: Any) -> Optional[datetime]:
if value is None:
return None
whole_python_temporal_value = value[:self.datetime_default_size]
remaining_fractional_seconds = value[self.datetime_default_size + 1:]
return Timestamp(
datetime.fromisoformat(whole_python_temporal_value),
_fraction_to_decimal(remaining_fractional_seconds),
).round_to(self.precision).to_python_type()
whole_python_temporal_value = value[: self.datetime_default_size]
remaining_fractional_seconds = value[self.datetime_default_size + 1 :]
return (
Timestamp(
datetime.fromisoformat(whole_python_temporal_value),
_fraction_to_decimal(remaining_fractional_seconds),
)
.round_to(self.precision)
.to_python_type()
)


class TimestampWithTimeZoneValueMapper(TimestampValueMapper):
def map(self, value: Any) -> Optional[datetime]:
if value is None:
return None
datetime_with_fraction, timezone_part = value.rsplit(' ', 1)
whole_python_temporal_value = datetime_with_fraction[:self.datetime_default_size]
remaining_fractional_seconds = datetime_with_fraction[self.datetime_default_size + 1:]
return TimestampWithTimeZone(
datetime.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)),
_fraction_to_decimal(remaining_fractional_seconds),
).round_to(self.precision).to_python_type()
datetime_with_fraction, timezone_part = value.rsplit(" ", 1)
whole_python_temporal_value = datetime_with_fraction[: self.datetime_default_size]
remaining_fractional_seconds = datetime_with_fraction[self.datetime_default_size + 1 :]
return (
TimestampWithTimeZone(
datetime.fromisoformat(whole_python_temporal_value).replace(tzinfo=_create_tzinfo(timezone_part)),
_fraction_to_decimal(remaining_fractional_seconds),
)
.round_to(self.precision)
.to_python_type()
)


def _create_tzinfo(timezone_str: str) -> tzinfo:
@@ -183,7 +196,7 @@ def map(self, value: Any) -> Optional[relativedelta]:
if value is None:
return None
is_negative = value[0] == "-"
years, months = (value[1:] if is_negative else value).split('-')
years, months = (value[1:] if is_negative else value).split("-")
years, months = int(years), int(months)
if is_negative:
years, months = -years, -months
@@ -195,11 +208,16 @@ def map(self, value: Any) -> Optional[timedelta]:
if value is None:
return None
is_negative = value[0] == "-"
days, time = (value[1:] if is_negative else value).split(' ')
hours, minutes, seconds_milliseconds = time.split(':')
seconds, milliseconds = seconds_milliseconds.split('.')
days, hours, minutes, seconds, milliseconds = (int(days), int(hours), int(minutes), int(seconds),
int(milliseconds))
days, time = (value[1:] if is_negative else value).split(" ")
hours, minutes, seconds_milliseconds = time.split(":")
seconds, milliseconds = seconds_milliseconds.split(".")
days, hours, minutes, seconds, milliseconds = (
int(days),
int(hours),
int(minutes),
int(seconds),
int(milliseconds),
)
if is_negative:
days, hours, minutes, seconds, milliseconds = -days, -hours, -minutes, -seconds, -milliseconds
try:
@@ -230,9 +248,7 @@ def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any])
def map(self, value: Any) -> Optional[Dict[Any, Optional[Any]]]:
if value is None:
return None
return {
self.key_mapper.map(k): self.value_mapper.map(v) for k, v in value.items()
}
return {self.key_mapper.map(k): self.value_mapper.map(v) for k, v in value.items()}


class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
@@ -244,11 +260,7 @@ def __init__(self, mappers: List[ValueMapper[Any]], names: List[Optional[str]],
def map(self, value: Optional[List[Any]]) -> Optional[Tuple[Optional[Any], ...]]:
if value is None:
return None
return NamedRowTuple(
list(self.mappers[i].map(v) for i, v in enumerate(value)),
self.names,
self.types
)
return NamedRowTuple(list(self.mappers[i].map(v) for i, v in enumerate(value)), self.names, self.types)


class UuidValueMapper(ValueMapper[uuid.UUID]):
@@ -279,82 +291,84 @@ class RowMapperFactory:
lambda functions (one for each column) which will process a data value
and returns a RowMapper instance which will process rows of data
"""

NO_OP_ROW_MAPPER = NoOpRowMapper()

def create(self, columns: List[Any], legacy_primitive_types: bool) -> RowMapper | NoOpRowMapper:
assert columns is not None

if not legacy_primitive_types:
return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns])
return RowMapper([self._create_value_mapper(column["typeSignature"]) for column in columns])
return RowMapperFactory.NO_OP_ROW_MAPPER

def _create_value_mapper(self, column: Dict[str, Any]) -> ValueMapper[Any]:
col_type = column['rawType']
col_type = column["rawType"]

# primitive types
if col_type == 'boolean':
if col_type == "boolean":
return BooleanValueMapper()
if col_type in {'tinyint', 'smallint', 'integer', 'bigint'}:
if col_type in {"tinyint", "smallint", "integer", "bigint"}:
return IntegerValueMapper()
if col_type in {'double', 'real'}:
if col_type in {"double", "real"}:
return DoubleValueMapper()
if col_type == 'decimal':
if col_type == "decimal":
return DecimalValueMapper()
if col_type in {'varchar', 'char'}:
if col_type in {"varchar", "char"}:
return StringValueMapper()
if col_type == 'varbinary':
if col_type == "varbinary":
return BinaryValueMapper()
if col_type == 'json':
if col_type == "json":
return StringValueMapper()
if col_type == 'date':
if col_type == "date":
return DateValueMapper()
if col_type == 'time':
if col_type == "time":
return TimeValueMapper(self._get_precision(column))
if col_type == 'time with time zone':
if col_type == "time with time zone":
return TimeWithTimeZoneValueMapper(self._get_precision(column))
if col_type == 'timestamp':
if col_type == "timestamp":
return TimestampValueMapper(self._get_precision(column))
if col_type == 'timestamp with time zone':
if col_type == "timestamp with time zone":
return TimestampWithTimeZoneValueMapper(self._get_precision(column))
if col_type == 'interval year to month':
if col_type == "interval year to month":
return IntervalYearToMonthMapper()
if col_type == 'interval day to second':
if col_type == "interval day to second":
return IntervalDayToSecondMapper()

# structural types
if col_type == 'array':
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
if col_type == "array":
value_mapper = self._create_value_mapper(column["arguments"][0]["value"])
return ArrayValueMapper(value_mapper)
if col_type == 'map':
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
if col_type == "map":
key_mapper = self._create_value_mapper(column["arguments"][0]["value"])
value_mapper = self._create_value_mapper(column["arguments"][1]["value"])
return MapValueMapper(key_mapper, value_mapper)
if col_type == 'row':
if col_type == "row":
mappers: List[ValueMapper[Any]] = []
names: List[Optional[str]] = []
types: List[str] = []
for arg in column['arguments']:
mappers.append(self._create_value_mapper(arg['value']['typeSignature']))
names.append(arg['value']['fieldName']['name'] if "fieldName" in arg['value'] else None)
types.append(arg['value']['typeSignature']['rawType'])
for arg in column["arguments"]:
mappers.append(self._create_value_mapper(arg["value"]["typeSignature"]))
names.append(arg["value"]["fieldName"]["name"] if "fieldName" in arg["value"] else None)
types.append(arg["value"]["typeSignature"]["rawType"])
return RowValueMapper(mappers, names, types)

# others
if col_type == 'uuid':
if col_type == "uuid":
return UuidValueMapper()
return NoOpValueMapper()

def _get_precision(self, column: Dict[str, Any]) -> int:
args = column['arguments']
args = column["arguments"]
if len(args) == 0:
return 3
return args[0]['value']
return args[0]["value"]


class RowMapper:
"""
Maps a row of data given a list of mapping functions
"""

def __init__(self, columns: List[ValueMapper[Any]]):
self.columns = columns

36 changes: 15 additions & 21 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
@@ -106,25 +106,19 @@ def limit_clause(self, select, **kw):
text += "\nLIMIT " + self.process(select._limit_clause, **kw)
return text

def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
fromhints=None, use_schema=True, **kwargs):
sql = super(TrinoSQLCompiler, self).visit_table(
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
)
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, use_schema=True, **kwargs):
sql = super(TrinoSQLCompiler, self).visit_table(table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs)
return self.add_catalog(sql, table)

@staticmethod
def add_catalog(sql, table):
if table is None or not isinstance(table, DialectKWArgs):
return sql

if (
'trino' not in table.dialect_options
or 'catalog' not in table.dialect_options['trino']
):
if "trino" not in table.dialect_options or "catalog" not in table.dialect_options["trino"]:
return sql

catalog = table.dialect_options['trino']['catalog']
catalog = table.dialect_options["trino"]["catalog"]
sql = f'"{catalog}".{sql}'
return sql

@@ -146,23 +140,23 @@ class GenericIgnoreNulls(GenericFunction):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if kwargs.get('ignore_nulls'):
if kwargs.get("ignore_nulls"):
self.ignore_nulls = True

class FirstValue(GenericIgnoreNulls):
name = 'first_value'
name = "first_value"

class LastValue(GenericIgnoreNulls):
name = 'last_value'
name = "last_value"

class NthValue(GenericIgnoreNulls):
name = 'nth_value'
name = "nth_value"

class Lead(GenericIgnoreNulls):
name = 'lead'
name = "lead"

class Lag(GenericIgnoreNulls):
name = 'lag'
name = "lag"

@staticmethod
@compiles(FirstValue)
@@ -171,9 +165,9 @@ class Lag(GenericIgnoreNulls):
@compiles(Lead)
@compiles(Lag)
def compile_ignore_nulls(element, compiler, **kwargs):
compiled = f'{element.name}({compiler.process(element.clauses)})'
compiled = f"{element.name}({compiler.process(element.clauses)})"
if element.ignore_nulls:
compiled += ' IGNORE NULLS'
compiled += " IGNORE NULLS"
return compiled

def visit_try_cast(self, element, **kw):
@@ -248,17 +242,17 @@ def visit_TIME(self, type_, **kw):
return datatype

def visit_JSON(self, type_, **kw):
return 'JSON'
return "JSON"

def visit_MAP(self, type_, **kw):
# the key and value types themselves need to be processed otherwise sqltypes.MAP(Float, Float) will get
# rendered as MAP(FLOAT, FLOAT) instead of MAP(REAL, REAL) or MAP(DOUBLE, DOUBLE)
key_type = self.process(type_.key_type, **kw)
value_type = self.process(type_.value_type, **kw)
return f'MAP({key_type}, {value_type})'
return f"MAP({key_type}, {value_type})"

def visit_ARRAY(self, type_, **kw):
return f'ARRAY({self.process(type_.item_type, **kw)})'
return f"ARRAY({self.process(type_.item_type, **kw)})"

def visit_ROW(self, type_, **kw):
return f'ROW({", ".join(f"{name} {self.process(attr_type, **kw)}" for name, attr_type in type_.attr_types)})'
8 changes: 3 additions & 5 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
@@ -119,13 +119,11 @@ def process(value):
class _JSONFormatter:
@staticmethod
def format_index(value):
return "$[\"%s\"]" % value
return '$["%s"]' % value

@staticmethod
def format_path(value):
return "$%s" % (
"".join(["[\"%s\"]" % elem for elem in value])
)
return "$%s" % ("".join(['["%s"]' % elem for elem in value]))


class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
@@ -238,7 +236,7 @@ def aware_split(
elif character == close_bracket:
parens -= 1
elif character == quote:
if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote:
if quotes and string[j - len(escaped_quote) + 1 : j + 1] != escaped_quote:
quotes = False
elif not quotes:
quotes = True
25 changes: 6 additions & 19 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
@@ -52,10 +52,7 @@


class TrinoDialect(DefaultDialect):
def __init__(self,
json_serializer=None,
json_deserializer=None,
**kwargs):
def __init__(self, json_serializer=None, json_deserializer=None, **kwargs):
DefaultDialect.__init__(self, **kwargs)
self._json_serializer = json_serializer
self._json_deserializer = json_deserializer
@@ -142,7 +139,7 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any

if "cert" in url.query and "key" in url.query:
kwargs["http_scheme"] = "https"
kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query['cert']), unquote_plus(url.query['key']))
kwargs["auth"] = CertificateAuthentication(unquote_plus(url.query["cert"]), unquote_plus(url.query["key"]))

if "externalAuthentication" in url.query:
kwargs["http_scheme"] = "https"
@@ -214,10 +211,7 @@ def _get_columns(self, connection: Connection, table_name: str, schema: str = No
return columns

def _get_partitions(
self,
connection: Connection,
table_name: str,
schema: str = None
self, connection: Connection, table_name: str, schema: str = None
) -> List[Dict[str, List[Any]]]:
schema = schema or self._get_default_schema_name(connection)
query = dedent(
@@ -330,11 +324,7 @@ def get_indexes(self, connection: Connection, table_name: str, schema: str = Non
logger.debug("Couldn't fetch partition columns. schema: %s, table: %s, error: %s", schema, table_name, e)
if not partitioned_columns:
return []
partition_index = dict(
name="partition",
column_names=partitioned_columns,
unique=False
)
partition_index = dict(name="partition", column_names=partitioned_columns, unique=False)
return [partition_index]

def get_sequence_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
@@ -371,14 +361,11 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
).strip()
try:
res = connection.execute(
sql.text(query),
{"catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name}
sql.text(query), {"catalog_name": catalog_name, "schema_name": schema_name, "table_name": table_name}
)
return dict(text=res.scalar())
except error.TrinoQueryError as e:
if e.error_name in (
error.PERMISSION_DENIED,
):
if e.error_name in (error.PERMISSION_DENIED,):
return dict(text=None)
raise

2 changes: 1 addition & 1 deletion trino/sqlalchemy/util.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ def _url(
cert: Optional[str] = None,
key: Optional[str] = None,
verify: Optional[bool] = None,
roles: Optional[Dict[str, str]] = None
roles: Optional[Dict[str, str]] = None,
) -> str:
"""
Composes a SQLAlchemy connection string from the given database connection
12 changes: 3 additions & 9 deletions trino/transaction.py
Original file line number Diff line number Diff line change
@@ -66,9 +66,7 @@ def request(self) -> trino.client.TrinoRequest:
def begin(self) -> None:
response = self._request.post(START_TRANSACTION)
if not response.ok:
raise trino.exceptions.DatabaseError(
"failed to start transaction: {}".format(response.status_code)
)
raise trino.exceptions.DatabaseError("failed to start transaction: {}".format(response.status_code))
transaction_id = response.headers.get(constants.HEADER_STARTED_TRANSACTION)
if transaction_id and transaction_id != NO_TRANSACTION:
self._id = response.headers[constants.HEADER_STARTED_TRANSACTION]
@@ -87,9 +85,7 @@ def commit(self) -> None:
try:
list(query.execute())
except Exception as err:
raise trino.exceptions.DatabaseError(
"failed to commit transaction {}: {}".format(self._id, err)
)
raise trino.exceptions.DatabaseError("failed to commit transaction {}: {}".format(self._id, err))
self._id = NO_TRANSACTION
self._request.transaction_id = self._id

@@ -98,8 +94,6 @@ def rollback(self) -> None:
try:
list(query.execute())
except Exception as err:
raise trino.exceptions.DatabaseError(
"failed to rollback transaction {}: {}".format(self._id, err)
)
raise trino.exceptions.DatabaseError("failed to rollback transaction {}: {}".format(self._id, err))
self._id = NO_TRANSACTION
self._request.transaction_id = self._id
9 changes: 5 additions & 4 deletions trino/types.py
Original file line number Diff line number Diff line change
@@ -36,9 +36,9 @@ def to_python_type(self) -> PythonTemporalType:

def round_to(self, precision: int) -> TemporalType[PythonTemporalType]:
"""
Python datetime and time only support up to microsecond precision
In case the supplied value exceeds the specified precision,
the value needs to be rounded.
Python datetime and time only support up to microsecond precision
In case the supplied value exceeds the specified precision,
the value needs to be rounded.
"""
precision = min(precision, MAX_PYTHON_TEMPORAL_PRECISION_POWER)
remaining_fractional_seconds = self._remaining_fractional_seconds
@@ -53,7 +53,7 @@ def round_to(self, precision: int) -> TemporalType[PythonTemporalType]:
@abc.abstractmethod
def add_time_delta(self, time_delta: timedelta) -> PythonTemporalType:
"""
This method shall be overriden to implement fraction arithmetics.
This method shall be overriden to implement fraction arithmetics.
"""
pass

@@ -99,6 +99,7 @@ def new_instance(self, value: datetime, fraction: Decimal) -> TimestampWithTimeZ

class NamedRowTuple(Tuple[Any, ...]):
"""Custom tuple class as namedtuple doesn't support missing or duplicate names"""

def __new__(cls, values: List[Any], names: List[str], types: List[str]) -> NamedRowTuple:
return cast(NamedRowTuple, super().__new__(cls, values))