Skip to content

Add support for the sslnegotiation parameter #1187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions asyncpg/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import enum
import pathlib
import platform
import typing
Expand Down Expand Up @@ -78,3 +79,10 @@ def markcoroutinefunction(c): # type: ignore
from collections.abc import ( # noqa: F401
Awaitable as Awaitable,
)

if sys.version_info < (3, 11):
class StrEnum(str, enum.Enum):
__str__ = str.__str__
__repr__ = enum.Enum.__repr__
else:
from enum import StrEnum as StrEnum # noqa: F401
44 changes: 38 additions & 6 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def parse(cls, sslmode):
return getattr(cls, sslmode.replace('-', '_'))


class SSLNegotiation(compat.StrEnum):
postgres = "postgres"
direct = "direct"


_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
Expand All @@ -53,7 +58,7 @@ def parse(cls, sslmode):
'database',
'ssl',
'sslmode',
'direct_tls',
'ssl_negotiation',
'server_settings',
'target_session_attrs',
'krbsrvname',
Expand Down Expand Up @@ -269,6 +274,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
auth_hosts = None
sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
ssl_min_protocol_version = ssl_max_protocol_version = None
sslnegotiation = None

if dsn:
parsed = urllib.parse.urlparse(dsn)
Expand Down Expand Up @@ -362,6 +368,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if 'sslrootcert' in query:
sslrootcert = query.pop('sslrootcert')

if 'sslnegotiation' in query:
sslnegotiation = query.pop('sslnegotiation')

if 'sslcrl' in query:
sslcrl = query.pop('sslcrl')

Expand Down Expand Up @@ -503,13 +512,36 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None and have_tcp_addrs:
ssl = 'prefer'

if direct_tls is not None:
sslneg = (
SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
)
else:
if sslnegotiation is None:
sslnegotiation = os.environ.get("PGSSLNEGOTIATION")

if sslnegotiation is not None:
try:
sslneg = SSLNegotiation(sslnegotiation)
except ValueError:
modes = ', '.join(
m.name.replace('_', '-')
for m in SSLNegotiation
)
raise exceptions.ClientConfigurationError(
f'`sslnegotiation` parameter must be one of: {modes}'
) from None
else:
sslneg = SSLNegotiation.postgres

if isinstance(ssl, (str, SSLMode)):
try:
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
raise exceptions.ClientConfigurationError(
'`sslmode` parameter must be one of: {}'.format(modes))
'`sslmode` parameter must be one of: {}'.format(modes)
) from None

# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
if sslmode < SSLMode.allow:
Expand Down Expand Up @@ -676,7 +708,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
sslmode=sslmode, ssl_negotiation=sslneg,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
Expand Down Expand Up @@ -882,9 +914,9 @@ async def __connect_addr(
# UNIX socket
connector = loop.create_unix_connection(proto_factory, addr)

elif params.ssl and params.direct_tls:
# if ssl and direct_tls are given, skip STARTTLS and perform direct
# SSL connection
elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
# if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
# direct SSL connection
connector = loop.create_connection(
proto_factory, *addr, ssl=params.ssl
)
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,7 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
direct_tls=False,
direct_tls=None,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None,
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ exclude_lines = [
show_missing = true

[tool.mypy]
exclude = [
"^.eggs",
"^.github",
"^.vscode",
"^build",
"^dist",
"^docs",
"^tests",
]
incremental = true
strict = true
implicit_reexport = true
Expand Down
59 changes: 58 additions & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,58 @@ class TestConnectParams(tb.TestCase):
'target_session_attrs': 'any'})
},

{
'name': 'params_ssl_negotiation_dsn',
'env': {
'PGSSLNEGOTIATION': 'postgres'
},

'dsn': 'postgres://u:p@localhost/d?sslnegotiation=direct',

'result': ([('localhost', 5432)], {
'user': 'u',
'password': 'p',
'database': 'd',
'ssl_negotiation': 'direct',
'target_session_attrs': 'any',
})
},

{
'name': 'params_ssl_negotiation_env',
'env': {
'PGSSLNEGOTIATION': 'direct'
},

'dsn': 'postgres://u:p@localhost/d',

'result': ([('localhost', 5432)], {
'user': 'u',
'password': 'p',
'database': 'd',
'ssl_negotiation': 'direct',
'target_session_attrs': 'any',
})
},

{
'name': 'params_ssl_negotiation_params',
'env': {
'PGSSLNEGOTIATION': 'direct'
},

'dsn': 'postgres://u:p@localhost/d',
'direct_tls': False,

'result': ([('localhost', 5432)], {
'user': 'u',
'password': 'p',
'database': 'd',
'ssl_negotiation': 'postgres',
'target_session_attrs': 'any',
})
},

{
'name': 'dsn_overrides_env_partially_ssl_prefer',
'env': {
Expand Down Expand Up @@ -1067,6 +1119,7 @@ def run_testcase(self, testcase):
passfile = testcase.get('passfile')
database = testcase.get('database')
sslmode = testcase.get('ssl')
direct_tls = testcase.get('direct_tls')
server_settings = testcase.get('server_settings')
target_session_attrs = testcase.get('target_session_attrs')
krbsrvname = testcase.get('krbsrvname')
Expand All @@ -1093,7 +1146,7 @@ def run_testcase(self, testcase):
addrs, params = connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
passfile=passfile, database=database, ssl=sslmode,
direct_tls=False,
direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
Expand All @@ -1118,6 +1171,10 @@ def run_testcase(self, testcase):
# Avoid the hassle of specifying direct_tls
# unless explicitly tested for
params.pop('direct_tls', False)
if 'ssl_negotiation' not in expected[1]:
# Avoid the hassle of specifying sslnegotiation
# unless explicitly tested for
params.pop('ssl_negotiation', False)
if 'gsslib' not in expected[1]:
# Avoid the hassle of specifying gsslib
# unless explicitly tested for
Expand Down
Loading