Skip to content

Commit 935000d

Browse files
simozzyrad-pat
andauthored
Added new types (#66)
* Fix not fetching cursor on insert/update * Fetch the cursor on insert/update/delete/copy into * Enable CTE tests * Enable further tests * Support for table and column comments * Include CTE test now bug is fixed * Run against nightly * Update pipenv * Work in SQLAlchemy 1.4 * Handle JSON params * Update Reserved Words * Changes to File Formats * Fix copyInto files clause * Fix test - input bytes will differ, table data is random * remove code for Geometry, Geography and structured types. * Update test.yml * reverted change to test file * Added tests for TINYINT and BITMAP * Added tests for DOUBLE * Update test.yml * Update test_sqlalchemy.py * Ensure tests work on sqlalchemy versions 1.4.54 and 2.0 + * Added types for GEOMETRY and GEOGRAPHY. * Added Zip compression type * Added code to enable geo tables in tests. Removed redundant code. * move initialization of enable_geo_create_table --------- Co-authored-by: Pat Buxton <[email protected]>
1 parent c5be5fb commit 935000d

File tree

6 files changed

+556
-66
lines changed

6 files changed

+556
-66
lines changed

databend_sqlalchemy/databend_dialect.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Table("some_table", metadata, ..., databend_transient=True|False)
2424
2525
"""
26-
2726
import decimal
2827
import re
2928
import operator
@@ -60,6 +59,17 @@
6059
CHAR,
6160
TIMESTAMP,
6261
)
62+
63+
import sqlalchemy
64+
from sqlalchemy import types as sqltypes
65+
from sqlalchemy.sql.base import Executable
66+
67+
# Check SQLAlchemy version
68+
if sqlalchemy.__version__.startswith('2.'):
69+
from sqlalchemy.types import DOUBLE
70+
else:
71+
from .types import DOUBLE
72+
6373
from sqlalchemy.engine import ExecutionContext, default
6474
from sqlalchemy.exc import DBAPIError, NoSuchTableError
6575

@@ -71,7 +81,7 @@
7181
AzureBlobStorage,
7282
AmazonS3,
7383
)
74-
from .types import INTERVAL
84+
from .types import INTERVAL, TINYINT, BITMAP, GEOMETRY, GEOGRAPHY
7585

7686
RESERVED_WORDS = {
7787
"Error",
@@ -693,6 +703,7 @@ def __init__(self, key_type, value_type):
693703
super(MAP, self).__init__()
694704

695705

706+
696707
class DatabendDate(sqltypes.DATE):
697708
__visit_name__ = "DATE"
698709

@@ -793,12 +804,26 @@ class DatabendInterval(INTERVAL):
793804
render_bind_cast = True
794805

795806

807+
class DatabendBitmap(BITMAP):
808+
render_bind_cast = True
809+
810+
811+
class DatabendTinyInt(TINYINT):
812+
render_bind_cast = True
813+
814+
815+
class DatabendGeometry(GEOMETRY):
816+
render_bind_cast = True
817+
818+
class DatabendGeography(GEOGRAPHY):
819+
render_bind_cast = True
820+
796821
# Type converters
797822
ischema_names = {
798823
"bigint": BIGINT,
799824
"int": INTEGER,
800825
"smallint": SMALLINT,
801-
"tinyint": SMALLINT,
826+
"tinyint": DatabendTinyInt,
802827
"int64": BIGINT,
803828
"int32": INTEGER,
804829
"int16": SMALLINT,
@@ -813,7 +838,7 @@ class DatabendInterval(INTERVAL):
813838
"datetime": DatabendDateTime,
814839
"timestamp": DatabendDateTime,
815840
"float": FLOAT,
816-
"double": FLOAT,
841+
"double": DOUBLE,
817842
"float64": FLOAT,
818843
"float32": FLOAT,
819844
"string": VARCHAR,
@@ -826,8 +851,13 @@ class DatabendInterval(INTERVAL):
826851
"binary": BINARY,
827852
"time": DatabendTime,
828853
"interval": DatabendInterval,
854+
"bitmap": DatabendBitmap,
855+
"geometry": DatabendGeometry,
856+
"geography": DatabendGeography
829857
}
830858

859+
860+
831861
# Column spec
832862
colspecs = {
833863
sqltypes.Interval: DatabendInterval,
@@ -1227,6 +1257,29 @@ def visit_TIME(self, type_, **kw):
12271257
def visit_INTERVAL(self, type, **kw):
12281258
return "INTERVAL"
12291259

1260+
def visit_DOUBLE(self, type_, **kw):
1261+
return "DOUBLE"
1262+
1263+
def visit_TINYINT(self, type_, **kw):
1264+
return "TINYINT"
1265+
1266+
def visit_FLOAT(self, type_, **kw):
1267+
return "FLOAT"
1268+
1269+
def visit_BITMAP(self, type_, **kw):
1270+
return "BITMAP"
1271+
1272+
def visit_GEOMETRY(self, type_, **kw):
1273+
if type_.srid is not None:
1274+
return f"GEOMETRY(SRID {type_.srid})"
1275+
return "GEOMETRY"
1276+
1277+
def visit_GEOGRAPHY(self, type_, **kw):
1278+
if type_.srid is not None:
1279+
return f"GEOGRAPHY(SRID {type_.srid})"
1280+
return "GEOGRAPHY"
1281+
1282+
12301283

12311284
class DatabendDDLCompiler(compiler.DDLCompiler):
12321285
def visit_primary_key_constraint(self, constraint, **kw):

databend_sqlalchemy/dml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ class Compression(Enum):
251251
RAW_DEFLATE = "RAW_DEFLATE"
252252
XZ = "XZ"
253253
SNAPPY = "SNAPPY"
254+
ZIP = "ZIP"
254255

255256

256257
class CopyFormat(ClauseElement):

databend_sqlalchemy/types.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import datetime as dt
44
from typing import Optional, Type, Any
55

6+
from sqlalchemy import func
67
from sqlalchemy.engine.interfaces import Dialect
78
from sqlalchemy.sql import sqltypes
89
from sqlalchemy.sql import type_api
@@ -73,3 +74,80 @@ def process(value: dt.timedelta) -> str:
7374
return f"to_interval('{value.total_seconds()} seconds')"
7475

7576
return process
77+
78+
79+
class TINYINT(sqltypes.Integer):
80+
__visit_name__ = "TINYINT"
81+
native = True
82+
83+
84+
class DOUBLE(sqltypes.Float):
85+
__visit_name__ = "DOUBLE"
86+
native = True
87+
88+
89+
class FLOAT(sqltypes.Float):
90+
__visit_name__ = "FLOAT"
91+
native = True
92+
93+
94+
# The “CamelCase” types are to the greatest degree possible database agnostic
95+
96+
# For these datatypes, specific SQLAlchemy dialects provide backend-specific “UPPERCASE” datatypes, for a SQL type that has no analogue on other backends
97+
98+
99+
class BITMAP(sqltypes.TypeEngine):
100+
__visit_name__ = "BITMAP"
101+
render_bind_cast = True
102+
103+
def __init__(self, **kwargs):
104+
super(BITMAP, self).__init__()
105+
106+
def process_result_value(self, value, dialect):
107+
if value is None:
108+
return None
109+
# Databend returns bitmaps as strings of comma-separated integers
110+
return set(int(x) for x in value.split(',') if x)
111+
112+
def bind_expression(self, bindvalue):
113+
return func.to_bitmap(bindvalue, type_=self)
114+
115+
def column_expression(self, col):
116+
# Convert bitmap to string using a custom function
117+
return func.to_string(col, type_=sqltypes.String)
118+
119+
def bind_processor(self, dialect):
120+
def process(value):
121+
if value is None:
122+
return None
123+
if isinstance(value, set):
124+
return ','.join(str(x) for x in sorted(value))
125+
return str(value)
126+
return process
127+
128+
def result_processor(self, dialect, coltype):
129+
def process(value):
130+
if value is None:
131+
return None
132+
return set(int(x) for x in value.split(',') if x)
133+
return process
134+
135+
136+
class GEOMETRY(sqltypes.TypeEngine):
137+
__visit_name__ = "GEOMETRY"
138+
139+
def __init__(self, srid=None):
140+
super(GEOMETRY, self).__init__()
141+
self.srid = srid
142+
143+
144+
145+
class GEOGRAPHY(sqltypes.TypeEngine):
146+
__visit_name__ = "GEOGRAPHY"
147+
native = True
148+
149+
def __init__(self, srid=None):
150+
super(GEOGRAPHY, self).__init__()
151+
self.srid = srid
152+
153+

tests/conftest.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from sqlalchemy.dialects import registry
2-
from sqlalchemy import event, Engine, text
32
import pytest
43

54
registry.register("databend.databend", "databend_sqlalchemy.databend_dialect", "DatabendDialect")
@@ -9,9 +8,18 @@
98

109
from sqlalchemy.testing.plugin.pytestplugin import *
1110

11+
from packaging import version
12+
import sqlalchemy
13+
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
14+
from sqlalchemy import event, text
15+
from sqlalchemy import Engine
16+
17+
18+
@event.listens_for(Engine, "connect")
19+
def receive_engine_connect(conn, r):
20+
cur = conn.cursor()
21+
cur.execute('SET global format_null_as_str = 0')
22+
cur.execute('SET global enable_geo_create_table = 1')
23+
cur.close()
24+
1225

13-
@event.listens_for(Engine, "connect")
14-
def receive_engine_connect(conn, r):
15-
cur = conn.cursor()
16-
cur.execute('SET global format_null_as_str = 0')
17-
cur.close()

tests/test_copy_into.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
FileColumnClause,
2828
StageClause,
2929
)
30+
import sqlalchemy
31+
from packaging import version
3032

3133

3234
class CompileDatabendCopyIntoTableTest(fixtures.TestBase, AssertsCompiledSQL):
@@ -215,51 +217,52 @@ def define_tables(cls, metadata):
215217
Column("data", String(50)),
216218
)
217219

218-
def test_copy_into_stage_and_table(self, connection):
219-
# create stage
220-
connection.execute(text('CREATE OR REPLACE STAGE mystage'))
221-
# copy into stage from random table limiting 1000
222-
table = self.tables.random_data
223-
query = table.select().limit(1000)
220+
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
221+
def test_copy_into_stage_and_table(self, connection):
222+
# create stage
223+
connection.execute(text('CREATE OR REPLACE STAGE mystage'))
224+
# copy into stage from random table limiting 1000
225+
table = self.tables.random_data
226+
query = table.select().limit(1000)
224227

225-
copy_into = CopyIntoLocation(
226-
target=StageClause(
227-
name='mystage'
228-
),
229-
from_=query,
230-
file_format=ParquetFormat(),
231-
options=CopyIntoLocationOptions()
232-
)
233-
r = connection.execute(
234-
copy_into
235-
)
236-
eq_(r.rowcount, 1000)
237-
copy_into_results = r.context.copy_into_location_results()
238-
eq_(copy_into_results['rows_unloaded'], 1000)
239-
# eq_(copy_into_results['input_bytes'], 16250) # input bytes will differ, the table is random
240-
# eq_(copy_into_results['output_bytes'], 4701) # output bytes differs
228+
copy_into = CopyIntoLocation(
229+
target=StageClause(
230+
name='mystage'
231+
),
232+
from_=query,
233+
file_format=ParquetFormat(),
234+
options=CopyIntoLocationOptions()
235+
)
236+
r = connection.execute(
237+
copy_into
238+
)
239+
eq_(r.rowcount, 1000)
240+
copy_into_results = r.context.copy_into_location_results()
241+
eq_(copy_into_results['rows_unloaded'], 1000)
242+
# eq_(copy_into_results['input_bytes'], 16250) # input bytes will differ, the table is random
243+
# eq_(copy_into_results['output_bytes'], 4701) # output bytes differs
241244

242-
# now copy into table
245+
# now copy into table
243246

244-
copy_into_table = CopyIntoTable(
245-
target=self.tables.loaded,
246-
from_=StageClause(
247-
name='mystage'
248-
),
249-
file_format=ParquetFormat(),
250-
options=CopyIntoTableOptions()
251-
)
252-
r = connection.execute(
253-
copy_into_table
254-
)
255-
eq_(r.rowcount, 1000)
256-
copy_into_table_results = r.context.copy_into_table_results()
257-
assert len(copy_into_table_results) == 1
258-
result = copy_into_table_results[0]
259-
assert result['file'].endswith('.parquet')
260-
eq_(result['rows_loaded'], 1000)
261-
eq_(result['errors_seen'], 0)
262-
eq_(result['first_error'], None)
263-
eq_(result['first_error_line'], None)
247+
copy_into_table = CopyIntoTable(
248+
target=self.tables.loaded,
249+
from_=StageClause(
250+
name='mystage'
251+
),
252+
file_format=ParquetFormat(),
253+
options=CopyIntoTableOptions()
254+
)
255+
r = connection.execute(
256+
copy_into_table
257+
)
258+
eq_(r.rowcount, 1000)
259+
copy_into_table_results = r.context.copy_into_table_results()
260+
assert len(copy_into_table_results) == 1
261+
result = copy_into_table_results[0]
262+
assert result['file'].endswith('.parquet')
263+
eq_(result['rows_loaded'], 1000)
264+
eq_(result['errors_seen'], 0)
265+
eq_(result['first_error'], None)
266+
eq_(result['first_error_line'], None)
264267

265268

0 commit comments

Comments
 (0)