diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 16ee054d..497d1ca5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,8 +11,7 @@ repos: - id: "mypy" name: "Python: types" additional_dependencies: - - "types-pytz" - - "types-requests" + - "types-all" - repo: https://github.com/pycqa/isort rev: 5.6.4 diff --git a/setup.cfg b/setup.cfg index 4ff502ba..d6106f2d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,8 +4,20 @@ test=pytest [flake8] max-line-length = 120 # W503 raises a warning when there is a line break before a binary operator. -# This is best practice according to PEP 8 and the rule should be ignored. +# This is best practice according to PEP 8 and the rule should be ignored. # # https://www.flake8rules.com/rules/W503.html # https://www.python.org/dev/peps/pep-0008/#should-a-line-break-before-or-after-a-binary-operator ignore = W503 + +[mypy] +check_untyped_defs = true +disallow_any_generics = true +disallow_untyped_calls = true +disallow_untyped_defs = true +ignore_missing_imports = true +no_implicit_optional = true +warn_unused_ignores = true + +[mypy-tests.*,trino.auth,trino.client,trino.dbapi,trino.sqlalchemy.*] +ignore_errors = true diff --git a/trino/client.py b/trino/client.py index b1c85e63..d023c72e 100644 --- a/trino/client.py +++ b/trino/client.py @@ -492,7 +492,7 @@ def statement_url(self) -> str: def next_uri(self) -> Optional[str]: return self._next_uri - def post(self, sql, additional_http_headers=None): + def post(self, sql: str, additional_http_headers: Optional[Dict[str, Any]] = None): data = sql.encode("utf-8") # Deep copy of the http_headers dict since they may be modified for this # request by the provided additional_http_headers @@ -524,7 +524,7 @@ def post(self, sql, additional_http_headers=None): ) return http_response - def get(self, url): + def get(self, url: str): return self._get( url, headers=self.http_headers, diff --git a/trino/exceptions.py b/trino/exceptions.py index 89082ba8..d48fc9ef 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -14,7 +14,7 @@ This module defines exceptions for Trino operations. It follows the structure defined in pep-0249. """ - +from typing import Any, Dict, Optional, Tuple import trino.logging @@ -72,44 +72,44 @@ class TrinoDataError(NotSupportedError): class TrinoQueryError(Error): - def __init__(self, error, query_id=None): + def __init__(self, error: Dict[str, Any], query_id: Optional[str] = None) -> None: self._error = error self._query_id = query_id @property - def error_code(self): + def error_code(self) -> Optional[int]: return self._error.get("errorCode", None) @property - def error_name(self): + def error_name(self) -> Optional[str]: return self._error.get("errorName", None) @property - def error_type(self): + def error_type(self) -> Optional[str]: return self._error.get("errorType", None) @property - def error_exception(self): + def error_exception(self) -> Optional[str]: return self.failure_info.get("type", None) if self.failure_info else None @property - def failure_info(self): + def failure_info(self) -> Optional[Dict[str, Any]]: return self._error.get("failureInfo", None) @property - def message(self): + def message(self) -> str: return self._error.get("message", "Trino did not return an error message") @property - def error_location(self): + def error_location(self) -> Tuple[int, int]: location = self._error["errorLocation"] return (location["lineNumber"], location["columnNumber"]) @property - def query_id(self): + def query_id(self) -> Optional[str]: return self._query_id - def __repr__(self): + def __repr__(self) -> str: return '{}(type={}, name={}, message="{}", query_id={})'.format( self.__class__.__name__, self.error_type, @@ -118,7 +118,7 @@ def __repr__(self): self.query_id, ) - def __str__(self): + def __str__(self) -> str: return repr(self) diff --git a/trino/logging.py b/trino/logging.py index dc99f5af..866c3e92 100644 --- a/trino/logging.py +++ b/trino/logging.py @@ -16,7 +16,7 @@ # TODO: provide interface to use ``logging.dictConfig`` -def get_logger(name, log_level=LEVEL): +def get_logger(name: str, log_level: int = LEVEL) -> logging.Logger: logger = logging.getLogger(name) logger.setLevel(log_level) return logger diff --git a/trino/transaction.py b/trino/transaction.py index c6e6257d..ebead938 100644 --- a/trino/transaction.py +++ b/trino/transaction.py @@ -50,19 +50,19 @@ def check(cls, level: int) -> int: class Transaction(object): - def __init__(self, request): + def __init__(self, request: trino.client.TrinoRequest) -> None: self._request = request self._id = NO_TRANSACTION @property - def id(self): + def id(self) -> str: return self._id @property - def request(self): + def request(self) -> trino.client.TrinoRequest: return self._request - def begin(self): + def begin(self) -> None: response = self._request.post(START_TRANSACTION) if not response.ok: raise trino.exceptions.DatabaseError( @@ -81,7 +81,7 @@ def begin(self): self._request.transaction_id = self._id logger.info("transaction started: %s", self._id) - def commit(self): + def commit(self) -> None: query = trino.client.TrinoQuery(self._request, COMMIT) try: list(query.execute()) @@ -92,7 +92,7 @@ def commit(self): self._id = NO_TRANSACTION self._request.transaction_id = self._id - def rollback(self): + def rollback(self) -> None: query = trino.client.TrinoQuery(self._request, ROLLBACK) try: list(query.execute())