Skip to content

Commit 96c0a4c

Browse files
committed
Tweaks for pyright and mypy
1 parent 292290d commit 96c0a4c

11 files changed

+272
-341
lines changed

asyncpg/cluster.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from asyncpg import exceptions
3232

3333
if typing.TYPE_CHECKING:
34+
import _typeshed
3435
from . import connection
3536
from . import types
3637

@@ -652,8 +653,7 @@ class TempCluster(Cluster):
652653
def __init__(self, *,
653654
data_dir_suffix: typing.Optional[str] = None,
654655
data_dir_prefix: typing.Optional[str] = None,
655-
data_dir_parent: typing.Optional[
656-
'tempfile._DirT[str]'] = None,
656+
data_dir_parent: typing.Optional['_typeshed.StrPath'] = None,
657657
pg_config_path: typing.Optional[str] = None) -> None:
658658
self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
659659
prefix=data_dir_prefix,
@@ -669,8 +669,7 @@ def __init__(self, *,
669669
master: _ConnectionSpec, replication_user: str,
670670
data_dir_suffix: typing.Optional[str] = None,
671671
data_dir_prefix: typing.Optional[str] = None,
672-
data_dir_parent: typing.Optional[
673-
'tempfile._DirT[str]'] = None,
672+
data_dir_parent: typing.Optional['_typeshed.StrPath'] = None,
674673
pg_config_path: typing.Optional[str] = None) -> None:
675674
self._master = master
676675
self._repl_user = replication_user

asyncpg/connect_utils.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import enum
1010
import functools
1111
import getpass
12+
import inspect
1213
import os
1314
import pathlib
1415
import platform
@@ -268,7 +269,9 @@ def _parse_hostlist(hostlist: str,
268269
hostspec_port = urllib.parse.unquote(hostspec_port)
269270
hostlist_ports.append(int(hostspec_port))
270271
else:
271-
hostlist_ports.append(default_port[i])
272+
hostlist_ports.append(
273+
default_port[i] # pyright: ignore [reportUnknownArgumentType, reportUnboundVariable] # noqa: E501
274+
)
272275

273276
if not ports:
274277
ports = hostlist_ports
@@ -892,7 +895,7 @@ async def _connect_addr(
892895
if inspect.isawaitable(password):
893896
password = await password
894897

895-
params = params._replace(password=password)
898+
params = params._replace(password=typing.cast(str, password))
896899
args = (addr, loop, config, connection_class, record_class, params_input)
897900

898901
# prepare the params (which attempt has ssl) for the 2 attempts
@@ -954,8 +957,13 @@ async def __connect_addr(
954957
elif params.ssl and params.direct_tls:
955958
# if ssl and direct_tls are given, skip STARTTLS and perform direct
956959
# SSL connection
957-
connector = loop.create_connection(
958-
proto_factory, *addr, ssl=params.ssl
960+
connector = typing.cast(
961+
typing.Coroutine[
962+
typing.Any, None, _TPTupleType['protocol.Protocol[_Record]']
963+
],
964+
loop.create_connection(
965+
proto_factory, *addr, ssl=params.ssl
966+
)
959967
)
960968

961969
elif params.ssl:

0 commit comments

Comments
 (0)