Skip to content

Commit a02bc6d

Browse files
committed
Support decimal, date, time, timestamp with time zone and timestamp
1 parent 65506e8 commit a02bc6d

File tree

3 files changed

+175
-7
lines changed

3 files changed

+175
-7
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212
import math
13-
from datetime import datetime
13+
from datetime import datetime, time
14+
from decimal import Decimal
1415

1516
import pytest
1617
import pytz
@@ -123,22 +124,114 @@ def test_string_query_param(trino_connection):
123124
assert rows[0][0] == "six'"
124125

125126

127+
def test_decimal_query_param(trino_connection):
128+
cur = trino_connection.cursor()
129+
130+
cur.execute("SELECT ?", params=(Decimal('0.142857'),))
131+
rows = cur.fetchall()
132+
133+
assert rows[0][0] == Decimal('0.142857')
134+
135+
126136
def test_datetime_query_param(trino_connection):
127137
cur = trino_connection.cursor()
128138

129-
cur.execute("SELECT ?", params=(datetime(2020, 1, 1, 0, 0, 0),))
139+
params = datetime(2020, 1, 1, 16, 43, 22, 320000)
140+
141+
cur.execute("SELECT ?", params=(params,))
130142
rows = cur.fetchall()
131143

132-
assert rows[0][0] == "2020-01-01 00:00:00.000"
144+
assert rows[0][0] == params
145+
assert cur.description[0][1] == "timestamp"
146+
147+
148+
def test_datetime_with_time_zone_query_param(trino_connection):
149+
cur = trino_connection.cursor()
150+
151+
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('CET'))
133152

134153
cur.execute("SELECT ?",
135-
params=(datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc),))
154+
params=(params,))
136155
rows = cur.fetchall()
137156

138-
assert rows[0][0] == "2020-01-01 00:00:00.000 UTC"
157+
assert rows[0][0] == params
139158
assert cur.description[0][1] == "timestamp with time zone"
140159

141160

161+
def test_datetime_with_time_zone_numeric_offset(trino_connection):
162+
cur = trino_connection.cursor()
163+
164+
cur.execute("SELECT TIMESTAMP '2001-08-22 03:04:05.321 -08:00'")
165+
rows = cur.fetchall()
166+
167+
assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z")
168+
169+
170+
def test_special_datetimes_query_param(trino_connection):
171+
cur = trino_connection.cursor()
172+
173+
for special_date in (
174+
datetime.fromtimestamp(1603589478, pytz.timezone('Europe/Warsaw')),
175+
):
176+
params = special_date
177+
178+
cur.execute("SELECT ?", params=(params,))
179+
rows = cur.fetchall()
180+
181+
assert rows[0][0] == params
182+
183+
184+
def test_date_query_param(trino_connection):
185+
cur = trino_connection.cursor()
186+
187+
params = datetime(2020, 1, 1, 0, 0, 0).date()
188+
189+
cur.execute("SELECT ?", params=(params,))
190+
rows = cur.fetchall()
191+
192+
assert rows[0][0] == params
193+
194+
195+
def test_special_dates_query_param(trino_connection):
196+
cur = trino_connection.cursor()
197+
198+
for params in (
199+
# datetime(-1, 1, 1, 0, 0, 0).date(),
200+
# datetime(0, 1, 1, 0, 0, 0).date(),
201+
datetime(1752, 9, 4, 0, 0, 0).date(),
202+
datetime(1970, 1, 1, 0, 0, 0).date(),
203+
):
204+
cur.execute("SELECT ?", params=(params,))
205+
rows = cur.fetchall()
206+
207+
assert rows[0][0] == params
208+
209+
210+
def test_time_query_param(trino_connection):
211+
cur = trino_connection.cursor()
212+
213+
params = time(12, 3, 44, 333000)
214+
215+
cur.execute("SELECT ?", params=(params,))
216+
rows = cur.fetchall()
217+
218+
assert rows[0][0] == params
219+
220+
221+
@pytest.mark.skip(reason="time with time zone currently not supported")
222+
def test_time_with_time_zone_query_param(trino_connection):
223+
cur = trino_connection.cursor()
224+
225+
params = time(16, 43, 22, 320000, tzinfo=pytz.timezone('CET'))
226+
227+
cur.execute("SELECT ?",
228+
params=(params,))
229+
rows = cur.fetchall()
230+
231+
assert rows[0][0] == params
232+
assert cur.description[0][1] == "time with time zone"
233+
234+
142235
def test_array_query_param(trino_connection):
143236
cur = trino_connection.cursor()
144237

@@ -158,6 +251,38 @@ def test_array_query_param(trino_connection):
158251
assert rows[0][0] == "array(integer)"
159252

160253

254+
def test_array_timestamp_query_param(trino_connection):
255+
cur = trino_connection.cursor()
256+
257+
params = [datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 1, 2, 0, 0, 0)]
258+
259+
cur.execute("SELECT ?", params=(params,))
260+
rows = cur.fetchall()
261+
262+
assert rows[0][0] == params
263+
264+
cur.execute("SELECT TYPEOF(?)", params=(params,))
265+
rows = cur.fetchall()
266+
267+
assert rows[0][0] == "array(timestamp(6))"
268+
269+
270+
def test_array_timestamp_with_timezone_query_param(trino_connection):
271+
cur = trino_connection.cursor()
272+
273+
params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]
274+
275+
cur.execute("SELECT ?", params=(params,))
276+
rows = cur.fetchall()
277+
278+
assert rows[0][0] == params
279+
280+
cur.execute("SELECT TYPEOF(?)", params=(params,))
281+
rows = cur.fetchall()
282+
283+
assert rows[0][0] == "array(timestamp(6) with time zone)"
284+
285+
161286
def test_dict_query_param(trino_connection):
162287
cur = trino_connection.cursor()
163288

trino/client.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
import copy
3737
import os
3838
import re
39+
from decimal import Decimal
40+
from datetime import datetime
41+
import pytz
3942
from typing import Any, Dict, List, Optional, Tuple, Union
4043
import urllib.parse
4144

@@ -494,12 +497,41 @@ def __iter__(self):
494497
for row in rows:
495498
self._rownumber += 1
496499
logger.debug("row %s", row)
497-
yield row
500+
yield self._map_to_python_types(row, self._query.columns)
498501

499502
@property
500503
def response_headers(self):
501504
return self._query.response_headers
502505

506+
@classmethod
507+
def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
508+
(value, data_type) = item
509+
510+
raw_type = data_type["typeSignature"]["rawType"]
511+
if isinstance(value, list):
512+
raw_type = {
513+
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
514+
}
515+
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
516+
elif "decimal" in raw_type:
517+
return Decimal(value)
518+
elif raw_type == "date":
519+
return datetime.strptime(value, "%Y-%m-%d").date()
520+
elif raw_type == "timestamp with time zone":
521+
dt, tz = value.rsplit(' ', 1)
522+
if tz.startswith('+') or tz.startswith('-'):
523+
return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f %z")
524+
return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=pytz.timezone(tz))
525+
elif "timestamp" in raw_type:
526+
return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f")
527+
elif "time" in raw_type:
528+
return datetime.strptime(value, "%H:%M:%S.%f").time()
529+
else:
530+
return value
531+
532+
def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]:
533+
return list(map(self._map_to_python_type, zip(row, columns)))
534+
503535

504536
class TrinoQuery(object):
505537
"""Represent the execution of a SQL statement by Trino."""

trino/dbapi.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Fetch methods returns rows as a list of lists on purpose to let the caller
1818
decide to convert then to a list of tuples.
1919
"""
20-
20+
from decimal import Decimal
2121
from typing import Any, List, Optional # NOQA for mypy types
2222

2323
import copy
@@ -365,6 +365,14 @@ def _format_prepared_param(self, param):
365365
datetime_str = datetime_str.rstrip(" ")
366366
return "TIMESTAMP '%s'" % datetime_str
367367

368+
if isinstance(param, datetime.time):
369+
time_str = param.strftime("%H:%M:%S.%f")
370+
return "TIME '%s'" % time_str
371+
372+
if isinstance(param, datetime.date):
373+
date_str = param.strftime("%Y-%m-%d")
374+
return "DATE '%s'" % date_str
375+
368376
if isinstance(param, list):
369377
return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param))
370378

@@ -379,6 +387,9 @@ def _format_prepared_param(self, param):
379387
if isinstance(param, uuid.UUID):
380388
return "UUID '%s'" % param
381389

390+
if isinstance(param, Decimal):
391+
return "DECIMAL '%s'" % param
392+
382393
raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))
383394

384395
def _deallocate_prepare_statement(self, added_prepare_header, statement_name):

0 commit comments

Comments
 (0)