diff --git a/README.md b/README.md index 5e4cd18..07bf2a9 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,36 @@ The transaction is created when the first SQL statement is executed. exits the *with* context and the queries succeed, otherwise `prestodb.dbapi.Connection.rollback()' will be called. +# Improved Python types + +If you enable the flag `experimental_python_types`, the client will convert the results of the query to the +corresponding Python types. For example, if the query returns a `DECIMAL` column, the result will be a `Decimal` object. + +Limitations of the Python types are described in the +[Python types documentation](https://docs.python.org/3/library/datatypes.html). These limitations will generate an +exception `prestodb.exceptions.DataError` if the query returns a value that cannot be converted to the corresponding Python +type. + +```python +import prestodb +import pytz +from datetime import datetime + +conn = prestodb.dbapi.connect( + experimental_python_types=True + ... +) + +cur = conn.cursor() + +params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles')) + +cur.execute("SELECT ?", params=(params,)) +rows = cur.fetchall() + +assert rows[0][0] == params +assert cur.description[0][1] == "timestamp with time zone" + # Running Tests There is a helper scripts, `run`, that provides commands to run tests. diff --git a/prestodb/client.py b/prestodb/client.py index 1bf82fe..eac8016 100644 --- a/prestodb/client.py +++ b/prestodb/client.py @@ -37,6 +37,10 @@ import logging import os from typing import Any, Dict, List, Optional, Text, Tuple, Union # NOQA for mypy types +from datetime import datetime, timedelta, timezone +from decimal import Decimal +from typing import Any, Dict, List, Optional, Tuple, Union +import pytz import prestodb.redirect import requests @@ -457,10 +461,11 @@ class PrestoResult(object): https://docs.python.org/3/library/stdtypes.html#generator-types """ - def __init__(self, query, rows=None): + def __init__(self, query, rows=None, experimental_python_types = False): self._query = query self._rows = rows or [] self._rownumber = 0 + self._experimental_python_types = experimental_python_types @property def rownumber(self): @@ -471,7 +476,10 @@ def __iter__(self): # Initial fetch from the first POST request for row in self._rows: self._rownumber += 1 - yield row + if not self._experimental_python_types: + yield row + else: + yield self._map_to_python_types(row, self._query.columns) self._rows = None # Subsequent fetches from GET requests until next_uri is empty. @@ -479,7 +487,56 @@ def __iter__(self): rows = self._query.fetch() for row in rows: self._rownumber += 1 - yield row + if not self._experimental_python_types: + yield row + else: + yield self._map_to_python_types(row, self._query.columns) + + @classmethod + def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any: + (value, data_type) = item + + if value is None: + return None + + raw_type = data_type["typeSignature"]["rawType"] + + try: + if isinstance(value, list): + raw_type = { + "typeSignature": data_type["typeSignature"]["arguments"][0]["value"] + } + return [cls._map_to_python_type((array_item, raw_type)) for array_item in value] + elif "decimal" in raw_type: + return Decimal(value) + elif raw_type == "date": + return datetime.strptime(value, "%Y-%m-%d").date() + elif raw_type == "timestamp with time zone": + dt, tz = value.rsplit(' ', 1) + if tz.startswith('+') or tz.startswith('-'): + return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f %z") + return datetime.strptime(dt, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=pytz.timezone(tz)) + elif "timestamp" in raw_type: + return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f") + elif "time with time zone" in raw_type: + matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value) + assert matches is not None + assert len(matches.groups()) == 4 + if matches.group(2) == '-': + tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) + else: + tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4))) + return datetime.strptime(matches.group(1), "%H:%M:%S.%f").time().replace(tzinfo=timezone(tz)) + elif "time" in raw_type: + return datetime.strptime(value, "%H:%M:%S.%f").time() + else: + return value + except ValueError as e: + error_str = f"Could not convert '{value}' into the associated python type for '{raw_type}'" + raise prestodb/client.py (error_str) from e + + def _map_to_python_types(self, row: List[Any], columns: List[Dict[str, Any]]) -> List[Any]: + return list(map(self._map_to_python_type, zip(row, columns))) class PrestoQuery(object): @@ -489,6 +546,7 @@ def __init__( self, request, # type: PrestoRequest sql, # type: Text + experimental_python_types = False, ): # type: (...) -> None self.auth_req = request.auth_req # type: Optional[Request] @@ -502,7 +560,8 @@ def __init__( self._cancelled = False self._request = request self._sql = sql - self._result = PrestoResult(self) + self._result = PrestoResult(self, experimental_python_types=experimental_python_types) + self._experimental_python_types = experimental_python_types @property def columns(self): @@ -543,7 +602,7 @@ def execute(self): self._warnings = getattr(status, "warnings", []) if status.next_uri is None: self._finished = True - self._result = PrestoResult(self, status.rows) + self._result = PrestoResult(self, status.rows, self._experimental_python_types) while ( not self._finished and not self._cancelled ): diff --git a/prestodb/dbapi.py b/prestodb/dbapi.py index cc60bf7..bd84c0c 100644 --- a/prestodb/dbapi.py +++ b/prestodb/dbapi.py @@ -21,6 +21,7 @@ import binascii import datetime +from decimal import Decimal import logging import uuid from typing import Any, List, Optional # NOQA for mypy types @@ -76,6 +77,7 @@ def __init__( max_attempts=constants.DEFAULT_MAX_ATTEMPTS, request_timeout=constants.DEFAULT_REQUEST_TIMEOUT, isolation_level=IsolationLevel.AUTOCOMMIT, + experimental_python_types=False, **kwargs, ): self.host = host @@ -107,6 +109,8 @@ def __init__( self._request = None self._transaction = None + self.experimental_python_types = experimental_python_types + @property def isolation_level(self): return self._isolation_level @@ -171,7 +175,7 @@ def cursor(self): request = self.transaction._request else: request = self._create_request() - return Cursor(self, request) + return Cursor(self, request, self.experimental_python_types) class Cursor(object): @@ -182,7 +186,7 @@ class Cursor(object): """ - def __init__(self, connection, request): + def __init__(self, connection, request, experimental_python_types = False): if not isinstance(connection, Connection): raise ValueError( "connection must be a Connection object: {}".format(type(connection)) @@ -193,6 +197,7 @@ def __init__(self, connection, request): self.arraysize = 1 self._iterator = None self._query = None + self._experimental_python_types = experimental_python_types def __iter__(self): return self._iterator @@ -263,7 +268,7 @@ def execute(self, operation, params=None): # TODO: Consider caching prepared statements if requested by caller self._deallocate_prepared_statement(statement_name) else: - self._query = prestodb.client.PrestoQuery(self._request, sql=operation) + self._query = prestodb.client.PrestoQuery(self._request, sql=operation, experimental_python_types=self._experimental_python_types) self._iterator = iter(self._query.execute()) return self @@ -272,7 +277,7 @@ def _generate_unique_statement_name(self): def _prepare_statement(self, statement: str, name: str) -> None: sql = f"PREPARE {name} FROM {statement}" - query = prestodb.client.PrestoQuery(self._request, sql=sql) + query = prestodb.client.PrestoQuery(self._request, sql=sql, experimental_python_types=self._experimental_python_types) query.execute() def _execute_prepared_statement(self, statement_name, params): @@ -282,11 +287,11 @@ def _execute_prepared_statement(self, statement_name, params): + " USING " + ",".join(map(self._format_prepared_param, params)) ) - return prestodb.client.PrestoQuery(self._request, sql=sql) + return prestodb.client.PrestoQuery(self._request, sql=sql, experimental_python_types=self._experimental_python_types) def _deallocate_prepared_statement(self, statement_name: str) -> None: sql = "DEALLOCATE PREPARE " + statement_name - query = prestodb.client.PrestoQuery(self._request, sql=sql) + query = prestodb.client.PrestoQuery(self._request, sql=sql, experimental_python_types=self._experimental_python_types) query.execute() def _format_prepared_param(self, param): @@ -323,6 +328,9 @@ def _format_prepared_param(self, param): if isinstance(param, datetime.datetime) and param.tzinfo is not None: datetime_str = param.strftime("%Y-%m-%d %H:%M:%S.%f") + # named timezones + if hasattr(param.tzinfo, 'zone'): + return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.zone) # offset-based timezones return "TIMESTAMP '%s %s'" % (datetime_str, param.tzinfo.tzname(param)) @@ -356,6 +364,9 @@ def _format_prepared_param(self, param): if isinstance(param, uuid.UUID): return "UUID '%s'" % param + if isinstance(param, Decimal): + return "DECIMAL '%s'" % param + if isinstance(param, (bytes, bytearray)): return "X'%s'" % binascii.hexlify(param).decode("utf-8") diff --git a/setup.py b/setup.py index 8bda465..d396179 100644 --- a/setup.py +++ b/setup.py @@ -24,12 +24,13 @@ ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) ) +require = ["pytz"] kerberos_require = ["requests_kerberos"] google_auth_require = ["google_auth"] -all_require = [kerberos_require, google_auth_require] +all_require = [require, kerberos_require, google_auth_require] tests_require = all_require + ["httpretty", "pytest", "pytest-runner"]