Skip to content

Commit da1441f

Browse files
dungdm93hashhar
authored andcommitted
Add SQLAlchemy dialect for Trino
1 parent f9e68da commit da1441f

12 files changed

+1109
-2
lines changed

setup.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
assert trino_version is not None
2727
version = str(ast.literal_eval(trino_version.group(1)))
2828

29-
3029
kerberos_require = ["requests_kerberos"]
30+
sqlalchemy_require = ["sqlalchemy~=1.3"]
3131

32-
all_require = kerberos_require + []
32+
all_require = kerberos_require + sqlalchemy_require
3333

3434
tests_require = all_require + [
3535
# httpretty >= 1.1 duplicates requests in `httpretty.latest_requests`
@@ -80,6 +80,12 @@
8080
extras_require={
8181
"all": all_require,
8282
"kerberos": kerberos_require,
83+
"sqlalchemy": sqlalchemy_require,
8384
"tests": tests_require,
8485
},
86+
entry_points={
87+
"sqlalchemy.dialects": [
88+
"trino = trino.sqlalchemy.dialect:TrinoDialect",
89+
]
90+
},
8591
)

tests/__init__.py

Whitespace-only changes.

tests/unit/sqlalchemy/__init__.py

Whitespace-only changes.

tests/unit/sqlalchemy/conftest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
import pytest
13+
from sqlalchemy.sql.sqltypes import ARRAY
14+
15+
from trino.sqlalchemy.datatype import MAP, ROW, SQLType
16+
17+
18+
@pytest.fixture(scope="session")
19+
def assert_sqltype():
20+
def _assert_sqltype(this: SQLType, that: SQLType):
21+
if isinstance(this, type):
22+
this = this()
23+
if isinstance(that, type):
24+
that = that()
25+
26+
assert type(this) == type(that)
27+
28+
if isinstance(this, ARRAY):
29+
_assert_sqltype(this.item_type, that.item_type)
30+
if this.dimensions is None or this.dimensions == 1:
31+
# ARRAY(dimensions=None) == ARRAY(dimensions=1)
32+
assert that.dimensions is None or that.dimensions == 1
33+
else:
34+
assert that.dimensions == this.dimensions
35+
elif isinstance(this, MAP):
36+
_assert_sqltype(this.key_type, that.key_type)
37+
_assert_sqltype(this.value_type, that.value_type)
38+
elif isinstance(this, ROW):
39+
assert len(this.attr_types) == len(that.attr_types)
40+
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
41+
assert this_attr[0] == that_attr[0]
42+
_assert_sqltype(this_attr[1], that_attr[1])
43+
else:
44+
assert str(this) == str(that)
45+
46+
return _assert_sqltype
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
import pytest
13+
from sqlalchemy.sql.sqltypes import (
14+
CHAR,
15+
VARCHAR,
16+
ARRAY,
17+
INTEGER,
18+
DECIMAL,
19+
DATE,
20+
TIME,
21+
TIMESTAMP,
22+
)
23+
from sqlalchemy.sql.type_api import TypeEngine
24+
25+
from trino.sqlalchemy import datatype
26+
from trino.sqlalchemy.datatype import MAP, ROW
27+
28+
29+
@pytest.mark.parametrize(
30+
"type_str, sql_type",
31+
datatype._type_map.items(),
32+
ids=datatype._type_map.keys(),
33+
)
34+
def test_parse_simple_type(type_str: str, sql_type: TypeEngine, assert_sqltype):
35+
actual_type = datatype.parse_sqltype(type_str)
36+
if not isinstance(actual_type, type):
37+
actual_type = type(actual_type)
38+
assert_sqltype(actual_type, sql_type)
39+
40+
41+
parse_cases_testcases = {
42+
"char(10)": CHAR(10),
43+
"Char(10)": CHAR(10),
44+
"char": CHAR(),
45+
"cHaR": CHAR(),
46+
"VARCHAR(10)": VARCHAR(10),
47+
"varCHAR(10)": VARCHAR(10),
48+
"VARchar(10)": VARCHAR(10),
49+
"VARCHAR": VARCHAR(),
50+
"VaRchAr": VARCHAR(),
51+
}
52+
53+
54+
@pytest.mark.parametrize(
55+
"type_str, sql_type",
56+
parse_cases_testcases.items(),
57+
ids=parse_cases_testcases.keys(),
58+
)
59+
def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
60+
actual_type = datatype.parse_sqltype(type_str)
61+
assert_sqltype(actual_type, sql_type)
62+
63+
64+
parse_type_options_testcases = {
65+
"CHAR(10)": CHAR(10),
66+
"VARCHAR(10)": VARCHAR(10),
67+
"DECIMAL(20)": DECIMAL(20),
68+
"DECIMAL(20, 3)": DECIMAL(20, 3),
69+
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
70+
}
71+
72+
73+
@pytest.mark.parametrize(
74+
"type_str, sql_type",
75+
parse_type_options_testcases.items(),
76+
ids=parse_type_options_testcases.keys(),
77+
)
78+
def test_parse_type_options(type_str: str, sql_type: TypeEngine, assert_sqltype):
79+
actual_type = datatype.parse_sqltype(type_str)
80+
assert_sqltype(actual_type, sql_type)
81+
82+
83+
parse_array_testcases = {
84+
"array(integer)": ARRAY(INTEGER()),
85+
"array(varchar(10))": ARRAY(VARCHAR(10)),
86+
"array(decimal(20,3))": ARRAY(DECIMAL(20, 3)),
87+
"array(array(varchar(10)))": ARRAY(VARCHAR(10), dimensions=2),
88+
"array(map(char, integer))": ARRAY(MAP(CHAR(), INTEGER())),
89+
"array(row(a integer, b varchar))": ARRAY(ROW([("a", INTEGER()), ("b", VARCHAR())])),
90+
}
91+
92+
93+
@pytest.mark.parametrize(
94+
"type_str, sql_type",
95+
parse_array_testcases.items(),
96+
ids=parse_array_testcases.keys(),
97+
)
98+
def test_parse_array(type_str: str, sql_type: ARRAY, assert_sqltype):
99+
actual_type = datatype.parse_sqltype(type_str)
100+
assert_sqltype(actual_type, sql_type)
101+
102+
103+
parse_map_testcases = {
104+
"map(char, integer)": MAP(CHAR(), INTEGER()),
105+
"map(varchar(10), varchar(10))": MAP(VARCHAR(10), VARCHAR(10)),
106+
"map(varchar(10), decimal(20,3))": MAP(VARCHAR(10), DECIMAL(20, 3)),
107+
"map(char, array(varchar(10)))": MAP(CHAR(), ARRAY(VARCHAR(10))),
108+
"map(varchar(10), array(varchar(10)))": MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
109+
"map(varchar(10), array(array(varchar(10))))": MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
110+
}
111+
112+
113+
@pytest.mark.parametrize(
114+
"type_str, sql_type",
115+
parse_map_testcases.items(),
116+
ids=parse_map_testcases.keys(),
117+
)
118+
def test_parse_map(type_str: str, sql_type: ARRAY, assert_sqltype):
119+
actual_type = datatype.parse_sqltype(type_str)
120+
assert_sqltype(actual_type, sql_type)
121+
122+
123+
parse_row_testcases = {
124+
"row(a integer, b varchar)": ROW(
125+
attr_types=[
126+
("a", INTEGER()),
127+
("b", VARCHAR()),
128+
]
129+
),
130+
"row(a varchar(20), b decimal(20,3))": ROW(
131+
attr_types=[
132+
("a", VARCHAR(20)),
133+
("b", DECIMAL(20, 3)),
134+
]
135+
),
136+
"row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))": ROW(
137+
attr_types=[
138+
("x", ARRAY(VARCHAR(10))),
139+
("y", ARRAY(VARCHAR(10), dimensions=2)),
140+
("z", DECIMAL(20, 3)),
141+
]
142+
),
143+
"row(min timestamp(6) with time zone, max timestamp(6) with time zone)": ROW(
144+
attr_types=[
145+
("min", TIMESTAMP(timezone=True)),
146+
("max", TIMESTAMP(timezone=True)),
147+
]
148+
),
149+
'row("first name" varchar, "last name" varchar)': ROW(
150+
attr_types=[
151+
("first name", VARCHAR()),
152+
("last name", VARCHAR()),
153+
]
154+
),
155+
'row("foo,bar" varchar, "foo(bar)" varchar, "foo\\"bar" varchar)': ROW(
156+
attr_types=[
157+
(r"foo,bar", VARCHAR()),
158+
(r"foo(bar)", VARCHAR()),
159+
(r'foo"bar', VARCHAR()),
160+
]
161+
),
162+
}
163+
164+
165+
@pytest.mark.parametrize(
166+
"type_str, sql_type",
167+
parse_row_testcases.items(),
168+
ids=parse_row_testcases.keys(),
169+
)
170+
def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
171+
actual_type = datatype.parse_sqltype(type_str)
172+
assert_sqltype(actual_type, sql_type)
173+
174+
175+
parse_datetime_testcases = {
176+
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
177+
"date": DATE(),
178+
"time": TIME(),
179+
"time with time zone": TIME(timezone=True),
180+
"timestamp": TIMESTAMP(),
181+
"timestamp with time zone": TIMESTAMP(timezone=True),
182+
}
183+
184+
185+
@pytest.mark.parametrize(
186+
"type_str, sql_type",
187+
parse_datetime_testcases.items(),
188+
ids=parse_datetime_testcases.keys(),
189+
)
190+
def test_parse_datetime(type_str: str, sql_type: ARRAY, assert_sqltype):
191+
actual_type = datatype.parse_sqltype(type_str)
192+
assert_sqltype(actual_type, sql_type)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
from typing import List
13+
14+
import pytest
15+
16+
from trino.sqlalchemy import datatype
17+
18+
split_string_testcases = {
19+
"10": ["10"],
20+
"10,3": ["10", "3"],
21+
'"a,b",c': ['"a,b"', "c"],
22+
'"a,b","c,d"': ['"a,b"', '"c,d"'],
23+
r'"a,\"b\",c",d': [r'"a,\"b\",c"', "d"],
24+
r'"foo(bar,\"baz\")",quiz': [r'"foo(bar,\"baz\")"', "quiz"],
25+
"varchar": ["varchar"],
26+
"varchar,int": ["varchar", "int"],
27+
"varchar,int,float": ["varchar", "int", "float"],
28+
"array(varchar)": ["array(varchar)"],
29+
"array(varchar),int": ["array(varchar)", "int"],
30+
"array(varchar(20))": ["array(varchar(20))"],
31+
"array(varchar(20)),int": ["array(varchar(20))", "int"],
32+
"array(varchar(20)),array(varchar(20))": [
33+
"array(varchar(20))",
34+
"array(varchar(20))",
35+
],
36+
"map(varchar, integer),int": ["map(varchar, integer)", "int"],
37+
"map(varchar(20), integer),int": ["map(varchar(20), integer)", "int"],
38+
"map(varchar(20), varchar(20)),int": ["map(varchar(20), varchar(20))", "int"],
39+
"map(varchar(20), varchar(20)),array(varchar)": [
40+
"map(varchar(20), varchar(20))",
41+
"array(varchar)",
42+
],
43+
"row(first_name varchar(20), last_name varchar(20)),int": [
44+
"row(first_name varchar(20), last_name varchar(20))",
45+
"int",
46+
],
47+
'row("first name" varchar(20), "last name" varchar(20)),int': [
48+
'row("first name" varchar(20), "last name" varchar(20))',
49+
"int",
50+
],
51+
}
52+
53+
54+
@pytest.mark.parametrize(
55+
"input_string, output_strings",
56+
split_string_testcases.items(),
57+
ids=split_string_testcases.keys(),
58+
)
59+
def test_split_string(input_string: str, output_strings: List[str]):
60+
actual = list(datatype.aware_split(input_string))
61+
assert actual == output_strings
62+
63+
64+
split_delimiter_testcases = [
65+
("first,second", ",", ["first", "second"]),
66+
("first second", " ", ["first", "second"]),
67+
("first|second", "|", ["first", "second"]),
68+
("first,second third", ",", ["first", "second third"]),
69+
("first,second third", " ", ["first,second", "third"]),
70+
]
71+
72+
73+
@pytest.mark.parametrize(
74+
"input_string, delimiter, output_strings",
75+
split_delimiter_testcases,
76+
)
77+
def test_split_delimiter(input_string: str, delimiter: str, output_strings: List[str]):
78+
actual = list(datatype.aware_split(input_string, delimiter=delimiter))
79+
assert actual == output_strings
80+
81+
82+
split_maxsplit_testcases = [
83+
("one,two,three", -1, ["one", "two", "three"]),
84+
("one,two,three", 0, ["one,two,three"]),
85+
("one,two,three", 1, ["one", "two,three"]),
86+
("one,two,three", 2, ["one", "two", "three"]),
87+
("one,two,three", 3, ["one", "two", "three"]),
88+
("one,two,three", 10, ["one", "two", "three"]),
89+
(",one,two,three", 0, [",one,two,three"]),
90+
(",one,two,three", 1, ["", "one,two,three"]),
91+
("one,two,three,", 2, ["one", "two", "three,"]),
92+
("one,two,three,", 3, ["one", "two", "three", ""]),
93+
]
94+
95+
96+
@pytest.mark.parametrize(
97+
"input_string, maxsplit, output_strings",
98+
split_maxsplit_testcases,
99+
)
100+
def test_split_maxsplit(input_string: str, maxsplit: int, output_strings: List[str]):
101+
actual = list(datatype.aware_split(input_string, maxsplit=maxsplit))
102+
assert actual == output_strings

0 commit comments

Comments
 (0)