Skip to content

Commit 8c316b4

Browse files
committed
Support decimal, date, time, timestamp with time zone and timestamp
1 parent 65506e8 commit 8c316b4

File tree

3 files changed

+84
-4
lines changed

3 files changed

+84
-4
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# limitations under the License.
1212
import math
1313
from datetime import datetime
14+
from decimal import Decimal
1415

1516
import pytest
1617
import pytz
@@ -123,19 +124,37 @@ def test_string_query_param(trino_connection):
123124
assert rows[0][0] == "six'"
124125

125126

127+
def test_float_query_param(trino_connection):
128+
cur = trino_connection.cursor()
129+
130+
cur.execute("SELECT ?", params=(1.23,))
131+
rows = cur.fetchall()
132+
133+
assert rows[0][0] == 1.23
134+
135+
136+
def test_decimal_query_param(trino_connection):
137+
cur = trino_connection.cursor()
138+
139+
cur.execute("SELECT ?", params=(Decimal('0.142857'),))
140+
rows = cur.fetchall()
141+
142+
assert rows[0][0] == Decimal('0.142857')
143+
144+
126145
def test_datetime_query_param(trino_connection):
127146
cur = trino_connection.cursor()
128147

129148
cur.execute("SELECT ?", params=(datetime(2020, 1, 1, 0, 0, 0),))
130149
rows = cur.fetchall()
131150

132-
assert rows[0][0] == "2020-01-01 00:00:00.000"
151+
assert rows[0][0] == datetime(2020, 1, 1, 0, 0, 0)
133152

134153
cur.execute("SELECT ?",
135154
params=(datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc),))
136155
rows = cur.fetchall()
137156

138-
assert rows[0][0] == "2020-01-01 00:00:00.000 UTC"
157+
assert rows[0][0] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc)
139158
assert cur.description[0][1] == "timestamp with time zone"
140159

141160

@@ -158,6 +177,33 @@ def test_array_query_param(trino_connection):
158177
assert rows[0][0] == "array(integer)"
159178

160179

180+
def test_array_timestamp_query_param(trino_connection):
181+
cur = trino_connection.cursor()
182+
cur.execute("SELECT ?", params=([datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 1, 2, 0, 0, 0)],))
183+
rows = cur.fetchall()
184+
185+
assert rows[0][0] == [datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 1, 2, 0, 0, 0)]
186+
187+
cur.execute("SELECT TYPEOF(?)", params=([datetime(2020, 1, 1, 0, 0, 0), datetime(2020, 1, 2, 0, 0, 0)],))
188+
rows = cur.fetchall()
189+
190+
assert rows[0][0] == "array(timestamp(6))"
191+
192+
193+
def test_array_timestamp_with_timezone_query_param(trino_connection):
194+
cur = trino_connection.cursor()
195+
196+
cur.execute("SELECT ?", params=([datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)],))
197+
rows = cur.fetchall()
198+
199+
assert rows[0][0] == [datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)]
200+
201+
cur.execute("SELECT TYPEOF(?)", params=([datetime(2020, 1, 1, 0, 0, 0, tzinfo=pytz.utc), datetime(2020, 1, 2, 0, 0, 0, tzinfo=pytz.utc)],))
202+
rows = cur.fetchall()
203+
204+
assert rows[0][0] == "array(timestamp(6) with time zone)"
205+
206+
161207
def test_dict_query_param(trino_connection):
162208
cur = trino_connection.cursor()
163209

trino/client.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
import copy
3737
import os
3838
import re
39+
from decimal import Decimal
40+
from datetime import datetime
41+
import pytz
42+
from time import strptime
3943
from typing import Any, Dict, List, Optional, Tuple, Union
4044
import urllib.parse
4145

@@ -494,12 +498,39 @@ def __iter__(self):
494498
for row in rows:
495499
self._rownumber += 1
496500
logger.debug("row %s", row)
497-
yield row
501+
yield self._map_to_python_types(row, self._query.columns)
498502

499503
@property
500504
def response_headers(self):
501505
return self._query.response_headers
502506

507+
@classmethod
508+
def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
509+
(value, data_type) = item
510+
511+
raw_type = data_type["typeSignature"]["rawType"]
512+
if isinstance(value, list):
513+
raw_type = {
514+
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
515+
}
516+
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
517+
elif "decimal" in raw_type:
518+
return Decimal(value)
519+
elif raw_type == "date":
520+
return strptime(value, "%Y-%m-%d")
521+
elif raw_type == "time":
522+
return strptime(value, "%H:%M:%S")
523+
elif "timestamp with time zone" in raw_type:
524+
time_str = strptime(value, "%Y-%m-%d %H:%M:%S.%f %Z")
525+
return datetime(*time_str[:3], tzinfo=pytz.timezone(time_str.tm_zone))
526+
elif "timestamp" in raw_type:
527+
return datetime(*strptime(value, "%Y-%m-%d %H:%M:%S.%f")[:3])
528+
else:
529+
return value
530+
531+
def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]:
532+
return list(map(self._map_to_python_type, zip(row, columns)))
533+
503534

504535
class TrinoQuery(object):
505536
"""Represent the execution of a SQL statement by Trino."""

trino/dbapi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Fetch methods returns rows as a list of lists on purpose to let the caller
1818
decide to convert then to a list of tuples.
1919
"""
20-
20+
from decimal import Decimal
2121
from typing import Any, List, Optional # NOQA for mypy types
2222

2323
import copy
@@ -379,6 +379,9 @@ def _format_prepared_param(self, param):
379379
if isinstance(param, uuid.UUID):
380380
return "UUID '%s'" % param
381381

382+
if isinstance(param, Decimal):
383+
return "DECIMAL '%s'" % param
384+
382385
raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))
383386

384387
def _deallocate_prepare_statement(self, added_prepare_header, statement_name):

0 commit comments

Comments
 (0)