Skip to content

Commit 77d4742

Browse files
committed
Add sslmode=allow support and fix =prefer retry
We didn't really retry the connection without SSL if the first SSL connection fails, that led to an issue when the server has SSL support but explicitly denies SSL connection through pg_hba.conf. This commit adds a retry in a new connection, which makes it easy to implement the sslmode=allow retry. Fixes MagicStack#716
1 parent 53bea98 commit 77d4742

File tree

5 files changed

+242
-49
lines changed

5 files changed

+242
-49
lines changed

asyncpg/connect_utils.py

+100-30
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'password',
3636
'database',
3737
'ssl',
38-
'ssl_is_advisory',
38+
'alt_retry_ssl_first',
3939
'connect_timeout',
4040
'server_settings',
4141
])
@@ -402,8 +402,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402402
if ssl is None and have_tcp_addrs:
403403
ssl = 'prefer'
404404

405-
# ssl_is_advisory is only allowed to come from the sslmode parameter.
406-
ssl_is_advisory = None
405+
# alt_retry_ssl_first is particularly for "allow" and "prefer"
406+
# to alternatively try SSL/non-SSL connections (once each if supported):
407+
# False - allow (try non-SSL first)
408+
# True - prefer (try SSL first)
409+
# None - other (don't retry, stick with the "ssl" parameter)
410+
alt_retry_ssl_first = None
411+
407412
if isinstance(ssl, str):
408413
SSLMODES = {
409414
'disable': 0,
@@ -420,26 +425,21 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
420425
raise exceptions.InterfaceError(
421426
'`sslmode` parameter must be one of: {}'.format(modes))
422427

423-
# sslmode 'allow' is currently handled as 'prefer' because we're
424-
# missing the "retry with SSL" behavior for 'allow', but do have the
425-
# "retry without SSL" behavior for 'prefer'.
426-
# Not changing 'allow' to 'prefer' here would be effectively the same
427-
# as changing 'allow' to 'disable'.
428428
if sslmode == SSLMODES['allow']:
429-
sslmode = SSLMODES['prefer']
429+
alt_retry_ssl_first = False
430+
elif sslmode == SSLMODES['prefer']:
431+
alt_retry_ssl_first = True
430432

431433
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
432434
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
433-
if sslmode <= SSLMODES['allow']:
435+
if sslmode < SSLMODES['allow']:
434436
ssl = False
435-
ssl_is_advisory = sslmode >= SSLMODES['allow']
436437
else:
437438
ssl = ssl_module.create_default_context()
438439
ssl.check_hostname = sslmode >= SSLMODES['verify-full']
439440
ssl.verify_mode = ssl_module.CERT_REQUIRED
440441
if sslmode <= SSLMODES['require']:
441442
ssl.verify_mode = ssl_module.CERT_NONE
442-
ssl_is_advisory = sslmode <= SSLMODES['prefer']
443443
elif ssl is True:
444444
ssl = ssl_module.create_default_context()
445445

@@ -453,7 +453,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453453

454454
params = _ConnectionParameters(
455455
user=user, password=password, database=database, ssl=ssl,
456-
ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout,
456+
alt_retry_ssl_first=alt_retry_ssl_first,
457+
connect_timeout=connect_timeout,
457458
server_settings=server_settings)
458459

459460
return addrs, params
@@ -520,9 +521,8 @@ def data_received(self, data):
520521
data == b'N'):
521522
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522523
# since the only way to get ssl_is_advisory is from
523-
# sslmode=prefer (or sslmode=allow). But be extra sure to
524-
# disallow insecure connections when the ssl context asks for
525-
# real security.
524+
# sslmode=prefer. But be extra sure to disallow insecure
525+
# connections when the ssl context asks for real security.
526526
self.on_data.set_result(False)
527527
else:
528528
self.on_data.set_exception(
@@ -566,6 +566,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566566
new_tr = tr
567567

568568
pg_proto = protocol_factory()
569+
pg_proto.is_ssl = do_ssl_upgrade
569570
pg_proto.connection_made(new_tr)
570571
new_tr.set_protocol(pg_proto)
571572

@@ -584,7 +585,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584585
tr.close()
585586

586587
try:
587-
return await conn_factory(sock=sock)
588+
new_tr, pg_proto = await conn_factory(sock=sock)
589+
pg_proto.is_ssl = do_ssl_upgrade
590+
return new_tr, pg_proto
588591
except (Exception, asyncio.CancelledError):
589592
sock.close()
590593
raise
@@ -605,8 +608,6 @@ async def _connect_addr(
605608
if timeout <= 0:
606609
raise asyncio.TimeoutError
607610

608-
connected = _create_future(loop)
609-
610611
params_input = params
611612
if callable(params.password):
612613
if inspect.iscoroutinefunction(params.password):
@@ -615,6 +616,44 @@ async def _connect_addr(
615616
password = params.password()
616617

617618
params = params._replace(password=password)
619+
args = (addr, loop, config, connection_class, record_class, params_input)
620+
621+
# skip retry if alt_retry is not enabled
622+
if params.alt_retry_ssl_first is None:
623+
return await __connect_addr(params, timeout, *args)
624+
625+
# prepare the params (which attempt has ssl) for the 2 attempts
626+
params_retry = params._replace(ssl=None)
627+
if not params.alt_retry_ssl_first:
628+
params, params_retry = params_retry, params
629+
630+
# first attempt
631+
before = time.monotonic()
632+
try:
633+
return await __connect_addr(params, timeout, *args)
634+
except ConnectionError:
635+
pass
636+
637+
# the second attempt with alt_retry_ssl_first=None
638+
timeout -= time.monotonic() - before
639+
if timeout <= 0:
640+
raise asyncio.TimeoutError
641+
else:
642+
params_retry = params_retry._replace(alt_retry_ssl_first=None)
643+
return await __connect_addr(params_retry, timeout, *args)
644+
645+
646+
async def __connect_addr(
647+
params,
648+
timeout,
649+
addr,
650+
loop,
651+
config,
652+
connection_class,
653+
record_class,
654+
params_input,
655+
):
656+
connected = _create_future(loop)
618657

619658
proto_factory = lambda: protocol.Protocol(
620659
addr, connected, params, record_class, loop)
@@ -625,7 +664,7 @@ async def _connect_addr(
625664
elif params.ssl:
626665
connector = _create_ssl_connection(
627666
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
628-
ssl_is_advisory=params.ssl_is_advisory)
667+
ssl_is_advisory=params.alt_retry_ssl_first)
629668
else:
630669
connector = loop.create_connection(proto_factory, *addr)
631670

@@ -638,6 +677,23 @@ async def _connect_addr(
638677
if timeout <= 0:
639678
raise asyncio.TimeoutError
640679
await compat.wait_for(connected, timeout=timeout)
680+
except exceptions.InvalidAuthorizationSpecificationError:
681+
tr.close()
682+
683+
# pr.is_ssl is a bool, so this equal test implies
684+
# alt_retry_ssl_first is not None (should do alt_retry)
685+
if params.alt_retry_ssl_first == pr.is_ssl:
686+
# Elevate the error to ConnectionError to trigger retry
687+
raise ConnectionError("Connection rejected trying {} SSL".format(
688+
'with' if pr.is_ssl else 'without'))
689+
690+
else:
691+
# Don't retry if alt_retry_ssl_first is None, or we don't need to
692+
# (alt_retry_ssl_first=True and pr.is_ssl=False means the server
693+
# doesn't support SSL, and we've already tried to Startup without
694+
# SSL but failed; The opposite case doesn't exist).
695+
raise
696+
641697
except (Exception, asyncio.CancelledError):
642698
tr.close()
643699
raise
@@ -684,6 +740,7 @@ class CancelProto(asyncio.Protocol):
684740

685741
def __init__(self):
686742
self.on_disconnect = _create_future(loop)
743+
self.is_ssl = False
687744

688745
def connection_lost(self, exc):
689746
if not self.on_disconnect.done():
@@ -692,17 +749,30 @@ def connection_lost(self, exc):
692749
if isinstance(addr, str):
693750
tr, pr = await loop.create_unix_connection(CancelProto, addr)
694751
else:
695-
if params.ssl:
696-
tr, pr = await _create_ssl_connection(
697-
CancelProto,
698-
*addr,
699-
loop=loop,
700-
ssl_context=params.ssl,
701-
ssl_is_advisory=params.ssl_is_advisory)
752+
async def _connect(params_in, ssl_is_advisory):
753+
if params_in.ssl:
754+
return await _create_ssl_connection(
755+
CancelProto,
756+
*addr,
757+
loop=loop,
758+
ssl_context=params_in.ssl,
759+
ssl_is_advisory=ssl_is_advisory)
760+
else:
761+
return await loop.create_connection(
762+
CancelProto, *addr)
763+
_set_nodelay(_get_socket(tr))
764+
765+
if params.alt_retry_ssl_first is None:
766+
tr, pr = await _connect(params, False)
702767
else:
703-
tr, pr = await loop.create_connection(
704-
CancelProto, *addr)
705-
_set_nodelay(_get_socket(tr))
768+
params_retry = params._replace(ssl=None)
769+
if not params.alt_retry_ssl_first:
770+
params, params_retry = params_retry, params
771+
try:
772+
tr, pr = await _connect(params, True)
773+
except ConnectionError:
774+
tr, pr = await _connect(
775+
params._replace(alt_retry_ssl_first=None), False)
706776

707777
# Pack a CancelRequest message
708778
msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)

asyncpg/connection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,8 @@ async def connect(dsn=None, *,
18791879
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
18801880
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
18811881
if SSL connection fails
1882-
- ``'allow'`` - currently equivalent to ``'prefer'``
1882+
- ``'allow'`` - try without SSL first, then retry with SSL if the first
1883+
attempt fails.
18831884
- ``'require'`` - only try an SSL connection. Certificate
18841885
verification errors are ignored
18851886
- ``'verify-ca'`` - only try an SSL connection, and verify

asyncpg/protocol/protocol.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ cdef class BaseProtocol(CoreProtocol):
5252

5353
readonly uint64_t queries_count
5454

55+
bint _is_ssl
56+
5557
PreparedStatementState statement
5658

5759
cdef get_connection(self)

asyncpg/protocol/protocol.pyx

+10
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ cdef class BaseProtocol(CoreProtocol):
103103

104104
self.queries_count = 0
105105

106+
self._is_ssl = False
107+
106108
try:
107109
self.create_future = loop.create_future
108110
except AttributeError:
@@ -943,6 +945,14 @@ cdef class BaseProtocol(CoreProtocol):
943945
def resume_writing(self):
944946
self.writing_allowed.set()
945947

948+
@property
949+
def is_ssl(self):
950+
return self._is_ssl
951+
952+
@is_ssl.setter
953+
def is_ssl(self, value):
954+
self._is_ssl = value
955+
946956

947957
class Timer:
948958
def __init__(self, budget):

0 commit comments

Comments
 (0)