Skip to content

Commit a01486d

Browse files
committed
added type hints
1 parent 89815ea commit a01486d

File tree

5 files changed

+45
-36
lines changed

5 files changed

+45
-36
lines changed

asyncpg/connection.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import sys
1515
import time
1616
import traceback
17+
import typing
1718
import warnings
1819

1920
from . import compat
@@ -25,6 +26,7 @@
2526
from . import protocol
2627
from . import serverversion
2728
from . import transaction
29+
from . import types
2830
from . import utils
2931

3032

@@ -179,11 +181,11 @@ def remove_log_listener(self, callback):
179181
"""
180182
self._log_listeners.discard(callback)
181183

182-
def get_server_pid(self):
184+
def get_server_pid(self) -> int:
183185
"""Return the PID of the Postgres server the connection is bound to."""
184186
return self._protocol.get_server_pid()
185187

186-
def get_server_version(self):
188+
def get_server_version(self) -> types.ServerVersion:
187189
"""Return the version of the connected PostgreSQL server.
188190
189191
The returned value is a named tuple similar to that in
@@ -199,15 +201,15 @@ def get_server_version(self):
199201
"""
200202
return self._server_version
201203

202-
def get_settings(self):
204+
def get_settings(self) -> protocol.ConnectionSettings:
203205
"""Return connection settings.
204206
205207
:return: :class:`~asyncpg.ConnectionSettings`.
206208
"""
207209
return self._protocol.get_settings()
208210

209211
def transaction(self, *, isolation='read_committed', readonly=False,
210-
deferrable=False):
212+
deferrable=False) -> transaction.Transaction:
211213
"""Create a :class:`~transaction.Transaction` object.
212214
213215
Refer to `PostgreSQL documentation`_ on the meaning of transaction
@@ -230,7 +232,7 @@ def transaction(self, *, isolation='read_committed', readonly=False,
230232
self._check_open()
231233
return transaction.Transaction(self, isolation, readonly, deferrable)
232234

233-
def is_in_transaction(self):
235+
def is_in_transaction(self) -> bool:
234236
"""Return True if Connection is currently inside a transaction.
235237
236238
:return bool: True if inside transaction, False otherwise.
@@ -275,7 +277,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
275277
_, status, _ = await self._execute(query, args, 0, timeout, True)
276278
return status.decode()
277279

278-
async def executemany(self, command: str, args, *, timeout: float=None):
280+
async def executemany(self, command: str, args, *, timeout: float=None) -> None:
279281
"""Execute an SQL *command* for each sequence of arguments in *args*.
280282
281283
Example:
@@ -378,7 +380,7 @@ async def _introspect_types(self, typeoids, timeout):
378380
return await self.__execute(
379381
self._intro_query, (list(typeoids),), 0, timeout)
380382

381-
def cursor(self, query, *args, prefetch=None, timeout=None):
383+
def cursor(self, query, *args, prefetch=None, timeout=None) -> cursor.CursorFactory:
382384
"""Return a *cursor factory* for the specified query.
383385
384386
:param args: Query arguments.
@@ -392,7 +394,7 @@ def cursor(self, query, *args, prefetch=None, timeout=None):
392394
return cursor.CursorFactory(self, query, None, args,
393395
prefetch, timeout)
394396

395-
async def prepare(self, query, *, timeout=None):
397+
async def prepare(self, query, *, timeout=None) -> prepared_stmt.PreparedStatement:
396398
"""Create a *prepared statement* for the specified query.
397399
398400
:param str query: Text of the query to create a prepared statement for.
@@ -408,7 +410,7 @@ async def _prepare(self, query, *, timeout=None, use_cache: bool=False):
408410
use_cache=use_cache)
409411
return prepared_stmt.PreparedStatement(self, query, stmt)
410412

411-
async def fetch(self, query, *args, timeout=None) -> list:
413+
async def fetch(self, query, *args, timeout=None) -> typing.List[protocol.Record]:
412414
"""Run a query and return the results as a list of :class:`Record`.
413415
414416
:param str query: Query text.
@@ -420,7 +422,7 @@ async def fetch(self, query, *args, timeout=None) -> list:
420422
self._check_open()
421423
return await self._execute(query, args, 0, timeout)
422424

423-
async def fetchval(self, query, *args, column=0, timeout=None):
425+
async def fetchval(self, query, *args, column=0, timeout=None) -> typing.Any:
424426
"""Run a query and return a value in the first row.
425427
426428
:param str query: Query text.
@@ -441,7 +443,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
441443
return None
442444
return data[0][column]
443445

444-
async def fetchrow(self, query, *args, timeout=None):
446+
async def fetchrow(self, query, *args, timeout=None) -> typing.Optional[protocol.Record]:
445447
"""Run a query and return the first row.
446448
447449
:param str query: Query text
@@ -461,7 +463,7 @@ async def copy_from_table(self, table_name, *, output,
461463
columns=None, schema_name=None, timeout=None,
462464
format=None, oids=None, delimiter=None,
463465
null=None, header=None, quote=None,
464-
escape=None, force_quote=None, encoding=None):
466+
escape=None, force_quote=None, encoding=None) -> str:
465467
"""Copy table contents to a file or file-like object.
466468
467469
:param str table_name:
@@ -533,7 +535,7 @@ async def copy_from_query(self, query, *args, output,
533535
timeout=None, format=None, oids=None,
534536
delimiter=None, null=None, header=None,
535537
quote=None, escape=None, force_quote=None,
536-
encoding=None):
538+
encoding=None) -> str:
537539
"""Copy the results of a query to a file or file-like object.
538540
539541
:param str query:
@@ -597,7 +599,7 @@ async def copy_to_table(self, table_name, *, source,
597599
delimiter=None, null=None, header=None,
598600
quote=None, escape=None, force_quote=None,
599601
force_not_null=None, force_null=None,
600-
encoding=None):
602+
encoding=None) -> str:
601603
"""Copy data to the specified table.
602604
603605
:param str table_name:
@@ -668,7 +670,7 @@ async def copy_to_table(self, table_name, *, source,
668670

669671
async def copy_records_to_table(self, table_name, *, records,
670672
columns=None, schema_name=None,
671-
timeout=None):
673+
timeout=None) -> str:
672674
"""Copy a list of records to the specified table using binary COPY.
673675
674676
:param str table_name:
@@ -1060,7 +1062,7 @@ async def set_builtin_type_codec(self, typename, *,
10601062
# Statement cache is no longer valid due to codec changes.
10611063
self._drop_local_statement_cache()
10621064

1063-
def is_closed(self):
1065+
def is_closed(self) -> bool:
10641066
"""Return ``True`` if the connection is closed, ``False`` otherwise.
10651067
10661068
:return bool: ``True`` if the connection is closed, ``False``
@@ -1503,7 +1505,7 @@ async def connect(dsn=None, *,
15031505
command_timeout=None,
15041506
ssl=None,
15051507
connection_class=Connection,
1506-
server_settings=None):
1508+
server_settings=None) -> Connection:
15071509
r"""A coroutine to establish a connection to a PostgreSQL server.
15081510
15091511
The connection parameters may be specified either as a connection

asyncpg/cursor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77

88
import collections
9+
import typing
910

1011
from . import compat
1112
from . import connresource
1213
from . import exceptions
14+
from . import protocol
1315

1416

1517
class CursorFactory(connresource.ConnectionResource):
@@ -33,15 +35,15 @@ def __init__(self, connection, query, state, args, prefetch, timeout):
3335

3436
@compat.aiter_compat
3537
@connresource.guarded
36-
def __aiter__(self):
38+
def __aiter__(self) -> 'CursorIterator':
3739
prefetch = 50 if self._prefetch is None else self._prefetch
3840
return CursorIterator(self._connection,
3941
self._query, self._state,
4042
self._args, prefetch,
4143
self._timeout)
4244

4345
@connresource.guarded
44-
def __await__(self):
46+
def __await__(self) -> 'Cursor':
4547
if self._prefetch is not None:
4648
raise exceptions.InterfaceError(
4749
'prefetch argument can only be specified for iterable cursor')
@@ -164,11 +166,11 @@ def __init__(self, connection, query, state, args, prefetch, timeout):
164166

165167
@compat.aiter_compat
166168
@connresource.guarded
167-
def __aiter__(self):
169+
def __aiter__(self) -> 'CursorIterator':
168170
return self
169171

170172
@connresource.guarded
171-
async def __anext__(self):
173+
async def __anext__(self) -> protocol.Record:
172174
if self._state is None:
173175
self._state = await self._connection._get_statement(
174176
self._query, self._timeout, named=True)
@@ -203,7 +205,7 @@ async def _init(self, timeout):
203205
return self
204206

205207
@connresource.guarded
206-
async def fetch(self, n, *, timeout=None):
208+
async def fetch(self, n, *, timeout=None) -> typing.List[protocol.Record]:
207209
r"""Return the next *n* rows as a list of :class:`Record` objects.
208210
209211
:param float timeout: Optional timeout value in seconds.
@@ -221,7 +223,7 @@ async def fetch(self, n, *, timeout=None):
221223
return recs
222224

223225
@connresource.guarded
224-
async def fetchrow(self, *, timeout=None):
226+
async def fetchrow(self, *, timeout=None) -> protocol.Record:
225227
r"""Return the next row.
226228
227229
:param float timeout: Optional timeout value in seconds.

asyncpg/pool.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import inspect
1111
import logging
1212
import time
13+
import typing
1314
import warnings
1415

1516
from . import connection
1617
from . import connect_utils
1718
from . import exceptions
19+
from . import protocol
1820

1921

2022
logger = logging.getLogger(__name__)
@@ -508,7 +510,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
508510
async with self.acquire() as con:
509511
return await con.execute(query, *args, timeout=timeout)
510512

511-
async def executemany(self, command: str, args, *, timeout: float=None):
513+
async def executemany(self, command: str, args, *, timeout: float=None) -> None:
512514
"""Execute an SQL *command* for each sequence of arguments in *args*.
513515
514516
Pool performs this operation using one of its connections. Other than
@@ -520,7 +522,7 @@ async def executemany(self, command: str, args, *, timeout: float=None):
520522
async with self.acquire() as con:
521523
return await con.executemany(command, args, timeout=timeout)
522524

523-
async def fetch(self, query, *args, timeout=None) -> list:
525+
async def fetch(self, query, *args, timeout=None) -> typing.List[protocol.Record]:
524526
"""Run a query and return the results as a list of :class:`Record`.
525527
526528
Pool performs this operation using one of its connections. Other than
@@ -532,7 +534,7 @@ async def fetch(self, query, *args, timeout=None) -> list:
532534
async with self.acquire() as con:
533535
return await con.fetch(query, *args, timeout=timeout)
534536

535-
async def fetchval(self, query, *args, column=0, timeout=None):
537+
async def fetchval(self, query, *args, column=0, timeout=None) -> typing.Any:
536538
"""Run a query and return a value in the first row.
537539
538540
Pool performs this operation using one of its connections. Other than
@@ -545,7 +547,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
545547
return await con.fetchval(
546548
query, *args, column=column, timeout=timeout)
547549

548-
async def fetchrow(self, query, *args, timeout=None):
550+
async def fetchrow(self, query, *args, timeout=None) -> typing.Optional[protocol.Record]:
549551
"""Run a query and return the first row.
550552
551553
Pool performs this operation using one of its connections. Other than
@@ -557,7 +559,7 @@ async def fetchrow(self, query, *args, timeout=None):
557559
async with self.acquire() as con:
558560
return await con.fetchrow(query, *args, timeout=timeout)
559561

560-
def acquire(self, *, timeout=None):
562+
def acquire(self, *, timeout=None) -> connection.Connection:
561563
"""Acquire a database connection from the pool.
562564
563565
:param float timeout: A timeout for acquiring a Connection.
@@ -784,7 +786,7 @@ def create_pool(dsn=None, *,
784786
init=None,
785787
loop=None,
786788
connection_class=connection.Connection,
787-
**connect_kwargs):
789+
**connect_kwargs) -> Pool:
788790
r"""Create a connection pool.
789791
790792
Can be used either with an ``async with`` block:

asyncpg/prepared_stmt.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66

77

88
import json
9+
import typing
910

1011
from . import connresource
1112
from . import cursor
1213
from . import exceptions
14+
from . import protocol
15+
from . import types
1316

1417

1518
class PreparedStatement(connresource.ConnectionResource):
@@ -50,7 +53,7 @@ def get_statusmsg(self) -> str:
5053
return self._last_status.decode()
5154

5255
@connresource.guarded
53-
def get_parameters(self):
56+
def get_parameters(self) -> typing.Tuple[types.Type, ...]:
5457
"""Return a description of statement parameters types.
5558
5659
:return: A tuple of :class:`asyncpg.types.Type`.
@@ -67,7 +70,7 @@ def get_parameters(self):
6770
return self._state._get_parameters()
6871

6972
@connresource.guarded
70-
def get_attributes(self):
73+
def get_attributes(self) -> typing.Tuple[types.Attribute, ...]:
7174
"""Return a description of relation attributes (columns).
7275
7376
:return: A tuple of :class:`asyncpg.types.Attribute`.
@@ -108,7 +111,7 @@ def cursor(self, *args, prefetch=None,
108111
timeout)
109112

110113
@connresource.guarded
111-
async def explain(self, *args, analyze=False):
114+
async def explain(self, *args, analyze=False) -> typing.Dict:
112115
"""Return the execution plan of the statement.
113116
114117
:param args: Query arguments.
@@ -150,7 +153,7 @@ async def explain(self, *args, analyze=False):
150153
return json.loads(data)
151154

152155
@connresource.guarded
153-
async def fetch(self, *args, timeout=None):
156+
async def fetch(self, *args, timeout=None) -> typing.List[protocol.Record]:
154157
r"""Execute the statement and return a list of :class:`Record` objects.
155158
156159
:param str query: Query text
@@ -163,7 +166,7 @@ async def fetch(self, *args, timeout=None):
163166
return data
164167

165168
@connresource.guarded
166-
async def fetchval(self, *args, column=0, timeout=None):
169+
async def fetchval(self, *args, column=0, timeout=None) -> typing.Any:
167170
"""Execute the statement and return a value in the first row.
168171
169172
:param args: Query arguments.
@@ -182,7 +185,7 @@ async def fetchval(self, *args, column=0, timeout=None):
182185
return data[0][column]
183186

184187
@connresource.guarded
185-
async def fetchrow(self, *args, timeout=None):
188+
async def fetchrow(self, *args, timeout=None) -> typing.Optional[protocol.Record]:
186189
"""Execute the statement and return the first row.
187190
188191
:param str query: Query text

asyncpg/protocol/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8-
from .protocol import Protocol, Record, NO_TIMEOUT # NOQA
8+
from .protocol import Protocol, Record, ConnectionSettings, NO_TIMEOUT # NOQA

0 commit comments

Comments
 (0)