Skip to content

Commit 292290d

Browse files
committed
Clean up types
1 parent 73c50c4 commit 292290d

File tree

8 files changed

+179
-75
lines changed

8 files changed

+179
-75
lines changed

asyncpg/cluster.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ class ClusterError(Exception):
6868

6969

7070
class Cluster:
71+
_data_dir: str
72+
_pg_config_path: typing.Optional[str]
73+
_pg_bin_dir: typing.Optional[str]
74+
_pg_ctl: typing.Optional[str]
75+
_daemon_pid: typing.Optional[int]
76+
_daemon_process: typing.Optional['subprocess.Popen[bytes]']
77+
_connection_addr: typing.Optional[_ConnectionSpec]
78+
_connection_spec_override: typing.Optional[_ConnectionSpec]
79+
7180
def __init__(self, data_dir: str, *,
7281
pg_config_path: typing.Optional[str] = None) -> None:
7382
self._data_dir = data_dir
@@ -76,11 +85,11 @@ def __init__(self, data_dir: str, *,
7685
os.environ.get('PGINSTALLATION')
7786
or os.environ.get('PGBIN')
7887
)
79-
self._pg_ctl: typing.Optional[str] = None
80-
self._daemon_pid: typing.Optional[int] = None
81-
self._daemon_process: typing.Optional[subprocess.Popen[bytes]] = None
82-
self._connection_addr: typing.Optional[_ConnectionSpec] = None
83-
self._connection_spec_override: typing.Optional[_ConnectionSpec] = None
88+
self._pg_ctl = None
89+
self._daemon_pid = None
90+
self._daemon_process = None
91+
self._connection_addr = None
92+
self._connection_spec_override = None
8493

8594
def get_pg_version(self) -> 'types.ServerVersion':
8695
return self._pg_version
@@ -653,6 +662,9 @@ def __init__(self, *,
653662

654663

655664
class HotStandbyCluster(TempCluster):
665+
_master: _ConnectionSpec
666+
_repl_user: str
667+
656668
def __init__(self, *,
657669
master: _ConnectionSpec, replication_user: str,
658670
data_dir_suffix: typing.Optional[str] = None,
@@ -739,16 +751,16 @@ def get_status(self) -> str:
739751
return 'running'
740752

741753
def init(self, **settings: str) -> str:
742-
pass
754+
...
743755

744756
def start(self, wait: int = 60, **settings: typing.Any) -> None:
745-
pass
757+
...
746758

747759
def stop(self, wait: int = 60) -> None:
748-
pass
760+
...
749761

750762
def destroy(self) -> None:
751-
pass
763+
...
752764

753765
def reset_hba(self) -> None:
754766
raise ClusterError('cannot modify HBA records of unmanaged cluster')

asyncpg/connection.py

+55-24
Original file line numberDiff line numberDiff line change
@@ -157,16 +157,41 @@ class Connection(typing.Generic[_Record], metaclass=ConnectionMeta):
157157
'_log_listeners', '_termination_listeners', '_cancellations',
158158
'_source_traceback', '__weakref__')
159159

160+
_protocol: '_cprotocol.BaseProtocol[_Record]'
161+
_transport: typing.Any
162+
_loop: asyncio.AbstractEventLoop
163+
_top_xact: typing.Optional[transaction.Transaction]
164+
_aborted: bool
165+
_pool_release_ctr: int
166+
_stmt_cache: '_StatementCache'
167+
_stmts_to_close: typing.Set[
168+
'_cprotocol.PreparedStatementState[typing.Any]'
169+
]
170+
_listeners: typing.Dict[str, typing.Set['_Callback']]
171+
_server_version: types.ServerVersion
172+
_server_caps: 'ServerCapabilities'
173+
_intro_query: str
174+
_reset_query: typing.Optional[str]
175+
_proxy: typing.Optional['_pool.PoolConnectionProxy[typing.Any]']
176+
_stmt_exclusive_section: '_Atomic'
177+
_config: connect_utils._ClientConfiguration
178+
_params: connect_utils._ConnectionParameters
179+
_addr: typing.Union[typing.Tuple[str, int], str]
180+
_log_listeners: typing.Set['_Callback']
181+
_termination_listeners: typing.Set['_Callback']
182+
_cancellations: typing.Set['asyncio.Task[typing.Any]']
183+
_source_traceback: typing.Optional[str]
184+
160185
def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
161186
transport: typing.Any,
162187
loop: asyncio.AbstractEventLoop,
163188
addr: typing.Union[typing.Tuple[str, int], str],
164189
config: connect_utils._ClientConfiguration,
165190
params: connect_utils._ConnectionParameters) -> None:
166-
self._protocol: '_cprotocol.BaseProtocol[_Record]' = protocol
191+
self._protocol = protocol
167192
self._transport = transport
168193
self._loop = loop
169-
self._top_xact: typing.Optional[transaction.Transaction] = None
194+
self._top_xact = None
170195
self._aborted = False
171196
# Incremented every time the connection is released back to a pool.
172197
# Used to catch invalid references to connection-related resources
@@ -184,14 +209,12 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
184209
_weak_maybe_gc_stmt, weakref.ref(self)),
185210
max_lifetime=config.max_cached_statement_lifetime)
186211

187-
self._stmts_to_close: typing.Set[
188-
'_cprotocol.PreparedStatementState[typing.Any]'
189-
] = set()
212+
self._stmts_to_close = set()
190213

191-
self._listeners: typing.Dict[str, typing.Set[_Callback]] = {}
192-
self._log_listeners: typing.Set[_Callback] = set()
193-
self._cancellations: typing.Set[asyncio.Task[typing.Any]] = set()
194-
self._termination_listeners: typing.Set[_Callback] = set()
214+
self._listeners = {}
215+
self._log_listeners = set()
216+
self._cancellations = set()
217+
self._termination_listeners = set()
195218

196219
settings = self._protocol.get_settings()
197220
ver_string = settings.server_version
@@ -206,10 +229,8 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
206229
else:
207230
self._intro_query = introspection.INTRO_LOOKUP_TYPES
208231

209-
self._reset_query: typing.Optional[str] = None
210-
self._proxy: typing.Optional[
211-
'_pool.PoolConnectionProxy[typing.Any]'
212-
] = None
232+
self._reset_query = None
233+
self._proxy = None
213234

214235
# Used to serialize operations that might involve anonymous
215236
# statements. Specifically, we want to make the following
@@ -221,7 +242,7 @@ def __init__(self, protocol: '_cprotocol.BaseProtocol[_Record]',
221242
self._stmt_exclusive_section = _Atomic()
222243

223244
if loop.get_debug():
224-
self._source_traceback: typing.Optional[str] = _extract_stack()
245+
self._source_traceback = _extract_stack()
225246
else:
226247
self._source_traceback = None
227248

@@ -2007,7 +2028,7 @@ def _set_proxy(
20072028
self._proxy = proxy
20082029

20092030
def _check_listeners(self,
2010-
listeners: 'typing.Sized',
2031+
listeners: typing.Sized,
20112032
listener_type: str) -> None:
20122033
if listeners:
20132034
count = len(listeners)
@@ -2927,6 +2948,11 @@ class _StatementCacheEntry(typing.Generic[_Record]):
29272948

29282949
__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')
29292950

2951+
_query: _StatementCacheKey[_Record]
2952+
_statement: '_cprotocol.PreparedStatementState[_Record]'
2953+
_cache: '_StatementCache'
2954+
_cleanup_cb: typing.Optional[asyncio.TimerHandle]
2955+
29302956
def __init__(
29312957
self,
29322958
cache: '_StatementCache',
@@ -2936,21 +2962,27 @@ def __init__(
29362962
self._cache = cache
29372963
self._query = query
29382964
self._statement = statement
2939-
self._cleanup_cb: typing.Optional[asyncio.TimerHandle] = None
2965+
self._cleanup_cb = None
29402966

29412967

29422968
class _StatementCache:
29432969

29442970
__slots__ = ('_loop', '_entries', '_max_size', '_on_remove',
29452971
'_max_lifetime')
29462972

2973+
_loop: asyncio.AbstractEventLoop
2974+
_entries: 'collections.OrderedDict[_StatementCacheKey[typing.Any], _StatementCacheEntry[typing.Any]]' # noqa: E501
2975+
_max_size: int
2976+
_on_remove: OnRemove[typing.Any]
2977+
_max_lifetime: float
2978+
29472979
def __init__(self, *, loop: asyncio.AbstractEventLoop,
29482980
max_size: int, on_remove: OnRemove[typing.Any],
29492981
max_lifetime: float) -> None:
2950-
self._loop: asyncio.AbstractEventLoop = loop
2951-
self._max_size: int = max_size
2952-
self._on_remove: OnRemove[typing.Any] = on_remove
2953-
self._max_lifetime: float = max_lifetime
2982+
self._loop = loop
2983+
self._max_size = max_size
2984+
self._on_remove = on_remove
2985+
self._max_lifetime = max_lifetime
29542986

29552987
# We use an OrderedDict for LRU implementation. Operations:
29562988
#
@@ -2969,10 +3001,7 @@ def __init__(self, *, loop: asyncio.AbstractEventLoop,
29693001
# So new entries and hits are always promoted to the end of the
29703002
# entries dict, whereas the unused one will group in the
29713003
# beginning of it.
2972-
self._entries: collections.OrderedDict[
2973-
_StatementCacheKey[typing.Any],
2974-
_StatementCacheEntry[typing.Any]
2975-
] = collections.OrderedDict()
3004+
self._entries = collections.OrderedDict()
29763005

29773006
def __len__(self) -> int:
29783007
return len(self._entries)
@@ -3148,6 +3177,8 @@ def from_callable(
31483177
class _Atomic:
31493178
__slots__ = ('_acquired',)
31503179

3180+
_acquired: int
3181+
31513182
def __init__(self) -> None:
31523183
self._acquired = 0
31533184

asyncpg/connresource.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
if typing.TYPE_CHECKING:
16-
from . import connection as _connection
16+
from . import connection as _conn
1717

1818

1919
_Callable = typing.TypeVar('_Callable', bound=typing.Callable[..., typing.Any])
@@ -35,8 +35,11 @@ def _check(self: 'ConnectionResource',
3535
class ConnectionResource:
3636
__slots__ = ('_connection', '_con_release_ctr')
3737

38+
_connection: '_conn.Connection[typing.Any]'
39+
_con_release_ctr: int
40+
3841
def __init__(
39-
self, connection: '_connection.Connection[typing.Any]'
42+
self, connection: '_conn.Connection[typing.Any]'
4043
) -> None:
4144
self._connection = connection
4245
self._con_release_ctr = connection._pool_release_ctr

asyncpg/cursor.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,19 @@ class CursorFactory(connresource.ConnectionResource, typing.Generic[_Record]):
3939
'_record_class',
4040
)
4141

42+
_state: typing.Optional['_cprotocol.PreparedStatementState[_Record]']
43+
_args: typing.Sequence[typing.Any]
44+
_prefetch: typing.Optional[int]
45+
_query: str
46+
_timeout: typing.Optional[float]
47+
_record_class: typing.Optional[typing.Type[_Record]]
48+
4249
@typing.overload
4350
def __init__(
4451
self: 'CursorFactory[_Record]',
4552
connection: '_connection.Connection[_Record]',
4653
query: str,
47-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
54+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
4855
args: typing.Sequence[typing.Any],
4956
prefetch: typing.Optional[int],
5057
timeout: typing.Optional[float],
@@ -57,7 +64,7 @@ def __init__(
5764
self: 'CursorFactory[_Record]',
5865
connection: '_connection.Connection[typing.Any]',
5966
query: str,
60-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
67+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
6168
args: typing.Sequence[typing.Any],
6269
prefetch: typing.Optional[int],
6370
timeout: typing.Optional[float],
@@ -69,7 +76,7 @@ def __init__(
6976
self,
7077
connection: '_connection.Connection[typing.Any]',
7178
query: str,
72-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
79+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
7380
args: typing.Sequence[typing.Any],
7481
prefetch: typing.Optional[int],
7582
timeout: typing.Optional[float],
@@ -130,12 +137,19 @@ class BaseCursor(connresource.ConnectionResource, typing.Generic[_Record]):
130137
'_record_class',
131138
)
132139

140+
_state: typing.Optional['_cprotocol.PreparedStatementState[_Record]']
141+
_args: typing.Sequence[typing.Any]
142+
_portal_name: typing.Optional[str]
143+
_exhausted: bool
144+
_query: str
145+
_record_class: typing.Optional[typing.Type[_Record]]
146+
133147
@typing.overload
134148
def __init__(
135149
self: 'BaseCursor[_Record]',
136150
connection: '_connection.Connection[_Record]',
137151
query: str,
138-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
152+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
139153
args: typing.Sequence[typing.Any],
140154
record_class: None
141155
) -> None:
@@ -146,7 +160,7 @@ def __init__(
146160
self: 'BaseCursor[_Record]',
147161
connection: '_connection.Connection[typing.Any]',
148162
query: str,
149-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
163+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
150164
args: typing.Sequence[typing.Any],
151165
record_class: typing.Type[_Record]
152166
) -> None:
@@ -156,7 +170,7 @@ def __init__(
156170
self,
157171
connection: '_connection.Connection[typing.Any]',
158172
query: str,
159-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
173+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
160174
args: typing.Sequence[typing.Any],
161175
record_class: typing.Optional[typing.Type[_Record]]
162176
) -> None:
@@ -165,7 +179,7 @@ def __init__(
165179
self._state = state
166180
if state is not None:
167181
state.attach()
168-
self._portal_name: typing.Optional[str] = None
182+
self._portal_name = None
169183
self._exhausted = False
170184
self._query = query
171185
self._record_class = record_class
@@ -260,12 +274,16 @@ class CursorIterator(BaseCursor[_Record]):
260274

261275
__slots__ = ('_buffer', '_prefetch', '_timeout')
262276

277+
_buffer: typing.Deque[_Record]
278+
_prefetch: int
279+
_timeout: typing.Optional[float]
280+
263281
@typing.overload
264282
def __init__(
265283
self: 'CursorIterator[_Record]',
266284
connection: '_connection.Connection[_Record]',
267285
query: str,
268-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
286+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
269287
args: typing.Sequence[typing.Any],
270288
record_class: None,
271289
prefetch: int,
@@ -278,7 +296,7 @@ def __init__(
278296
self: 'CursorIterator[_Record]',
279297
connection: '_connection.Connection[typing.Any]',
280298
query: str,
281-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
299+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
282300
args: typing.Sequence[typing.Any],
283301
record_class: typing.Type[_Record],
284302
prefetch: int,
@@ -290,7 +308,7 @@ def __init__(
290308
self,
291309
connection: '_connection.Connection[typing.Any]',
292310
query: str,
293-
state: 'typing.Optional[_cprotocol.PreparedStatementState[_Record]]',
311+
state: typing.Optional['_cprotocol.PreparedStatementState[_Record]'],
294312
args: typing.Sequence[typing.Any],
295313
record_class: typing.Optional[typing.Type[_Record]],
296314
prefetch: int,
@@ -302,7 +320,7 @@ def __init__(
302320
raise exceptions.InterfaceError(
303321
'prefetch argument must be greater than zero')
304322

305-
self._buffer: typing.Deque[_Record] = collections.deque()
323+
self._buffer = collections.deque()
306324
self._prefetch = prefetch
307325
self._timeout = timeout
308326

0 commit comments

Comments
 (0)