Skip to content

Commit 8951230

Browse files
committed
Improve typing for connect(port=)
1 parent 96c0a4c commit 8951230

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

asyncpg/cluster.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,14 @@ def _test_connection(self, timeout: int = 60) -> str:
543543
try:
544544
con: 'connection.Connection[typing.Any]' = \
545545
loop.run_until_complete(
546-
asyncpg.connect( # type: ignore[arg-type] # noqa: E501
546+
asyncpg.connect(
547547
database='postgres',
548548
user='postgres',
549549
timeout=5, loop=loop,
550-
**self._connection_addr
550+
**typing.cast(
551+
_ConnectionSpec,
552+
self._connection_addr
553+
)
551554
)
552555
)
553556
except (OSError, asyncio.TimeoutError,

asyncpg/connect_utils.py

+38-16
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,16 @@
5858
]
5959
SSLType = typing.Union[_ParsedSSLType, SSLStringValues, bool]
6060
HostType = typing.Union[typing.List[str], str]
61-
PortType = typing.Union[typing.List[int], int]
61+
PortListType = typing.Union[
62+
typing.List[typing.Union[int, str]],
63+
typing.List[int],
64+
typing.List[str],
65+
]
66+
PortType = typing.Union[
67+
PortListType,
68+
int,
69+
str
70+
]
6271

6372

6473
class SSLMode(enum.IntEnum):
@@ -192,26 +201,42 @@ def _read_password_from_pgpass(
192201
return None
193202

194203

195-
def _validate_port_spec(hosts: typing.List[str],
196-
port: PortType) \
197-
-> typing.List[int]:
204+
@typing.overload
205+
def _validate_port_spec(
206+
hosts: typing.List[str],
207+
port: PortListType
208+
) -> typing.List[int]:
209+
...
210+
211+
212+
@typing.overload
213+
def _validate_port_spec(
214+
hosts: typing.List[str],
215+
port: typing.Union[int, str]
216+
) -> typing.List[int]:
217+
...
218+
219+
220+
def _validate_port_spec(
221+
hosts: typing.List[str],
222+
port: PortType
223+
) -> typing.List[int]:
198224
if isinstance(port, list):
199225
# If there is a list of ports, its length must
200226
# match that of the host list.
201227
if len(port) != len(hosts):
202228
raise exceptions.InterfaceError(
203229
'could not match {} port numbers to {} hosts'.format(
204230
len(port), len(hosts)))
231+
return [int(p) for p in port]
205232
else:
206-
port = [port for _ in range(len(hosts))]
207-
208-
return port
233+
return [int(port) for _ in range(len(hosts))]
209234

210235

211236
def _parse_hostlist(hostlist: str,
212237
port: typing.Optional[PortType],
213238
*, unquote: bool = False) \
214-
-> typing.Tuple[typing.List[str], typing.List[int]]:
239+
-> typing.Tuple[typing.List[str], PortListType]:
215240
if ',' in hostlist:
216241
# A comma-separated list of host addresses.
217242
hostspecs = hostlist.split(',')
@@ -242,7 +267,7 @@ def _parse_hostlist(hostlist: str,
242267
if hostspec[0] == '/':
243268
# Unix socket
244269
addr = hostspec
245-
hostspec_port = ''
270+
hostspec_port: str = ''
246271
elif hostspec[0] == '[':
247272
# IPv6 address
248273
m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
@@ -470,13 +495,10 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
470495
else:
471496
port = 5432
472497

473-
elif isinstance(port, (list, tuple)):
474-
port = [int(p) for p in port]
475-
476-
else:
498+
elif not isinstance(port, (list, tuple)):
477499
port = int(port)
478500

479-
port = _validate_port_spec(host, port)
501+
validated_ports = _validate_port_spec(host, port)
480502

481503
if user is None:
482504
user = os.getenv('PGUSER')
@@ -517,13 +539,13 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
517539

518540
if passfile_path is not None:
519541
password = _read_password_from_pgpass(
520-
hosts=auth_hosts, ports=port,
542+
hosts=auth_hosts, ports=validated_ports,
521543
database=database, user=user,
522544
passfile=passfile_path)
523545

524546
addrs: typing.List[AddrType] = []
525547
have_tcp_addrs = False
526-
for h, p in zip(host, port):
548+
for h, p in zip(host, validated_ports):
527549
if h.startswith('/'):
528550
# UNIX socket name
529551
if '.s.PGSQL.' not in h:

0 commit comments

Comments
 (0)