Skip to content

Commit 172b8f6

Browse files
Handle environments without home dir (#1011)
1 parent 247b1a5 commit 172b8f6

File tree

3 files changed

+68
-20
lines changed

3 files changed

+68
-20
lines changed

asyncpg/compat.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import pathlib
1010
import platform
11+
import typing
1112

1213

1314
SYSTEM = platform.uname().system
@@ -18,7 +19,7 @@
1819

1920
CSIDL_APPDATA = 0x001a
2021

21-
def get_pg_home_directory() -> pathlib.Path:
22+
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
2223
# We cannot simply use expanduser() as that returns the user's
2324
# home directory, whereas Postgres stores its config in
2425
# %AppData% on Windows.
@@ -30,8 +31,11 @@ def get_pg_home_directory() -> pathlib.Path:
3031
return pathlib.Path(buf.value) / 'postgresql'
3132

3233
else:
33-
def get_pg_home_directory() -> pathlib.Path:
34-
return pathlib.Path.home()
34+
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
35+
try:
36+
return pathlib.Path.home()
37+
except (RuntimeError, KeyError):
38+
return None
3539

3640

3741
async def wait_closed(stream):

asyncpg/connect_utils.py

+32-17
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,13 @@ def _parse_tls_version(tls_version):
249249
)
250250

251251

252-
def _dot_postgresql_path(filename) -> pathlib.Path:
253-
return (pathlib.Path.home() / '.postgresql' / filename).resolve()
252+
def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
253+
try:
254+
homedir = pathlib.Path.home()
255+
except (RuntimeError, KeyError):
256+
return None
257+
258+
return (homedir / '.postgresql' / filename).resolve()
254259

255260

256261
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
@@ -501,11 +506,16 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
501506
ssl.load_verify_locations(cafile=sslrootcert)
502507
ssl.verify_mode = ssl_module.CERT_REQUIRED
503508
else:
504-
sslrootcert = _dot_postgresql_path('root.crt')
505509
try:
510+
sslrootcert = _dot_postgresql_path('root.crt')
511+
assert sslrootcert is not None
506512
ssl.load_verify_locations(cafile=sslrootcert)
507-
except FileNotFoundError:
513+
except (AssertionError, FileNotFoundError):
508514
if sslmode > SSLMode.require:
515+
if sslrootcert is None:
516+
raise RuntimeError(
517+
'Cannot determine home directory'
518+
)
509519
raise ValueError(
510520
f'root certificate file "{sslrootcert}" does '
511521
f'not exist\nEither provide the file or '
@@ -526,18 +536,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
526536
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
527537
else:
528538
sslcrl = _dot_postgresql_path('root.crl')
529-
try:
530-
ssl.load_verify_locations(cafile=sslcrl)
531-
except FileNotFoundError:
532-
pass
533-
else:
534-
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
539+
if sslcrl is not None:
540+
try:
541+
ssl.load_verify_locations(cafile=sslcrl)
542+
except FileNotFoundError:
543+
pass
544+
else:
545+
ssl.verify_flags |= \
546+
ssl_module.VERIFY_CRL_CHECK_CHAIN
535547

536548
if sslkey is None:
537549
sslkey = os.getenv('PGSSLKEY')
538550
if not sslkey:
539551
sslkey = _dot_postgresql_path('postgresql.key')
540-
if not sslkey.exists():
552+
if sslkey is not None and not sslkey.exists():
541553
sslkey = None
542554
if not sslpassword:
543555
sslpassword = ''
@@ -549,12 +561,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
549561
)
550562
else:
551563
sslcert = _dot_postgresql_path('postgresql.crt')
552-
try:
553-
ssl.load_cert_chain(
554-
sslcert, keyfile=sslkey, password=lambda: sslpassword
555-
)
556-
except FileNotFoundError:
557-
pass
564+
if sslcert is not None:
565+
try:
566+
ssl.load_cert_chain(
567+
sslcert,
568+
keyfile=sslkey,
569+
password=lambda: sslpassword
570+
)
571+
except FileNotFoundError:
572+
pass
558573

559574
# OpenSSL 1.1.1 keylog file, copied from create_default_context()
560575
if hasattr(ssl, 'keylog_filename'):

tests/test_connect.py

+29
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False):
7171
yield
7272

7373

74+
@contextlib.contextmanager
75+
def mock_no_home_dir():
76+
with unittest.mock.patch(
77+
'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError)
78+
):
79+
yield
80+
81+
7482
class TestSettings(tb.ConnectedTestCase):
7583

7684
async def test_get_settings_01(self):
@@ -1257,6 +1265,27 @@ async def test_connection_implicit_host(self):
12571265
user=conn_spec.get('user'))
12581266
await con.close()
12591267

1268+
@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
1269+
async def test_connection_no_home_dir(self):
1270+
with mock_no_home_dir():
1271+
con = await self.connect(
1272+
dsn='postgresql://foo/',
1273+
user='postgres',
1274+
database='postgres',
1275+
host='localhost')
1276+
await con.fetchval('SELECT 42')
1277+
await con.close()
1278+
1279+
with self.assertRaisesRegex(
1280+
RuntimeError,
1281+
'Cannot determine home directory'
1282+
):
1283+
with mock_no_home_dir():
1284+
await self.connect(
1285+
host='localhost',
1286+
user='ssl_user',
1287+
ssl='verify-full')
1288+
12601289

12611290
class BaseTestSSLConnection(tb.ConnectedTestCase):
12621291
@classmethod

0 commit comments

Comments
 (0)