|
58 | 58 | ]
|
59 | 59 | SSLType = typing.Union[_ParsedSSLType, SSLStringValues, bool]
|
60 | 60 | HostType = typing.Union[typing.List[str], str]
|
61 |
| -PortType = typing.Union[typing.List[int], int] |
| 61 | +PortListType = typing.Union[ |
| 62 | + typing.List[typing.Union[int, str]], |
| 63 | + typing.List[int], |
| 64 | + typing.List[str], |
| 65 | +] |
| 66 | +PortType = typing.Union[ |
| 67 | + PortListType, |
| 68 | + int, |
| 69 | + str |
| 70 | +] |
62 | 71 |
|
63 | 72 |
|
64 | 73 | class SSLMode(enum.IntEnum):
|
@@ -192,26 +201,42 @@ def _read_password_from_pgpass(
|
192 | 201 | return None
|
193 | 202 |
|
194 | 203 |
|
195 |
| -def _validate_port_spec(hosts: typing.List[str], |
196 |
| - port: PortType) \ |
197 |
| - -> typing.List[int]: |
| 204 | +@typing.overload |
| 205 | +def _validate_port_spec( |
| 206 | + hosts: typing.List[str], |
| 207 | + port: PortListType |
| 208 | +) -> typing.List[int]: |
| 209 | + ... |
| 210 | + |
| 211 | + |
| 212 | +@typing.overload |
| 213 | +def _validate_port_spec( |
| 214 | + hosts: typing.List[str], |
| 215 | + port: typing.Union[int, str] |
| 216 | +) -> typing.List[int]: |
| 217 | + ... |
| 218 | + |
| 219 | + |
| 220 | +def _validate_port_spec( |
| 221 | + hosts: typing.List[str], |
| 222 | + port: PortType |
| 223 | +) -> typing.List[int]: |
198 | 224 | if isinstance(port, list):
|
199 | 225 | # If there is a list of ports, its length must
|
200 | 226 | # match that of the host list.
|
201 | 227 | if len(port) != len(hosts):
|
202 | 228 | raise exceptions.InterfaceError(
|
203 | 229 | 'could not match {} port numbers to {} hosts'.format(
|
204 | 230 | len(port), len(hosts)))
|
| 231 | + return [int(p) for p in port] |
205 | 232 | else:
|
206 |
| - port = [port for _ in range(len(hosts))] |
207 |
| - |
208 |
| - return port |
| 233 | + return [int(port) for _ in range(len(hosts))] |
209 | 234 |
|
210 | 235 |
|
211 | 236 | def _parse_hostlist(hostlist: str,
|
212 | 237 | port: typing.Optional[PortType],
|
213 | 238 | *, unquote: bool = False) \
|
214 |
| - -> typing.Tuple[typing.List[str], typing.List[int]]: |
| 239 | + -> typing.Tuple[typing.List[str], PortListType]: |
215 | 240 | if ',' in hostlist:
|
216 | 241 | # A comma-separated list of host addresses.
|
217 | 242 | hostspecs = hostlist.split(',')
|
@@ -242,7 +267,7 @@ def _parse_hostlist(hostlist: str,
|
242 | 267 | if hostspec[0] == '/':
|
243 | 268 | # Unix socket
|
244 | 269 | addr = hostspec
|
245 |
| - hostspec_port = '' |
| 270 | + hostspec_port: str = '' |
246 | 271 | elif hostspec[0] == '[':
|
247 | 272 | # IPv6 address
|
248 | 273 | m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
|
@@ -470,13 +495,10 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
|
470 | 495 | else:
|
471 | 496 | port = 5432
|
472 | 497 |
|
473 |
| - elif isinstance(port, (list, tuple)): |
474 |
| - port = [int(p) for p in port] |
475 |
| - |
476 |
| - else: |
| 498 | + elif not isinstance(port, (list, tuple)): |
477 | 499 | port = int(port)
|
478 | 500 |
|
479 |
| - port = _validate_port_spec(host, port) |
| 501 | + validated_ports = _validate_port_spec(host, port) |
480 | 502 |
|
481 | 503 | if user is None:
|
482 | 504 | user = os.getenv('PGUSER')
|
@@ -517,13 +539,13 @@ def _parse_connect_dsn_and_args(*, dsn: typing.Optional[str],
|
517 | 539 |
|
518 | 540 | if passfile_path is not None:
|
519 | 541 | password = _read_password_from_pgpass(
|
520 |
| - hosts=auth_hosts, ports=port, |
| 542 | + hosts=auth_hosts, ports=validated_ports, |
521 | 543 | database=database, user=user,
|
522 | 544 | passfile=passfile_path)
|
523 | 545 |
|
524 | 546 | addrs: typing.List[AddrType] = []
|
525 | 547 | have_tcp_addrs = False
|
526 |
| - for h, p in zip(host, port): |
| 548 | + for h, p in zip(host, validated_ports): |
527 | 549 | if h.startswith('/'):
|
528 | 550 | # UNIX socket name
|
529 | 551 | if '.s.PGSQL.' not in h:
|
|
0 commit comments