Skip to content

Commit 075114c

Browse files
authored
Add sslmode=allow support and fix =prefer retry (#720)
We didn't really retry the connection without SSL if the first SSL connection fails under sslmode=prefer, that led to an issue when the server has SSL support but explicitly denies SSL connection through pg_hba.conf. This commit adds a retry in a new connection, which makes it easy to implement the sslmode=allow retry. Fixes #716
1 parent 93a238c commit 075114c

File tree

5 files changed

+314
-57
lines changed

5 files changed

+314
-57
lines changed

asyncpg/connect_utils.py

+110-38
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import asyncio
99
import collections
10+
import enum
1011
import functools
1112
import getpass
1213
import os
@@ -28,14 +29,29 @@
2829
from . import protocol
2930

3031

32+
class SSLMode(enum.IntEnum):
33+
disable = 0
34+
allow = 1
35+
prefer = 2
36+
require = 3
37+
verify_ca = 4
38+
verify_full = 5
39+
40+
@classmethod
41+
def parse(cls, sslmode):
42+
if isinstance(sslmode, cls):
43+
return sslmode
44+
return getattr(cls, sslmode.replace('-', '_'))
45+
46+
3147
_ConnectionParameters = collections.namedtuple(
3248
'ConnectionParameters',
3349
[
3450
'user',
3551
'password',
3652
'database',
3753
'ssl',
38-
'ssl_is_advisory',
54+
'sslmode',
3955
'connect_timeout',
4056
'server_settings',
4157
])
@@ -402,46 +418,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402418
if ssl is None and have_tcp_addrs:
403419
ssl = 'prefer'
404420

405-
# ssl_is_advisory is only allowed to come from the sslmode parameter.
406-
ssl_is_advisory = None
407-
if isinstance(ssl, str):
408-
SSLMODES = {
409-
'disable': 0,
410-
'allow': 1,
411-
'prefer': 2,
412-
'require': 3,
413-
'verify-ca': 4,
414-
'verify-full': 5,
415-
}
421+
if isinstance(ssl, (str, SSLMode)):
416422
try:
417-
sslmode = SSLMODES[ssl]
418-
except KeyError:
419-
modes = ', '.join(SSLMODES.keys())
423+
sslmode = SSLMode.parse(ssl)
424+
except AttributeError:
425+
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
420426
raise exceptions.InterfaceError(
421427
'`sslmode` parameter must be one of: {}'.format(modes))
422428

423-
# sslmode 'allow' is currently handled as 'prefer' because we're
424-
# missing the "retry with SSL" behavior for 'allow', but do have the
425-
# "retry without SSL" behavior for 'prefer'.
426-
# Not changing 'allow' to 'prefer' here would be effectively the same
427-
# as changing 'allow' to 'disable'.
428-
if sslmode == SSLMODES['allow']:
429-
sslmode = SSLMODES['prefer']
430-
431429
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
432430
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
433-
if sslmode <= SSLMODES['allow']:
431+
if sslmode < SSLMode.allow:
434432
ssl = False
435-
ssl_is_advisory = sslmode >= SSLMODES['allow']
436433
else:
437434
ssl = ssl_module.create_default_context()
438-
ssl.check_hostname = sslmode >= SSLMODES['verify-full']
435+
ssl.check_hostname = sslmode >= SSLMode.verify_full
439436
ssl.verify_mode = ssl_module.CERT_REQUIRED
440-
if sslmode <= SSLMODES['require']:
437+
if sslmode <= SSLMode.require:
441438
ssl.verify_mode = ssl_module.CERT_NONE
442-
ssl_is_advisory = sslmode <= SSLMODES['prefer']
443439
elif ssl is True:
444440
ssl = ssl_module.create_default_context()
441+
sslmode = SSLMode.verify_full
442+
else:
443+
sslmode = SSLMode.disable
445444

446445
if server_settings is not None and (
447446
not isinstance(server_settings, dict) or
@@ -453,7 +452,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453452

454453
params = _ConnectionParameters(
455454
user=user, password=password, database=database, ssl=ssl,
456-
ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout,
455+
sslmode=sslmode, connect_timeout=connect_timeout,
457456
server_settings=server_settings)
458457

459458
return addrs, params
@@ -520,9 +519,8 @@ def data_received(self, data):
520519
data == b'N'):
521520
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522521
# since the only way to get ssl_is_advisory is from
523-
# sslmode=prefer (or sslmode=allow). But be extra sure to
524-
# disallow insecure connections when the ssl context asks for
525-
# real security.
522+
# sslmode=prefer. But be extra sure to disallow insecure
523+
# connections when the ssl context asks for real security.
526524
self.on_data.set_result(False)
527525
else:
528526
self.on_data.set_exception(
@@ -566,6 +564,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566564
new_tr = tr
567565

568566
pg_proto = protocol_factory()
567+
pg_proto.is_ssl = do_ssl_upgrade
569568
pg_proto.connection_made(new_tr)
570569
new_tr.set_protocol(pg_proto)
571570

@@ -584,7 +583,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584583
tr.close()
585584

586585
try:
587-
return await conn_factory(sock=sock)
586+
new_tr, pg_proto = await conn_factory(sock=sock)
587+
pg_proto.is_ssl = do_ssl_upgrade
588+
return new_tr, pg_proto
588589
except (Exception, asyncio.CancelledError):
589590
sock.close()
590591
raise
@@ -605,8 +606,6 @@ async def _connect_addr(
605606
if timeout <= 0:
606607
raise asyncio.TimeoutError
607608

608-
connected = _create_future(loop)
609-
610609
params_input = params
611610
if callable(params.password):
612611
if inspect.iscoroutinefunction(params.password):
@@ -615,6 +614,49 @@ async def _connect_addr(
615614
password = params.password()
616615

617616
params = params._replace(password=password)
617+
args = (addr, loop, config, connection_class, record_class, params_input)
618+
619+
# prepare the params (which attempt has ssl) for the 2 attempts
620+
if params.sslmode == SSLMode.allow:
621+
params_retry = params
622+
params = params._replace(ssl=None)
623+
elif params.sslmode == SSLMode.prefer:
624+
params_retry = params._replace(ssl=None)
625+
else:
626+
# skip retry if we don't have to
627+
return await __connect_addr(params, timeout, False, *args)
628+
629+
# first attempt
630+
before = time.monotonic()
631+
try:
632+
return await __connect_addr(params, timeout, True, *args)
633+
except _Retry:
634+
pass
635+
636+
# second attempt
637+
timeout -= time.monotonic() - before
638+
if timeout <= 0:
639+
raise asyncio.TimeoutError
640+
else:
641+
return await __connect_addr(params_retry, timeout, False, *args)
642+
643+
644+
class _Retry(Exception):
645+
pass
646+
647+
648+
async def __connect_addr(
649+
params,
650+
timeout,
651+
retry,
652+
addr,
653+
loop,
654+
config,
655+
connection_class,
656+
record_class,
657+
params_input,
658+
):
659+
connected = _create_future(loop)
618660

619661
proto_factory = lambda: protocol.Protocol(
620662
addr, connected, params, record_class, loop)
@@ -625,7 +667,7 @@ async def _connect_addr(
625667
elif params.ssl:
626668
connector = _create_ssl_connection(
627669
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
628-
ssl_is_advisory=params.ssl_is_advisory)
670+
ssl_is_advisory=params.sslmode == SSLMode.prefer)
629671
else:
630672
connector = loop.create_connection(proto_factory, *addr)
631673

@@ -638,6 +680,35 @@ async def _connect_addr(
638680
if timeout <= 0:
639681
raise asyncio.TimeoutError
640682
await compat.wait_for(connected, timeout=timeout)
683+
except (
684+
exceptions.InvalidAuthorizationSpecificationError,
685+
exceptions.ConnectionDoesNotExistError, # seen on Windows
686+
):
687+
tr.close()
688+
689+
# retry=True here is a redundant check because we don't want to
690+
# accidentally raise the internal _Retry to the outer world
691+
if retry and (
692+
params.sslmode == SSLMode.allow and not pr.is_ssl or
693+
params.sslmode == SSLMode.prefer and pr.is_ssl
694+
):
695+
# Trigger retry when:
696+
# 1. First attempt with sslmode=allow, ssl=None failed
697+
# 2. First attempt with sslmode=prefer, ssl=ctx failed while the
698+
# server claimed to support SSL (returning "S" for SSLRequest)
699+
# (likely because pg_hba.conf rejected the connection)
700+
raise _Retry()
701+
702+
else:
703+
# but will NOT retry if:
704+
# 1. First attempt with sslmode=prefer failed but the server
705+
# doesn't support SSL (returning 'N' for SSLRequest), because
706+
# we already tried to connect without SSL thru ssl_is_advisory
707+
# 2. Second attempt with sslmode=prefer, ssl=None failed
708+
# 3. Second attempt with sslmode=allow, ssl=ctx failed
709+
# 4. Any other sslmode
710+
raise
711+
641712
except (Exception, asyncio.CancelledError):
642713
tr.close()
643714
raise
@@ -684,6 +755,7 @@ class CancelProto(asyncio.Protocol):
684755

685756
def __init__(self):
686757
self.on_disconnect = _create_future(loop)
758+
self.is_ssl = False
687759

688760
def connection_lost(self, exc):
689761
if not self.on_disconnect.done():
@@ -692,13 +764,13 @@ def connection_lost(self, exc):
692764
if isinstance(addr, str):
693765
tr, pr = await loop.create_unix_connection(CancelProto, addr)
694766
else:
695-
if params.ssl:
767+
if params.ssl and params.sslmode != SSLMode.allow:
696768
tr, pr = await _create_ssl_connection(
697769
CancelProto,
698770
*addr,
699771
loop=loop,
700772
ssl_context=params.ssl,
701-
ssl_is_advisory=params.ssl_is_advisory)
773+
ssl_is_advisory=params.sslmode == SSLMode.prefer)
702774
else:
703775
tr, pr = await loop.create_connection(
704776
CancelProto, *addr)

asyncpg/connection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,8 @@ async def connect(dsn=None, *,
18791879
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
18801880
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
18811881
if SSL connection fails
1882-
- ``'allow'`` - currently equivalent to ``'prefer'``
1882+
- ``'allow'`` - try without SSL first, then retry with SSL if the first
1883+
attempt fails.
18831884
- ``'require'`` - only try an SSL connection. Certificate
18841885
verification errors are ignored
18851886
- ``'verify-ca'`` - only try an SSL connection, and verify

asyncpg/protocol/protocol.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ cdef class BaseProtocol(CoreProtocol):
5252

5353
readonly uint64_t queries_count
5454

55+
bint _is_ssl
56+
5557
PreparedStatementState statement
5658

5759
cdef get_connection(self)

asyncpg/protocol/protocol.pyx

+10
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ cdef class BaseProtocol(CoreProtocol):
103103

104104
self.queries_count = 0
105105

106+
self._is_ssl = False
107+
106108
try:
107109
self.create_future = loop.create_future
108110
except AttributeError:
@@ -943,6 +945,14 @@ cdef class BaseProtocol(CoreProtocol):
943945
def resume_writing(self):
944946
self.writing_allowed.set()
945947

948+
@property
949+
def is_ssl(self):
950+
return self._is_ssl
951+
952+
@is_ssl.setter
953+
def is_ssl(self, value):
954+
self._is_ssl = value
955+
946956

947957
class Timer:
948958
def __init__(self, budget):

0 commit comments

Comments
 (0)