7
7
8
8
import asyncio
9
9
import collections
10
+ import enum
10
11
import functools
11
12
import getpass
12
13
import os
28
29
from . import protocol
29
30
30
31
32
+ class SSLMode (enum .IntEnum ):
33
+ disable = 0
34
+ allow = 1
35
+ prefer = 2
36
+ require = 3
37
+ verify_ca = 4
38
+ verify_full = 5
39
+
40
+ @classmethod
41
+ def parse (cls , sslmode ):
42
+ if isinstance (sslmode , cls ):
43
+ return sslmode
44
+ return getattr (cls , sslmode .replace ('-' , '_' ))
45
+
46
+
31
47
_ConnectionParameters = collections .namedtuple (
32
48
'ConnectionParameters' ,
33
49
[
34
50
'user' ,
35
51
'password' ,
36
52
'database' ,
37
53
'ssl' ,
38
- 'ssl_is_advisory ' ,
54
+ 'sslmode ' ,
39
55
'connect_timeout' ,
40
56
'server_settings' ,
41
57
])
@@ -402,46 +418,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402
418
if ssl is None and have_tcp_addrs :
403
419
ssl = 'prefer'
404
420
405
- # ssl_is_advisory is only allowed to come from the sslmode parameter.
406
- ssl_is_advisory = None
407
- if isinstance (ssl , str ):
408
- SSLMODES = {
409
- 'disable' : 0 ,
410
- 'allow' : 1 ,
411
- 'prefer' : 2 ,
412
- 'require' : 3 ,
413
- 'verify-ca' : 4 ,
414
- 'verify-full' : 5 ,
415
- }
421
+ if isinstance (ssl , (str , SSLMode )):
416
422
try :
417
- sslmode = SSLMODES [ ssl ]
418
- except KeyError :
419
- modes = ', ' .join (SSLMODES . keys () )
423
+ sslmode = SSLMode . parse ( ssl )
424
+ except AttributeError :
425
+ modes = ', ' .join (m . name . replace ( '_' , '-' ) for m in SSLMode )
420
426
raise exceptions .InterfaceError (
421
427
'`sslmode` parameter must be one of: {}' .format (modes ))
422
428
423
- # sslmode 'allow' is currently handled as 'prefer' because we're
424
- # missing the "retry with SSL" behavior for 'allow', but do have the
425
- # "retry without SSL" behavior for 'prefer'.
426
- # Not changing 'allow' to 'prefer' here would be effectively the same
427
- # as changing 'allow' to 'disable'.
428
- if sslmode == SSLMODES ['allow' ]:
429
- sslmode = SSLMODES ['prefer' ]
430
-
431
429
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
432
430
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
433
- if sslmode <= SSLMODES [ ' allow' ] :
431
+ if sslmode < SSLMode . allow :
434
432
ssl = False
435
- ssl_is_advisory = sslmode >= SSLMODES ['allow' ]
436
433
else :
437
434
ssl = ssl_module .create_default_context ()
438
- ssl .check_hostname = sslmode >= SSLMODES [ 'verify-full' ]
435
+ ssl .check_hostname = sslmode >= SSLMode . verify_full
439
436
ssl .verify_mode = ssl_module .CERT_REQUIRED
440
- if sslmode <= SSLMODES [ ' require' ] :
437
+ if sslmode <= SSLMode . require :
441
438
ssl .verify_mode = ssl_module .CERT_NONE
442
- ssl_is_advisory = sslmode <= SSLMODES ['prefer' ]
443
439
elif ssl is True :
444
440
ssl = ssl_module .create_default_context ()
441
+ sslmode = SSLMode .verify_full
442
+ else :
443
+ sslmode = SSLMode .disable
445
444
446
445
if server_settings is not None and (
447
446
not isinstance (server_settings , dict ) or
@@ -453,7 +452,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453
452
454
453
params = _ConnectionParameters (
455
454
user = user , password = password , database = database , ssl = ssl ,
456
- ssl_is_advisory = ssl_is_advisory , connect_timeout = connect_timeout ,
455
+ sslmode = sslmode , connect_timeout = connect_timeout ,
457
456
server_settings = server_settings )
458
457
459
458
return addrs , params
@@ -520,9 +519,8 @@ def data_received(self, data):
520
519
data == b'N' ):
521
520
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522
521
# since the only way to get ssl_is_advisory is from
523
- # sslmode=prefer (or sslmode=allow). But be extra sure to
524
- # disallow insecure connections when the ssl context asks for
525
- # real security.
522
+ # sslmode=prefer. But be extra sure to disallow insecure
523
+ # connections when the ssl context asks for real security.
526
524
self .on_data .set_result (False )
527
525
else :
528
526
self .on_data .set_exception (
@@ -566,6 +564,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566
564
new_tr = tr
567
565
568
566
pg_proto = protocol_factory ()
567
+ pg_proto .is_ssl = do_ssl_upgrade
569
568
pg_proto .connection_made (new_tr )
570
569
new_tr .set_protocol (pg_proto )
571
570
@@ -584,7 +583,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584
583
tr .close ()
585
584
586
585
try :
587
- return await conn_factory (sock = sock )
586
+ new_tr , pg_proto = await conn_factory (sock = sock )
587
+ pg_proto .is_ssl = do_ssl_upgrade
588
+ return new_tr , pg_proto
588
589
except (Exception , asyncio .CancelledError ):
589
590
sock .close ()
590
591
raise
@@ -605,8 +606,6 @@ async def _connect_addr(
605
606
if timeout <= 0 :
606
607
raise asyncio .TimeoutError
607
608
608
- connected = _create_future (loop )
609
-
610
609
params_input = params
611
610
if callable (params .password ):
612
611
if inspect .iscoroutinefunction (params .password ):
@@ -615,6 +614,49 @@ async def _connect_addr(
615
614
password = params .password ()
616
615
617
616
params = params ._replace (password = password )
617
+ args = (addr , loop , config , connection_class , record_class , params_input )
618
+
619
+ # prepare the params (which attempt has ssl) for the 2 attempts
620
+ if params .sslmode == SSLMode .allow :
621
+ params_retry = params
622
+ params = params ._replace (ssl = None )
623
+ elif params .sslmode == SSLMode .prefer :
624
+ params_retry = params ._replace (ssl = None )
625
+ else :
626
+ # skip retry if we don't have to
627
+ return await __connect_addr (params , timeout , False , * args )
628
+
629
+ # first attempt
630
+ before = time .monotonic ()
631
+ try :
632
+ return await __connect_addr (params , timeout , True , * args )
633
+ except _Retry :
634
+ pass
635
+
636
+ # second attempt
637
+ timeout -= time .monotonic () - before
638
+ if timeout <= 0 :
639
+ raise asyncio .TimeoutError
640
+ else :
641
+ return await __connect_addr (params_retry , timeout , False , * args )
642
+
643
+
644
+ class _Retry (Exception ):
645
+ pass
646
+
647
+
648
+ async def __connect_addr (
649
+ params ,
650
+ timeout ,
651
+ retry ,
652
+ addr ,
653
+ loop ,
654
+ config ,
655
+ connection_class ,
656
+ record_class ,
657
+ params_input ,
658
+ ):
659
+ connected = _create_future (loop )
618
660
619
661
proto_factory = lambda : protocol .Protocol (
620
662
addr , connected , params , record_class , loop )
@@ -625,7 +667,7 @@ async def _connect_addr(
625
667
elif params .ssl :
626
668
connector = _create_ssl_connection (
627
669
proto_factory , * addr , loop = loop , ssl_context = params .ssl ,
628
- ssl_is_advisory = params .ssl_is_advisory )
670
+ ssl_is_advisory = params .sslmode == SSLMode . prefer )
629
671
else :
630
672
connector = loop .create_connection (proto_factory , * addr )
631
673
@@ -638,6 +680,35 @@ async def _connect_addr(
638
680
if timeout <= 0 :
639
681
raise asyncio .TimeoutError
640
682
await compat .wait_for (connected , timeout = timeout )
683
+ except (
684
+ exceptions .InvalidAuthorizationSpecificationError ,
685
+ exceptions .ConnectionDoesNotExistError , # seen on Windows
686
+ ):
687
+ tr .close ()
688
+
689
+ # retry=True here is a redundant check because we don't want to
690
+ # accidentally raise the internal _Retry to the outer world
691
+ if retry and (
692
+ params .sslmode == SSLMode .allow and not pr .is_ssl or
693
+ params .sslmode == SSLMode .prefer and pr .is_ssl
694
+ ):
695
+ # Trigger retry when:
696
+ # 1. First attempt with sslmode=allow, ssl=None failed
697
+ # 2. First attempt with sslmode=prefer, ssl=ctx failed while the
698
+ # server claimed to support SSL (returning "S" for SSLRequest)
699
+ # (likely because pg_hba.conf rejected the connection)
700
+ raise _Retry ()
701
+
702
+ else :
703
+ # but will NOT retry if:
704
+ # 1. First attempt with sslmode=prefer failed but the server
705
+ # doesn't support SSL (returning 'N' for SSLRequest), because
706
+ # we already tried to connect without SSL thru ssl_is_advisory
707
+ # 2. Second attempt with sslmode=prefer, ssl=None failed
708
+ # 3. Second attempt with sslmode=allow, ssl=ctx failed
709
+ # 4. Any other sslmode
710
+ raise
711
+
641
712
except (Exception , asyncio .CancelledError ):
642
713
tr .close ()
643
714
raise
@@ -684,6 +755,7 @@ class CancelProto(asyncio.Protocol):
684
755
685
756
def __init__ (self ):
686
757
self .on_disconnect = _create_future (loop )
758
+ self .is_ssl = False
687
759
688
760
def connection_lost (self , exc ):
689
761
if not self .on_disconnect .done ():
@@ -692,13 +764,13 @@ def connection_lost(self, exc):
692
764
if isinstance (addr , str ):
693
765
tr , pr = await loop .create_unix_connection (CancelProto , addr )
694
766
else :
695
- if params .ssl :
767
+ if params .ssl and params . sslmode != SSLMode . allow :
696
768
tr , pr = await _create_ssl_connection (
697
769
CancelProto ,
698
770
* addr ,
699
771
loop = loop ,
700
772
ssl_context = params .ssl ,
701
- ssl_is_advisory = params .ssl_is_advisory )
773
+ ssl_is_advisory = params .sslmode == SSLMode . prefer )
702
774
else :
703
775
tr , pr = await loop .create_connection (
704
776
CancelProto , * addr )
0 commit comments