Skip to content

Implement connection service file functionality #1223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 125 additions & 3 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import asyncio
import configparser
import collections
from collections.abc import Callable
import enum
Expand Down Expand Up @@ -87,6 +88,9 @@ class SSLNegotiation(compat.StrEnum):
PGPASSFILE = '.pgpass'


PG_SERVICEFILE = '.pg_service.conf'


def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:

Expand Down Expand Up @@ -268,7 +272,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
password, passfile, database, ssl, service,
direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
# `auth_hosts` is the version of host information for the purposes
Expand All @@ -278,6 +282,120 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
ssl_min_protocol_version = ssl_max_protocol_version = None
sslnegotiation = None

if dsn:
parsed = urllib.parse.urlparse(dsn)
if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]

if 'service' in query:
val = query.pop('service')
if not service and val:
service = val

connection_service_file = os.getenv('PGSERVICEFILE')
if connection_service_file is None:
homedir = compat.get_pg_home_directory()
if homedir:
connection_service_file = homedir / PG_SERVICEFILE
else:
connection_service_file = None
else:
connection_service_file = pathlib.Path(connection_service_file)

if connection_service_file is not None and service is not None:
# TODO Open and parse connection service file
pg_service = configparser.ConfigParser()
pg_service.read(connection_service_file)
if service in pg_service.sections():
service_params = pg_service[service]
if 'port' in service_params:
val = service_params.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]

if 'host' in service_params:
val = service_params.pop('host')
if not host and val:
host, port = _parse_hostlist(val, port)

if 'dbname' in service_params:
val = service_params.pop('dbname')
if database is None:
database = val

if 'database' in service_params:
val = service_params.pop('database')
if database is None:
database = val

if 'user' in service_params:
val = service_params.pop('user')
if user is None:
user = val

if 'password' in service_params:
val = service_params.pop('password')
if password is None:
password = val

if 'passfile' in service_params:
val = service_params.pop('passfile')
if passfile is None:
passfile = val

if 'sslmode' in service_params:
val = service_params.pop('sslmode')
if ssl is None:
ssl = val

if 'sslcert' in service_params:
sslcert = service_params.pop('sslcert')

if 'sslkey' in service_params:
sslkey = service_params.pop('sslkey')

if 'sslrootcert' in service_params:
sslrootcert = service_params.pop('sslrootcert')

if 'sslnegotiation' in service_params:
sslnegotiation = service_params.pop('sslnegotiation')

if 'sslcrl' in service_params:
sslcrl = service_params.pop('sslcrl')

if 'sslpassword' in service_params:
sslpassword = service_params.pop('sslpassword')

if 'ssl_min_protocol_version' in service_params:
ssl_min_protocol_version = service_params.pop(
'ssl_min_protocol_version'
)

if 'ssl_max_protocol_version' in service_params:
ssl_max_protocol_version = service_params.pop(
'ssl_max_protocol_version'
)

if 'target_session_attrs' in service_params:
dsn_target_session_attrs = service_params.pop(
'target_session_attrs'
)
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs

if 'krbsrvname' in service_params:
val = service_params.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val

if 'gsslib' in service_params:
val = service_params.pop('gsslib')
if gsslib is None:
gsslib = val

if dsn:
parsed = urllib.parse.urlparse(dsn)

Expand Down Expand Up @@ -406,6 +524,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if gsslib is None:
gsslib = val

if 'service' in query:
val = query.pop('service')

if query:
if server_settings is None:
server_settings = query
Expand Down Expand Up @@ -491,6 +612,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
database=database, user=user,
passfile=passfile)


addrs = []
have_tcp_addrs = False
for h, p in zip(host, port):
Expand Down Expand Up @@ -724,7 +846,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
target_session_attrs, krbsrvname, gsslib, service):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -754,7 +876,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, service=service)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down
6 changes: 6 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,7 @@ async def _do_execute(
async def connect(dsn=None, *,
host=None, port=None,
user=None, password=None, passfile=None,
service=None,
database=None,
loop=None,
timeout=60,
Expand Down Expand Up @@ -2183,6 +2184,10 @@ async def connect(dsn=None, *,
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
on Windows).

:param service:
The name of the postgres connection service stored in the postgres
connection service file.

:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
Expand Down Expand Up @@ -2428,6 +2433,7 @@ async def connect(dsn=None, *,
user=user,
password=password,
passfile=passfile,
service=service,
ssl=ssl,
direct_tls=direct_tls,
database=database,
Expand Down
70 changes: 68 additions & 2 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,8 @@ def run_testcase(self, testcase):
env = testcase.get('env', {})
test_env = {'PGHOST': None, 'PGPORT': None,
'PGUSER': None, 'PGPASSWORD': None,
'PGDATABASE': None, 'PGSSLMODE': None}
'PGDATABASE': None, 'PGSSLMODE': None,
'PGSERVICE': None, }
test_env.update(env)

dsn = testcase.get('dsn')
Expand All @@ -1132,6 +1133,7 @@ def run_testcase(self, testcase):
target_session_attrs = testcase.get('target_session_attrs')
krbsrvname = testcase.get('krbsrvname')
gsslib = testcase.get('gsslib')
service = testcase.get('service')

expected = testcase.get('result')
expected_error = testcase.get('error')
Expand All @@ -1157,7 +1159,7 @@ def run_testcase(self, testcase):
direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
krbsrvname=krbsrvname, gsslib=gsslib, service=service)

params = {
k: v for k, v in params._asdict().items()
Expand Down Expand Up @@ -1236,6 +1238,70 @@ def test_connect_params(self):
for testcase in self.TESTS:
self.run_testcase(testcase)

def test_connect_connection_service_file(self):
connection_service_file = tempfile.NamedTemporaryFile('w+t', delete=False)
connection_service_file.write(textwrap.dedent(f'''
[test_service_dbname]
port=5433
host=somehost
dbname=test_dbname
user=admin
password=test_password
target_session_attrs=primary
krbsrvname=fakekrbsrvname
gsslib=sspi

[test_service_database]
port=5433
host=somehost
database=test_dbname
user=admin
password=test_password
target_session_attrs=primary
krbsrvname=fakekrbsrvname
gsslib=sspi
'''))
connection_service_file.close()
os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR)
try:
# passfile path in env
self.run_testcase({
'dsn': 'postgresql://?service=test_service_dbname',
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
self.run_testcase({
'dsn': 'postgresql://?service=test_service_database',
'env': {
'PGSERVICEFILE': connection_service_file.name
},
'result': (
[('somehost', 5433)],
{
'user': 'admin',
'password': 'test_password',
'database': 'test_dbname',
'target_session_attrs': 'primary',
'krbsrvname': 'fakekrbsrvname',
'gsslib': 'sspi',
}
)
})
finally:
os.unlink(connection_service_file.name)

def test_connect_pgpass_regular(self):
passfile = tempfile.NamedTemporaryFile('w+t', delete=False)
passfile.write(textwrap.dedent(R'''
Expand Down