diff --git a/asyncpg/pool.py b/asyncpg/pool.py index e3898d53..2e4a7b4f 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -4,12 +4,16 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from __future__ import annotations import asyncio +from collections.abc import Awaitable, Callable import functools import inspect import logging import time +from types import TracebackType +from typing import Any, Optional, Type import warnings from . import compat @@ -23,7 +27,14 @@ class PoolConnectionProxyMeta(type): - def __new__(mcls, name, bases, dct, *, wrap=False): + def __new__( + mcls, + name: str, + bases: tuple[Type[Any], ...], + dct: dict[str, Any], + *, + wrap: bool = False, + ) -> PoolConnectionProxyMeta: if wrap: for attrname in dir(connection.Connection): if attrname.startswith('_') or attrname in dct: @@ -44,8 +55,10 @@ def __new__(mcls, name, bases, dct, *, wrap=False): return super().__new__(mcls, name, bases, dct) @staticmethod - def _wrap_connection_method(meth_name, iscoroutine): - def call_con_method(self, *args, **kwargs): + def _wrap_connection_method( + meth_name: str, iscoroutine: bool + ) -> Callable[..., Any]: + def call_con_method(self: Any, *args: Any, **kwargs: Any) -> Any: # This method will be owned by PoolConnectionProxy class. if self._con is None: raise exceptions.InterfaceError( @@ -68,17 +81,18 @@ class PoolConnectionProxy(connection._ConnectionProxy, __slots__ = ('_con', '_holder') - def __init__(self, holder: 'PoolConnectionHolder', - con: connection.Connection): + def __init__( + self, holder: PoolConnectionHolder, con: connection.Connection + ) -> None: self._con = con self._holder = holder con._set_proxy(self) - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: # Proxy all unresolved attributes to the wrapped Connection object. return getattr(self._con, attr) - def _detach(self) -> connection.Connection: + def _detach(self) -> Optional[connection.Connection]: if self._con is None: return @@ -86,7 +100,7 @@ def _detach(self) -> connection.Connection: con._set_proxy(None) return con - def __repr__(self): + def __repr__(self) -> str: if self._con is None: return '<{classname} [released] {id:#x}>'.format( classname=self.__class__.__name__, id=id(self)) @@ -103,27 +117,34 @@ class PoolConnectionHolder: '_inactive_callback', '_timeout', '_generation') - def __init__(self, pool, *, max_queries, setup, max_inactive_time): + def __init__( + self, + pool: "Pool", + *, + max_queries: float, + setup: Optional[Callable[[PoolConnectionProxy], Awaitable[None]]], + max_inactive_time: float, + ) -> None: self._pool = pool - self._con = None - self._proxy = None + self._con: Optional[connection.Connection] = None + self._proxy: Optional[PoolConnectionProxy] = None self._max_queries = max_queries self._max_inactive_time = max_inactive_time self._setup = setup - self._inactive_callback = None - self._in_use = None # type: asyncio.Future - self._timeout = None - self._generation = None + self._inactive_callback: Optional[Callable] = None + self._in_use: Optional[asyncio.Future] = None + self._timeout: Optional[float] = None + self._generation: Optional[int] = None - def is_connected(self): + def is_connected(self) -> bool: return self._con is not None and not self._con.is_closed() - def is_idle(self): + def is_idle(self) -> bool: return not self._in_use - async def connect(self): + async def connect(self) -> None: if self._con is not None: raise exceptions.InternalClientError( 'PoolConnectionHolder.connect() called while another ' @@ -171,7 +192,7 @@ async def acquire(self) -> PoolConnectionProxy: return proxy - async def release(self, timeout): + async def release(self, timeout: Optional[float]) -> None: if self._in_use is None: raise exceptions.InternalClientError( 'PoolConnectionHolder.release() called on ' @@ -234,25 +255,25 @@ async def release(self, timeout): # Rearm the connection inactivity timer. self._setup_inactive_callback() - async def wait_until_released(self): + async def wait_until_released(self) -> None: if self._in_use is None: return else: await self._in_use - async def close(self): + async def close(self) -> None: if self._con is not None: # Connection.close() will call _release_on_close() to # finish holder cleanup. await self._con.close() - def terminate(self): + def terminate(self) -> None: if self._con is not None: # Connection.terminate() will call _release_on_close() to # finish holder cleanup. self._con.terminate() - def _setup_inactive_callback(self): + def _setup_inactive_callback(self) -> None: if self._inactive_callback is not None: raise exceptions.InternalClientError( 'pool connection inactivity timer already exists') @@ -261,12 +282,12 @@ def _setup_inactive_callback(self): self._inactive_callback = self._pool._loop.call_later( self._max_inactive_time, self._deactivate_inactive_connection) - def _maybe_cancel_inactive_callback(self): + def _maybe_cancel_inactive_callback(self) -> None: if self._inactive_callback is not None: self._inactive_callback.cancel() self._inactive_callback = None - def _deactivate_inactive_connection(self): + def _deactivate_inactive_connection(self) -> None: if self._in_use is not None: raise exceptions.InternalClientError( 'attempting to deactivate an acquired connection') @@ -280,12 +301,12 @@ def _deactivate_inactive_connection(self): # so terminate() above will not call the below. self._release_on_close() - def _release_on_close(self): + def _release_on_close(self) -> None: self._maybe_cancel_inactive_callback() self._release() self._con = None - def _release(self): + def _release(self) -> None: """Release this connection holder.""" if self._in_use is None: # The holder is not checked out. @@ -1012,7 +1033,7 @@ class PoolAcquireContext: __slots__ = ('timeout', 'connection', 'done', 'pool') - def __init__(self, pool, timeout): + def __init__(self, pool: Pool, timeout: Optional[float]) -> None: self.pool = pool self.timeout = timeout self.connection = None @@ -1024,7 +1045,12 @@ async def __aenter__(self): self.connection = await self.pool._acquire(self.timeout) return self.connection - async def __aexit__(self, *exc): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_val: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: self.done = True con = self.connection self.connection = None