Skip to content

Commit c337261

Browse files
committed
Handle environments without home dir
1 parent 247b1a5 commit c337261

File tree

3 files changed

+67
-20
lines changed

3 files changed

+67
-20
lines changed

asyncpg/compat.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import pathlib
1010
import platform
11+
import typing
1112

1213

1314
SYSTEM = platform.uname().system
@@ -18,7 +19,7 @@
1819

1920
CSIDL_APPDATA = 0x001a
2021

21-
def get_pg_home_directory() -> pathlib.Path:
22+
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
2223
# We cannot simply use expanduser() as that returns the user's
2324
# home directory, whereas Postgres stores its config in
2425
# %AppData% on Windows.
@@ -30,8 +31,11 @@ def get_pg_home_directory() -> pathlib.Path:
3031
return pathlib.Path(buf.value) / 'postgresql'
3132

3233
else:
33-
def get_pg_home_directory() -> pathlib.Path:
34-
return pathlib.Path.home()
34+
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
35+
try:
36+
return pathlib.Path.home()
37+
except (RuntimeError, KeyError):
38+
return None
3539

3640

3741
async def wait_closed(stream):

asyncpg/connect_utils.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,12 @@ def _parse_tls_version(tls_version):
249249
)
250250

251251

252-
def _dot_postgresql_path(filename) -> pathlib.Path:
253-
return (pathlib.Path.home() / '.postgresql' / filename).resolve()
252+
def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
253+
homedir = compat.get_pg_home_directory()
254+
if homedir is None:
255+
return None
256+
257+
return (homedir / '.postgresql' / filename).resolve()
254258

255259

256260
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
@@ -501,11 +505,16 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
501505
ssl.load_verify_locations(cafile=sslrootcert)
502506
ssl.verify_mode = ssl_module.CERT_REQUIRED
503507
else:
504-
sslrootcert = _dot_postgresql_path('root.crt')
505508
try:
509+
sslrootcert = _dot_postgresql_path('root.crt')
510+
assert sslrootcert is not None
506511
ssl.load_verify_locations(cafile=sslrootcert)
507-
except FileNotFoundError:
512+
except (AssertionError, FileNotFoundError):
508513
if sslmode > SSLMode.require:
514+
if sslrootcert is None:
515+
raise RuntimeError(
516+
'Cannot determine home directory'
517+
)
509518
raise ValueError(
510519
f'root certificate file "{sslrootcert}" does '
511520
f'not exist\nEither provide the file or '
@@ -526,18 +535,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
526535
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
527536
else:
528537
sslcrl = _dot_postgresql_path('root.crl')
529-
try:
530-
ssl.load_verify_locations(cafile=sslcrl)
531-
except FileNotFoundError:
532-
pass
533-
else:
534-
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
538+
if sslcrl is not None:
539+
try:
540+
ssl.load_verify_locations(cafile=sslcrl)
541+
except FileNotFoundError:
542+
pass
543+
else:
544+
ssl.verify_flags |= \
545+
ssl_module.VERIFY_CRL_CHECK_CHAIN
535546

536547
if sslkey is None:
537548
sslkey = os.getenv('PGSSLKEY')
538549
if not sslkey:
539550
sslkey = _dot_postgresql_path('postgresql.key')
540-
if not sslkey.exists():
551+
if sslkey is not None and not sslkey.exists():
541552
sslkey = None
542553
if not sslpassword:
543554
sslpassword = ''
@@ -549,12 +560,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
549560
)
550561
else:
551562
sslcert = _dot_postgresql_path('postgresql.crt')
552-
try:
553-
ssl.load_cert_chain(
554-
sslcert, keyfile=sslkey, password=lambda: sslpassword
555-
)
556-
except FileNotFoundError:
557-
pass
563+
if sslcert is not None:
564+
try:
565+
ssl.load_cert_chain(
566+
sslcert,
567+
keyfile=sslkey,
568+
password=lambda: sslpassword
569+
)
570+
except FileNotFoundError:
571+
pass
558572

559573
# OpenSSL 1.1.1 keylog file, copied from create_default_context()
560574
if hasattr(ssl, 'keylog_filename'):

tests/test_connect.py

+29
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False):
7171
yield
7272

7373

74+
@contextlib.contextmanager
75+
def mock_no_home_dir():
76+
with unittest.mock.patch(
77+
'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError)
78+
):
79+
yield
80+
81+
7482
class TestSettings(tb.ConnectedTestCase):
7583

7684
async def test_get_settings_01(self):
@@ -1257,6 +1265,27 @@ async def test_connection_implicit_host(self):
12571265
user=conn_spec.get('user'))
12581266
await con.close()
12591267

1268+
@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
1269+
async def test_connection_no_home_dir(self):
1270+
with mock_no_home_dir():
1271+
con = await self.connect(
1272+
dsn='postgresql://foo/',
1273+
user='postgres',
1274+
database='postgres',
1275+
host='localhost')
1276+
await con.fetchval('SELECT 42')
1277+
await con.close()
1278+
1279+
with self.assertRaisesRegex(
1280+
RuntimeError,
1281+
'Cannot determine home directory'
1282+
):
1283+
with mock_no_home_dir():
1284+
await self.connect(
1285+
host='localhost',
1286+
user='ssl_user',
1287+
ssl='verify-full')
1288+
12601289

12611290
class BaseTestSSLConnection(tb.ConnectedTestCase):
12621291
@classmethod

0 commit comments

Comments
 (0)