|
18 | 18 | import ssl as ssl_module
|
19 | 19 | import stat
|
20 | 20 | import struct
|
| 21 | +import sys |
21 | 22 | import time
|
22 | 23 | import typing
|
23 | 24 | import urllib.parse
|
@@ -220,13 +221,27 @@ def _parse_hostlist(hostlist, port, *, unquote=False):
|
220 | 221 | return hosts, port
|
221 | 222 |
|
222 | 223 |
|
| 224 | +def _parse_tls_version(tls_version): |
| 225 | + if tls_version.startswith('SSL'): |
| 226 | + raise ValueError( |
| 227 | + f"Unsupported TLS version: {tls_version}" |
| 228 | + ) |
| 229 | + try: |
| 230 | + return ssl_module.TLSVersion[tls_version.replace('.', '_')] |
| 231 | + except KeyError: |
| 232 | + raise ValueError( |
| 233 | + f"No such TLS version: {tls_version}" |
| 234 | + ) |
| 235 | + |
| 236 | + |
223 | 237 | def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
224 | 238 | password, passfile, database, ssl,
|
225 | 239 | connect_timeout, server_settings):
|
226 | 240 | # `auth_hosts` is the version of host information for the purposes
|
227 | 241 | # of reading the pgpass file.
|
228 | 242 | auth_hosts = None
|
229 |
| - sslcert = sslkey = sslrootcert = sslcrl = None |
| 243 | + sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None |
| 244 | + sslcompression = ssl_min_protocol_version = ssl_max_protocol_version = None |
230 | 245 |
|
231 | 246 | if dsn:
|
232 | 247 | parsed = urllib.parse.urlparse(dsn)
|
@@ -312,24 +327,28 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
312 | 327 | ssl = val
|
313 | 328 |
|
314 | 329 | if 'sslcert' in query:
|
315 |
| - val = query.pop('sslcert') |
316 |
| - if sslcert is None: |
317 |
| - sslcert = val |
| 330 | + sslcert = query.pop('sslcert') |
318 | 331 |
|
319 | 332 | if 'sslkey' in query:
|
320 |
| - val = query.pop('sslkey') |
321 |
| - if sslkey is None: |
322 |
| - sslkey = val |
| 333 | + sslkey = query.pop('sslkey') |
323 | 334 |
|
324 | 335 | if 'sslrootcert' in query:
|
325 |
| - val = query.pop('sslrootcert') |
326 |
| - if sslrootcert is None: |
327 |
| - sslrootcert = val |
| 336 | + sslrootcert = query.pop('sslrootcert') |
328 | 337 |
|
329 | 338 | if 'sslcrl' in query:
|
330 |
| - val = query.pop('sslcrl') |
331 |
| - if sslcrl is None: |
332 |
| - sslcrl = val |
| 339 | + sslcrl = query.pop('sslcrl') |
| 340 | + |
| 341 | + if 'sslpassword' in query: |
| 342 | + sslpassword = query.pop('sslpassword') |
| 343 | + |
| 344 | + if 'sslcompression' in query: |
| 345 | + sslcompression = query.pop('sslcompression') |
| 346 | + |
| 347 | + if 'ssl_min_protocol_version' in query: |
| 348 | + ssl_min_protocol_version = query.pop('ssl_min_protocol_version') |
| 349 | + |
| 350 | + if 'ssl_max_protocol_version' in query: |
| 351 | + ssl_max_protocol_version = query.pop('ssl_max_protocol_version') |
333 | 352 |
|
334 | 353 | if query:
|
335 | 354 | if server_settings is None:
|
@@ -451,34 +470,98 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
|
451 | 470 | if sslmode < SSLMode.allow:
|
452 | 471 | ssl = False
|
453 | 472 | else:
|
454 |
| - ssl = ssl_module.create_default_context( |
455 |
| - ssl_module.Purpose.SERVER_AUTH) |
| 473 | + ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) |
456 | 474 | ssl.check_hostname = sslmode >= SSLMode.verify_full
|
457 |
| - ssl.verify_mode = ssl_module.CERT_REQUIRED |
458 |
| - if sslmode <= SSLMode.require: |
| 475 | + if sslmode < SSLMode.require: |
459 | 476 | ssl.verify_mode = ssl_module.CERT_NONE
|
| 477 | + else: |
| 478 | + if sslrootcert is None: |
| 479 | + sslrootcert = os.getenv('PGSSLROOTCERT') |
| 480 | + if sslrootcert: |
| 481 | + ssl.load_verify_locations(cafile=sslrootcert) |
| 482 | + ssl.verify_mode = ssl_module.CERT_REQUIRED |
| 483 | + else: |
| 484 | + sslrootcert = os.path.expanduser('~/.postgresql/root.crt') |
| 485 | + try: |
| 486 | + ssl.load_verify_locations(cafile=sslrootcert) |
| 487 | + except FileNotFoundError: |
| 488 | + if sslmode > SSLMode.require: |
| 489 | + raise ValueError( |
| 490 | + f'root certificate file "{sslrootcert}" does ' |
| 491 | + f'not exist\nEither provide the file or ' |
| 492 | + f'change sslmode to disable server ' |
| 493 | + f'certificate verification.' |
| 494 | + ) |
| 495 | + elif sslmode == SSLMode.require: |
| 496 | + ssl.verify_mode = ssl_module.CERT_NONE |
| 497 | + else: |
| 498 | + assert False, 'unreachable' |
| 499 | + else: |
| 500 | + ssl.verify_mode = ssl_module.CERT_REQUIRED |
460 | 501 |
|
461 |
| - if sslcert is None: |
462 |
| - sslcert = os.getenv('PGSSLCERT') |
| 502 | + if sslcrl is None: |
| 503 | + sslcrl = os.getenv('PGSSLCRL') |
| 504 | + if sslcrl: |
| 505 | + ssl.load_verify_locations(cafile=sslcrl) |
| 506 | + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN |
| 507 | + else: |
| 508 | + sslcrl = os.path.expanduser('~/.postgresql/root.crl') |
| 509 | + try: |
| 510 | + ssl.load_verify_locations(cafile=sslcrl) |
| 511 | + except FileNotFoundError: |
| 512 | + pass |
| 513 | + else: |
| 514 | + ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN |
463 | 515 |
|
464 | 516 | if sslkey is None:
|
465 | 517 | sslkey = os.getenv('PGSSLKEY')
|
466 |
| - |
467 |
| - if sslrootcert is None: |
468 |
| - sslrootcert = os.getenv('PGSSLROOTCERT') |
469 |
| - |
470 |
| - if sslcrl is None: |
471 |
| - sslcrl = os.getenv('PGSSLCRL') |
472 |
| - |
| 518 | + if not sslkey: |
| 519 | + sslkey = os.path.expanduser('~/.postgresql/postgresql.key') |
| 520 | + if not os.path.exists(sslkey): |
| 521 | + sslkey = None |
| 522 | + if not sslpassword: |
| 523 | + sslpassword = '' |
| 524 | + if sslcert is None: |
| 525 | + sslcert = os.getenv('PGSSLCERT') |
473 | 526 | if sslcert:
|
474 |
| - ssl.load_cert_chain(sslcert, keyfile=sslkey) |
475 |
| - |
476 |
| - if sslrootcert: |
477 |
| - ssl.load_verify_locations(cafile=sslrootcert) |
478 |
| - |
479 |
| - if sslcrl: |
480 |
| - ssl.load_verify_locations(cafile=sslcrl) |
481 |
| - ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN |
| 527 | + ssl.load_cert_chain( |
| 528 | + sslcert, keyfile=sslkey, password=lambda: sslpassword |
| 529 | + ) |
| 530 | + else: |
| 531 | + sslcert = os.path.expanduser('~/.postgresql/postgresql.crt') |
| 532 | + try: |
| 533 | + ssl.load_cert_chain( |
| 534 | + sslcert, keyfile=sslkey, password=lambda: sslpassword |
| 535 | + ) |
| 536 | + except FileNotFoundError: |
| 537 | + pass |
| 538 | + |
| 539 | + # OpenSSL 1.1.1 keylog file, copied from create_default_context() |
| 540 | + if hasattr(ssl, 'keylog_filename'): |
| 541 | + keylogfile = os.environ.get('SSLKEYLOGFILE') |
| 542 | + if keylogfile and not sys.flags.ignore_environment: |
| 543 | + ssl.keylog_filename = keylogfile |
| 544 | + |
| 545 | + if sslcompression is None: |
| 546 | + sslcompression = os.getenv('PGSSLCOMPRESSION') |
| 547 | + if sslcompression == '1': |
| 548 | + ssl.verify_flags ^= ssl_module.OP_NO_COMPRESSION |
| 549 | + |
| 550 | + if ssl_min_protocol_version is None: |
| 551 | + ssl_min_protocol_version = os.getenv( |
| 552 | + 'PGSSLMINPROTOCOLVERSION', 'TLSv1.2' |
| 553 | + ) |
| 554 | + if ssl_min_protocol_version: |
| 555 | + ssl.minimum_version = _parse_tls_version( |
| 556 | + ssl_min_protocol_version |
| 557 | + ) |
| 558 | + |
| 559 | + if ssl_max_protocol_version is None: |
| 560 | + ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') |
| 561 | + if ssl_max_protocol_version: |
| 562 | + ssl.maximum_version = _parse_tls_version( |
| 563 | + ssl_max_protocol_version |
| 564 | + ) |
482 | 565 |
|
483 | 566 | elif ssl is True:
|
484 | 567 | ssl = ssl_module.create_default_context()
|
|
0 commit comments