Skip to content

Add type hints #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import sys
import time
import traceback
import typing
import warnings

from . import compat
Expand All @@ -25,6 +26,7 @@
from . import protocol
from . import serverversion
from . import transaction
from . import types
from . import utils


Expand Down Expand Up @@ -179,11 +181,11 @@ def remove_log_listener(self, callback):
"""
self._log_listeners.discard(callback)

def get_server_pid(self):
def get_server_pid(self) -> int:
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()

def get_server_version(self):
def get_server_version(self) -> types.ServerVersion:
"""Return the version of the connected PostgreSQL server.

The returned value is a named tuple similar to that in
Expand All @@ -199,15 +201,15 @@ def get_server_version(self):
"""
return self._server_version

def get_settings(self):
def get_settings(self) -> protocol.ConnectionSettings:
"""Return connection settings.

:return: :class:`~asyncpg.ConnectionSettings`.
"""
return self._protocol.get_settings()

def transaction(self, *, isolation='read_committed', readonly=False,
deferrable=False):
deferrable=False) -> transaction.Transaction:
"""Create a :class:`~transaction.Transaction` object.

Refer to `PostgreSQL documentation`_ on the meaning of transaction
Expand All @@ -230,7 +232,7 @@ def transaction(self, *, isolation='read_committed', readonly=False,
self._check_open()
return transaction.Transaction(self, isolation, readonly, deferrable)

def is_in_transaction(self):
def is_in_transaction(self) -> bool:
"""Return True if Connection is currently inside a transaction.

:return bool: True if inside transaction, False otherwise.
Expand Down Expand Up @@ -275,7 +277,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
_, status, _ = await self._execute(query, args, 0, timeout, True)
return status.decode()

async def executemany(self, command: str, args, *, timeout: float=None):
async def executemany(self, command: str, args, *, timeout: float=None) \
-> None:
"""Execute an SQL *command* for each sequence of arguments in *args*.

Example:
Expand Down Expand Up @@ -378,7 +381,8 @@ async def _introspect_types(self, typeoids, timeout):
return await self.__execute(
self._intro_query, (list(typeoids),), 0, timeout)

def cursor(self, query, *args, prefetch=None, timeout=None):
def cursor(self, query, *args, prefetch=None, timeout=None) \
-> cursor.CursorFactory:
"""Return a *cursor factory* for the specified query.

:param args: Query arguments.
Expand All @@ -392,7 +396,8 @@ def cursor(self, query, *args, prefetch=None, timeout=None):
return cursor.CursorFactory(self, query, None, args,
prefetch, timeout)

async def prepare(self, query, *, timeout=None):
async def prepare(self, query, *, timeout=None) \
-> prepared_stmt.PreparedStatement:
"""Create a *prepared statement* for the specified query.

:param str query: Text of the query to create a prepared statement for.
Expand All @@ -408,7 +413,8 @@ async def _prepare(self, query, *, timeout=None, use_cache: bool=False):
use_cache=use_cache)
return prepared_stmt.PreparedStatement(self, query, stmt)

async def fetch(self, query, *args, timeout=None) -> list:
async def fetch(self, query, *args, timeout=None) \
-> typing.List[protocol.Record]:
"""Run a query and return the results as a list of :class:`Record`.

:param str query: Query text.
Expand All @@ -420,7 +426,8 @@ async def fetch(self, query, *args, timeout=None) -> list:
self._check_open()
return await self._execute(query, args, 0, timeout)

async def fetchval(self, query, *args, column=0, timeout=None):
async def fetchval(self, query, *args, column=0, timeout=None) \
-> typing.Any:
"""Run a query and return a value in the first row.

:param str query: Query text.
Expand All @@ -441,7 +448,8 @@ async def fetchval(self, query, *args, column=0, timeout=None):
return None
return data[0][column]

async def fetchrow(self, query, *args, timeout=None):
async def fetchrow(self, query, *args, timeout=None) \
-> typing.Optional[protocol.Record]:
"""Run a query and return the first row.

:param str query: Query text
Expand All @@ -461,7 +469,8 @@ async def copy_from_table(self, table_name, *, output,
columns=None, schema_name=None, timeout=None,
format=None, oids=None, delimiter=None,
null=None, header=None, quote=None,
escape=None, force_quote=None, encoding=None):
escape=None, force_quote=None, encoding=None) \
-> str:
"""Copy table contents to a file or file-like object.

:param str table_name:
Expand Down Expand Up @@ -533,7 +542,7 @@ async def copy_from_query(self, query, *args, output,
timeout=None, format=None, oids=None,
delimiter=None, null=None, header=None,
quote=None, escape=None, force_quote=None,
encoding=None):
encoding=None) -> str:
"""Copy the results of a query to a file or file-like object.

:param str query:
Expand Down Expand Up @@ -597,7 +606,7 @@ async def copy_to_table(self, table_name, *, source,
delimiter=None, null=None, header=None,
quote=None, escape=None, force_quote=None,
force_not_null=None, force_null=None,
encoding=None):
encoding=None) -> str:
"""Copy data to the specified table.

:param str table_name:
Expand Down Expand Up @@ -668,7 +677,7 @@ async def copy_to_table(self, table_name, *, source,

async def copy_records_to_table(self, table_name, *, records,
columns=None, schema_name=None,
timeout=None):
timeout=None) -> str:
"""Copy a list of records to the specified table using binary COPY.

:param str table_name:
Expand Down Expand Up @@ -1060,7 +1069,7 @@ async def set_builtin_type_codec(self, typename, *,
# Statement cache is no longer valid due to codec changes.
self._drop_local_statement_cache()

def is_closed(self):
def is_closed(self) -> bool:
"""Return ``True`` if the connection is closed, ``False`` otherwise.

:return bool: ``True`` if the connection is closed, ``False``
Expand Down Expand Up @@ -1503,7 +1512,7 @@ async def connect(dsn=None, *,
command_timeout=None,
ssl=None,
connection_class=Connection,
server_settings=None):
server_settings=None) -> Connection:
r"""A coroutine to establish a connection to a PostgreSQL server.

The connection parameters may be specified either as a connection
Expand Down
14 changes: 8 additions & 6 deletions asyncpg/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@


import collections
import typing

from . import compat
from . import connresource
from . import exceptions
from . import protocol


class CursorFactory(connresource.ConnectionResource):
Expand All @@ -33,15 +35,15 @@ def __init__(self, connection, query, state, args, prefetch, timeout):

@compat.aiter_compat
@connresource.guarded
def __aiter__(self):
def __aiter__(self) -> 'CursorIterator':
prefetch = 50 if self._prefetch is None else self._prefetch
return CursorIterator(self._connection,
self._query, self._state,
self._args, prefetch,
self._timeout)

@connresource.guarded
def __await__(self):
def __await__(self) -> 'Cursor':
if self._prefetch is not None:
raise exceptions.InterfaceError(
'prefetch argument can only be specified for iterable cursor')
Expand Down Expand Up @@ -164,11 +166,11 @@ def __init__(self, connection, query, state, args, prefetch, timeout):

@compat.aiter_compat
@connresource.guarded
def __aiter__(self):
def __aiter__(self) -> 'CursorIterator':
return self

@connresource.guarded
async def __anext__(self):
async def __anext__(self) -> protocol.Record:
if self._state is None:
self._state = await self._connection._get_statement(
self._query, self._timeout, named=True)
Expand Down Expand Up @@ -203,7 +205,7 @@ async def _init(self, timeout):
return self

@connresource.guarded
async def fetch(self, n, *, timeout=None):
async def fetch(self, n, *, timeout=None) -> typing.List[protocol.Record]:
r"""Return the next *n* rows as a list of :class:`Record` objects.

:param float timeout: Optional timeout value in seconds.
Expand All @@ -221,7 +223,7 @@ async def fetch(self, n, *, timeout=None):
return recs

@connresource.guarded
async def fetchrow(self, *, timeout=None):
async def fetchrow(self, *, timeout=None) -> protocol.Record:
r"""Return the next row.

:param float timeout: Optional timeout value in seconds.
Expand Down
18 changes: 12 additions & 6 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import inspect
import logging
import time
import typing
import warnings

from . import connection
from . import connect_utils
from . import exceptions
from . import protocol


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -508,7 +510,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
async with self.acquire() as con:
return await con.execute(query, *args, timeout=timeout)

async def executemany(self, command: str, args, *, timeout: float=None):
async def executemany(self, command: str, args, *, timeout: float=None) \
-> None:
"""Execute an SQL *command* for each sequence of arguments in *args*.

Pool performs this operation using one of its connections. Other than
Expand All @@ -520,7 +523,8 @@ async def executemany(self, command: str, args, *, timeout: float=None):
async with self.acquire() as con:
return await con.executemany(command, args, timeout=timeout)

async def fetch(self, query, *args, timeout=None) -> list:
async def fetch(self, query, *args, timeout=None) \
-> typing.List[protocol.Record]:
"""Run a query and return the results as a list of :class:`Record`.

Pool performs this operation using one of its connections. Other than
Expand All @@ -532,7 +536,8 @@ async def fetch(self, query, *args, timeout=None) -> list:
async with self.acquire() as con:
return await con.fetch(query, *args, timeout=timeout)

async def fetchval(self, query, *args, column=0, timeout=None):
async def fetchval(self, query, *args, column=0, timeout=None) \
-> typing.Any:
"""Run a query and return a value in the first row.

Pool performs this operation using one of its connections. Other than
Expand All @@ -545,7 +550,8 @@ async def fetchval(self, query, *args, column=0, timeout=None):
return await con.fetchval(
query, *args, column=column, timeout=timeout)

async def fetchrow(self, query, *args, timeout=None):
async def fetchrow(self, query, *args, timeout=None) \
-> typing.Optional[protocol.Record]:
"""Run a query and return the first row.

Pool performs this operation using one of its connections. Other than
Expand All @@ -557,7 +563,7 @@ async def fetchrow(self, query, *args, timeout=None):
async with self.acquire() as con:
return await con.fetchrow(query, *args, timeout=timeout)

def acquire(self, *, timeout=None):
def acquire(self, *, timeout=None) -> connection.Connection:
"""Acquire a database connection from the pool.

:param float timeout: A timeout for acquiring a Connection.
Expand Down Expand Up @@ -784,7 +790,7 @@ def create_pool(dsn=None, *,
init=None,
loop=None,
connection_class=connection.Connection,
**connect_kwargs):
**connect_kwargs) -> Pool:
r"""Create a connection pool.

Can be used either with an ``async with`` block:
Expand Down
16 changes: 10 additions & 6 deletions asyncpg/prepared_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@


import json
import typing

from . import connresource
from . import cursor
from . import exceptions
from . import protocol
from . import types


class PreparedStatement(connresource.ConnectionResource):
Expand Down Expand Up @@ -50,7 +53,7 @@ def get_statusmsg(self) -> str:
return self._last_status.decode()

@connresource.guarded
def get_parameters(self):
def get_parameters(self) -> typing.Tuple[types.Type, ...]:
"""Return a description of statement parameters types.

:return: A tuple of :class:`asyncpg.types.Type`.
Expand All @@ -67,7 +70,7 @@ def get_parameters(self):
return self._state._get_parameters()

@connresource.guarded
def get_attributes(self):
def get_attributes(self) -> typing.Tuple[types.Attribute, ...]:
"""Return a description of relation attributes (columns).

:return: A tuple of :class:`asyncpg.types.Attribute`.
Expand Down Expand Up @@ -108,7 +111,7 @@ def cursor(self, *args, prefetch=None,
timeout)

@connresource.guarded
async def explain(self, *args, analyze=False):
async def explain(self, *args, analyze=False) -> typing.Dict:
"""Return the execution plan of the statement.

:param args: Query arguments.
Expand Down Expand Up @@ -150,7 +153,7 @@ async def explain(self, *args, analyze=False):
return json.loads(data)

@connresource.guarded
async def fetch(self, *args, timeout=None):
async def fetch(self, *args, timeout=None) -> typing.List[protocol.Record]:
r"""Execute the statement and return a list of :class:`Record` objects.

:param str query: Query text
Expand All @@ -163,7 +166,7 @@ async def fetch(self, *args, timeout=None):
return data

@connresource.guarded
async def fetchval(self, *args, column=0, timeout=None):
async def fetchval(self, *args, column=0, timeout=None) -> typing.Any:
"""Execute the statement and return a value in the first row.

:param args: Query arguments.
Expand All @@ -182,7 +185,8 @@ async def fetchval(self, *args, column=0, timeout=None):
return data[0][column]

@connresource.guarded
async def fetchrow(self, *args, timeout=None):
async def fetchrow(self, *args, timeout=None) \
-> typing.Optional[protocol.Record]:
"""Execute the statement and return the first row.

:param str query: Query text
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


from .protocol import Protocol, Record, NO_TIMEOUT # NOQA
from .protocol import Protocol, Record, ConnectionSettings, NO_TIMEOUT # NOQA