Skip to content

Commit cffd2b2

Browse files
lpoulainhashhar
authored andcommitted
Optimize experimental_python_types and add type-mapping tests
Instead of checking the type for each row, check the type once for each fetch() call and compute a list of lambdas which are to be applied to the values from each row. A new RowMapperFactory class is created to wrap this behavior. The experimental_python_types flag is now processed in the TrinoQuery class instead of the TrinoResult class. Type mapping tests for each lambda which maps rows to Python types is added.
1 parent f4487f5 commit cffd2b2

File tree

2 files changed

+420
-88
lines changed

2 files changed

+420
-88
lines changed
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import math
2+
import pytest
3+
from decimal import Decimal
4+
import trino
5+
6+
7+
@pytest.fixture
8+
def trino_connection(run_trino):
9+
_, host, port = run_trino
10+
11+
yield trino.dbapi.Connection(
12+
host=host, port=port, user="test", source="test", max_attempts=1
13+
)
14+
15+
16+
def test_boolean(trino_connection):
17+
SqlTest(trino_connection) \
18+
.add_field(sql="CAST(null AS BOOLEAN)", python=None) \
19+
.add_field(sql="false", python=False) \
20+
.add_field(sql="true", python=True) \
21+
.execute()
22+
23+
24+
def test_tinyint(trino_connection):
25+
SqlTest(trino_connection) \
26+
.add_field(sql="CAST(null AS TINYINT)", python=None) \
27+
.add_field(sql="CAST(-128 AS TINYINT)", python=-128) \
28+
.add_field(sql="CAST(42 AS TINYINT)", python=42) \
29+
.add_field(sql="CAST(127 AS TINYINT)", python=127) \
30+
.execute()
31+
32+
33+
def test_smallint(trino_connection):
34+
SqlTest(trino_connection) \
35+
.add_field(sql="CAST(null AS SMALLINT)", python=None) \
36+
.add_field(sql="CAST(-32768 AS SMALLINT)", python=-32768) \
37+
.add_field(sql="CAST(42 AS SMALLINT)", python=42) \
38+
.add_field(sql="CAST(32767 AS SMALLINT)", python=32767) \
39+
.execute()
40+
41+
42+
def test_int(trino_connection):
43+
SqlTest(trino_connection) \
44+
.add_field(sql="CAST(null AS INTEGER)", python=None) \
45+
.add_field(sql="CAST(-2147483648 AS INTEGER)", python=-2147483648) \
46+
.add_field(sql="CAST(83648 AS INTEGER)", python=83648) \
47+
.add_field(sql="CAST(2147483647 AS INTEGER)", python=2147483647) \
48+
.execute()
49+
50+
51+
def test_bigint(trino_connection):
52+
SqlTest(trino_connection) \
53+
.add_field(sql="CAST(null AS BIGINT)", python=None) \
54+
.add_field(sql="CAST(-9223372036854775808 AS BIGINT)", python=-9223372036854775808) \
55+
.add_field(sql="CAST(9223 AS BIGINT)", python=9223) \
56+
.add_field(sql="CAST(9223372036854775807 AS BIGINT)", python=9223372036854775807) \
57+
.execute()
58+
59+
60+
def test_real(trino_connection):
61+
SqlTest(trino_connection) \
62+
.add_field(sql="CAST(null AS REAL)", python=None) \
63+
.add_field(sql="CAST('NaN' AS REAL)", python=math.nan) \
64+
.add_field(sql="CAST('-Infinity' AS REAL)", python=-math.inf) \
65+
.add_field(sql="CAST(3.4028235E38 AS REAL)", python=3.4028235e+38) \
66+
.add_field(sql="CAST(1.4E-45 AS REAL)", python=1.4e-45) \
67+
.add_field(sql="CAST('Infinity' AS REAL)", python=math.inf) \
68+
.execute()
69+
70+
71+
def test_double(trino_connection):
72+
SqlTest(trino_connection) \
73+
.add_field(sql="CAST(null AS DOUBLE)", python=None) \
74+
.add_field(sql="CAST('NaN' AS DOUBLE)", python=math.nan) \
75+
.add_field(sql="CAST('-Infinity' AS DOUBLE)", python=-math.inf) \
76+
.add_field(sql="CAST(1.7976931348623157E308 AS DOUBLE)", python=1.7976931348623157e+308) \
77+
.add_field(sql="CAST(4.9E-324 AS DOUBLE)", python=5e-324) \
78+
.add_field(sql="CAST('Infinity' AS DOUBLE)", python=math.inf) \
79+
.execute()
80+
81+
82+
def test_decimal(trino_connection):
83+
SqlTest(trino_connection) \
84+
.add_field(sql="CAST(null AS DECIMAL)", python=None) \
85+
.add_field(sql="CAST(null AS DECIMAL(38,0))", python=None) \
86+
.add_field(sql="DECIMAL '10.3'", python=Decimal('10.3')) \
87+
.add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,18))", python=Decimal('0.123456789123456789')) \
88+
.add_field(sql="CAST(null AS DECIMAL(18,18))", python=None) \
89+
.add_field(sql="CAST('234.123456789123456789' AS DECIMAL(18,4))", python=Decimal('234.1235')) \
90+
.add_field(sql="CAST('10.3' AS DECIMAL(38,1))", python=Decimal('10.3')) \
91+
.add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,2))", python=Decimal('0.12')) \
92+
.add_field(sql="CAST('0.3123' AS DECIMAL(38,38))", python=Decimal('0.3123')) \
93+
.execute()
94+
95+
96+
def test_varchar(trino_connection):
97+
SqlTest(trino_connection) \
98+
.add_field(sql="'aaa'", python='aaa') \
99+
.add_field(sql="U&'Hello winter \2603 !'", python='Hello winter °3 !') \
100+
.add_field(sql="CAST(null AS VARCHAR)", python=None) \
101+
.add_field(sql="CAST('bbb' AS VARCHAR(1))", python='b') \
102+
.add_field(sql="CAST(null AS VARCHAR(1))", python=None) \
103+
.execute()
104+
105+
106+
def test_char(trino_connection):
107+
SqlTest(trino_connection) \
108+
.add_field(sql="CAST('ccc' AS CHAR)", python='c') \
109+
.add_field(sql="CAST(null AS CHAR)", python=None) \
110+
.add_field(sql="CAST('ddd' AS CHAR(1))", python='d') \
111+
.add_field(sql="CAST('😂' AS CHAR(1))", python='😂') \
112+
.add_field(sql="CAST(null AS CHAR(1))", python=None) \
113+
.execute()
114+
115+
116+
def test_varbinary(trino_connection):
117+
SqlTest(trino_connection) \
118+
.add_field(sql="X'65683F'", python='ZWg/') \
119+
.add_field(sql="X''", python='') \
120+
.add_field(sql="CAST('' AS VARBINARY)", python='') \
121+
.add_field(sql="from_utf8(CAST('😂😂😂😂😂😂' AS VARBINARY))", python='😂😂😂😂😂😂') \
122+
.add_field(sql="CAST(null AS VARBINARY)", python=None) \
123+
.execute()
124+
125+
126+
def test_varbinary_failure(trino_connection):
127+
SqlExpectFailureTest(trino_connection) \
128+
.execute("CAST(42 AS VARBINARY)")
129+
130+
131+
def test_json(trino_connection):
132+
SqlTest(trino_connection) \
133+
.add_field(sql="CAST('{}' AS JSON)", python='"{}"') \
134+
.add_field(sql="CAST('null' AS JSON)", python='"null"') \
135+
.add_field(sql="CAST(null AS JSON)", python=None) \
136+
.add_field(sql="CAST('3.14' AS JSON)", python='"3.14"') \
137+
.add_field(sql="CAST('a string' AS JSON)", python='"a string"') \
138+
.add_field(sql="CAST('a \" complex '' string :' AS JSON)", python='"a \\" complex \' string :"') \
139+
.add_field(sql="CAST('[]' AS JSON)", python='"[]"') \
140+
.execute()
141+
142+
143+
def test_interval(trino_connection):
144+
SqlTest(trino_connection) \
145+
.add_field(sql="CAST(null AS INTERVAL YEAR TO MONTH)", python=None) \
146+
.add_field(sql="CAST(null AS INTERVAL DAY TO SECOND)", python=None) \
147+
.add_field(sql="INTERVAL '3' MONTH", python='0-3') \
148+
.add_field(sql="INTERVAL '2' DAY", python='2 00:00:00.000') \
149+
.add_field(sql="INTERVAL '-2' DAY", python='-2 00:00:00.000') \
150+
.execute()
151+
152+
153+
def test_array(trino_connection):
154+
SqlTest(trino_connection) \
155+
.add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \
156+
.add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \
157+
.execute()
158+
159+
160+
def test_map(trino_connection):
161+
SqlTest(trino_connection) \
162+
.add_field(sql="CAST(null AS MAP(VARCHAR, INTEGER))", python=None) \
163+
.add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[1, null])", python={'a': 1, 'b': None}) \
164+
.execute()
165+
166+
167+
def test_row(trino_connection):
168+
SqlTest(trino_connection) \
169+
.add_field(sql="CAST(null AS ROW(x BIGINT, y DOUBLE))", python=None) \
170+
.add_field(sql="CAST(ROW(1, 2e0) AS ROW(x BIGINT, y DOUBLE))", python=(1, 2.0)) \
171+
.execute()
172+
173+
174+
def test_ipaddress(trino_connection):
175+
SqlTest(trino_connection) \
176+
.add_field(sql="CAST(null AS IPADDRESS)", python=None) \
177+
.add_field(sql="IPADDRESS '2001:db8::1'", python='2001:db8::1') \
178+
.execute()
179+
180+
181+
def test_uuid(trino_connection):
182+
SqlTest(trino_connection) \
183+
.add_field(sql="CAST(null AS UUID)", python=None) \
184+
.add_field(sql="UUID '12151fd2-7586-11e9-8f9e-2a86e4085a59'", python='12151fd2-7586-11e9-8f9e-2a86e4085a59') \
185+
.execute()
186+
187+
188+
def test_digest(trino_connection):
189+
SqlTest(trino_connection) \
190+
.add_field(sql="CAST(null AS HyperLogLog)", python=None) \
191+
.add_field(sql="CAST(null AS P4HyperLogLog)", python=None) \
192+
.add_field(sql="CAST(null AS SetDigest)", python=None) \
193+
.add_field(sql="CAST(null AS QDigest(BIGINT))", python=None) \
194+
.add_field(sql="CAST(null AS TDigest)", python=None) \
195+
.add_field(sql="approx_set(1)", python='AgwBAIADRAA=') \
196+
.add_field(sql="CAST(approx_set(1) AS P4HyperLogLog)", python='AwwAAAAg' + 'A' * 2730 + '==') \
197+
.add_field(sql="make_set_digest(1)", python='AQgAAAACCwEAgANEAAAgAAABAAAASsQF+7cDRAABAA==') \
198+
.add_field(sql="tdigest_agg(1)",
199+
python='AAAAAAAAAPA/AAAAAAAA8D8AAAAAAABZQAAAAAAAAPA/AQAAAAAAAAAAAPA/AAAAAAAA8D8=') \
200+
.execute()
201+
202+
203+
class SqlTest:
204+
def __init__(self, trino_connection):
205+
self.cur = trino_connection.cursor(experimental_python_types=True)
206+
self.sql_args = []
207+
self.expected_result = []
208+
209+
def add_field(self, sql, python):
210+
self.sql_args.append(sql)
211+
self.expected_result.append(python)
212+
return self
213+
214+
def execute(self):
215+
sql = 'SELECT ' + ',\n'.join(self.sql_args)
216+
217+
self.cur.execute(sql)
218+
actual_result = self.cur.fetchall()
219+
self._compare_results(actual_result[0], self.expected_result)
220+
221+
def _compare_results(self, actual, expected):
222+
assert len(actual) == len(expected)
223+
224+
for idx, actual_val in enumerate(actual):
225+
expected_val = expected[idx]
226+
if type(actual_val) == float and math.isnan(actual_val) \
227+
and type(expected_val) == float and math.isnan(expected_val):
228+
continue
229+
230+
assert actual_val == expected_val
231+
232+
233+
class SqlExpectFailureTest:
234+
def __init__(self, trino_connection):
235+
self.cur = trino_connection.cursor(experimental_python_types=True)
236+
237+
def execute(self, field):
238+
sql = 'SELECT ' + field
239+
240+
try:
241+
self.cur.execute(sql)
242+
self.cur.fetchall()
243+
success = True
244+
except Exception:
245+
success = False
246+
247+
assert not success, "Test not expected to succeed"

0 commit comments

Comments
 (0)