Skip to content

Commit d2ef96c

Browse files
committed
Support decimal, date, time, timestamp with time zone and timestamp
1 parent 65506e8 commit d2ef96c

File tree

3 files changed

+80
-4
lines changed

3 files changed

+80
-4
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# limitations under the License.
1212
import math
1313
from datetime import datetime
14+
from decimal import Decimal
1415

1516
import pytest
1617
import pytz
@@ -123,19 +124,28 @@ def test_string_query_param(trino_connection):
123124
assert rows[0][0] == "six'"
124125

125126

127+
def test_decimal_query_param(trino_connection):
128+
cur = trino_connection.cursor()
129+
130+
cur.execute("SELECT ?", params=(Decimal('0.142857'),))
131+
rows = cur.fetchall()
132+
133+
assert rows[0][0] == Decimal('0.142857')
134+
135+
126136
def test_datetime_query_param(trino_connection):
127137
cur = trino_connection.cursor()
128138

129139
cur.execute("SELECT ?", params=(datetime(2020, 1, 1, 0, 0, 0),))
130140
rows = cur.fetchall()
131141

132-
assert rows[0][0] == "2020-01-01 00:00:00.000"
142+
assert rows[0][0] == datetime(2020, 1, 1, 0, 0, 0)
133143

134144
cur.execute("SELECT ?",
135145
params=(datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc),))
136146
rows = cur.fetchall()
137147

138-
assert rows[0][0] == "2020-01-01 00:00:00.000 UTC"
148+
assert rows[0][0] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc)
139149
assert cur.description[0][1] == "timestamp with time zone"
140150

141151

@@ -158,6 +168,38 @@ def test_array_query_param(trino_connection):
158168
assert rows[0][0] == "array(integer)"
159169

160170

171+
def test_array_timestamp_query_param(trino_connection):
172+
cur = trino_connection.cursor()
173+
174+
params = [datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 1, 2, 0, 0, 0)]
175+
176+
cur.execute("SELECT ?", params=(params,))
177+
rows = cur.fetchall()
178+
179+
assert rows[0][0] == params
180+
181+
cur.execute("SELECT TYPEOF(?)", params=(params,))
182+
rows = cur.fetchall()
183+
184+
assert rows[0][0] == "array(timestamp(6))"
185+
186+
187+
def test_array_timestamp_with_timezone_query_param(trino_connection):
188+
cur = trino_connection.cursor()
189+
190+
params = [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]
191+
192+
cur.execute("SELECT ?", params=(params,))
193+
rows = cur.fetchall()
194+
195+
assert rows[0][0] == params
196+
197+
cur.execute("SELECT TYPEOF(?)", params=(params,))
198+
rows = cur.fetchall()
199+
200+
assert rows[0][0] == "array(timestamp(6) with time zone)"
201+
202+
161203
def test_dict_query_param(trino_connection):
162204
cur = trino_connection.cursor()
163205

trino/client.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
import copy
3737
import os
3838
import re
39+
from decimal import Decimal
40+
from datetime import datetime
41+
import pytz
42+
from time import strptime
3943
from typing import Any, Dict, List, Optional, Tuple, Union
4044
import urllib.parse
4145

@@ -494,12 +498,39 @@ def __iter__(self):
494498
for row in rows:
495499
self._rownumber += 1
496500
logger.debug("row %s", row)
497-
yield row
501+
yield self._map_to_python_types(row, self._query.columns)
498502

499503
@property
500504
def response_headers(self):
501505
return self._query.response_headers
502506

507+
@classmethod
508+
def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
509+
(value, data_type) = item
510+
511+
raw_type = data_type["typeSignature"]["rawType"]
512+
if isinstance(value, list):
513+
raw_type = {
514+
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
515+
}
516+
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
517+
elif "decimal" in raw_type:
518+
return Decimal(value)
519+
elif raw_type == "date":
520+
return strptime(value, "%Y-%m-%d")
521+
elif raw_type == "time":
522+
return strptime(value, "%H:%M:%S")
523+
elif "timestamp with time zone" in raw_type:
524+
time_str = strptime(value, "%Y-%m-%d %H:%M:%S.%f %Z")
525+
return datetime(*time_str[:3], tzinfo=pytz.timezone(time_str.tm_zone))
526+
elif "timestamp" in raw_type:
527+
return datetime(*strptime(value, "%Y-%m-%d %H:%M:%S.%f")[:3])
528+
else:
529+
return value
530+
531+
def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]:
532+
return list(map(self._map_to_python_type, zip(row, columns)))
533+
503534

504535
class TrinoQuery(object):
505536
"""Represent the execution of a SQL statement by Trino."""

trino/dbapi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Fetch methods returns rows as a list of lists on purpose to let the caller
1818
decide to convert then to a list of tuples.
1919
"""
20-
20+
from decimal import Decimal
2121
from typing import Any, List, Optional # NOQA for mypy types
2222

2323
import copy
@@ -379,6 +379,9 @@ def _format_prepared_param(self, param):
379379
if isinstance(param, uuid.UUID):
380380
return "UUID '%s'" % param
381381

382+
if isinstance(param, Decimal):
383+
return "DECIMAL '%s'" % param
384+
382385
raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))
383386

384387
def _deallocate_prepare_statement(self, added_prepare_header, statement_name):

0 commit comments

Comments
 (0)