Skip to content

Commit 439fb96

Browse files
LeonardBessonlezram
authored andcommitted
Handle environments without home dir (MagicStack#1011)
1 parent 6155213 commit 439fb96

File tree

3 files changed

+68
-20
lines changed

3 files changed

+68
-20
lines changed

asyncpg/compat.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import pathlib
1010
import platform
11+
import typing
1112

1213

1314
SYSTEM = platform.uname().system
@@ -18,7 +19,7 @@
1819

1920
CSIDL_APPDATA = 0x001a
2021

21-
def get_pg_home_directory() -> pathlib.Path:
22+
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
2223
# We cannot simply use expanduser() as that returns the user's
2324
# home directory, whereas Postgres stores its config in
2425
# %AppData% on Windows.
@@ -30,8 +31,11 @@ def get_pg_home_directory() -> pathlib.Path:
3031
return pathlib.Path(buf.value) / 'postgresql'
3132

3233
else:
33-
def get_pg_home_directory() -> pathlib.Path:
34-
return pathlib.Path.home()
34+
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
35+
try:
36+
return pathlib.Path.home()
37+
except (RuntimeError, KeyError):
38+
return None
3539

3640

3741
async def wait_closed(stream):

asyncpg/connect_utils.py

+32-17
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,13 @@ def _parse_tls_version(tls_version):
251251
)
252252

253253

254-
def _dot_postgresql_path(filename) -> pathlib.Path:
255-
return (pathlib.Path.home() / '.postgresql' / filename).resolve()
254+
def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
255+
try:
256+
homedir = pathlib.Path.home()
257+
except (RuntimeError, KeyError):
258+
return None
259+
260+
return (homedir / '.postgresql' / filename).resolve()
256261

257262

258263
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
@@ -504,11 +509,16 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
504509
ssl.load_verify_locations(cafile=sslrootcert)
505510
ssl.verify_mode = ssl_module.CERT_REQUIRED
506511
else:
507-
sslrootcert = _dot_postgresql_path('root.crt')
508512
try:
513+
sslrootcert = _dot_postgresql_path('root.crt')
514+
assert sslrootcert is not None
509515
ssl.load_verify_locations(cafile=sslrootcert)
510-
except FileNotFoundError:
516+
except (AssertionError, FileNotFoundError):
511517
if sslmode > SSLMode.require:
518+
if sslrootcert is None:
519+
raise RuntimeError(
520+
'Cannot determine home directory'
521+
)
512522
raise ValueError(
513523
f'root certificate file "{sslrootcert}" does '
514524
f'not exist\nEither provide the file or '
@@ -529,18 +539,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
529539
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
530540
else:
531541
sslcrl = _dot_postgresql_path('root.crl')
532-
try:
533-
ssl.load_verify_locations(cafile=sslcrl)
534-
except FileNotFoundError:
535-
pass
536-
else:
537-
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
542+
if sslcrl is not None:
543+
try:
544+
ssl.load_verify_locations(cafile=sslcrl)
545+
except FileNotFoundError:
546+
pass
547+
else:
548+
ssl.verify_flags |= \
549+
ssl_module.VERIFY_CRL_CHECK_CHAIN
538550

539551
if sslkey is None:
540552
sslkey = os.getenv('PGSSLKEY')
541553
if not sslkey:
542554
sslkey = _dot_postgresql_path('postgresql.key')
543-
if not sslkey.exists():
555+
if sslkey is not None and not sslkey.exists():
544556
sslkey = None
545557
if not sslpassword:
546558
sslpassword = ''
@@ -552,12 +564,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
552564
)
553565
else:
554566
sslcert = _dot_postgresql_path('postgresql.crt')
555-
try:
556-
ssl.load_cert_chain(
557-
sslcert, keyfile=sslkey, password=lambda: sslpassword
558-
)
559-
except FileNotFoundError:
560-
pass
567+
if sslcert is not None:
568+
try:
569+
ssl.load_cert_chain(
570+
sslcert,
571+
keyfile=sslkey,
572+
password=lambda: sslpassword
573+
)
574+
except FileNotFoundError:
575+
pass
561576

562577
# OpenSSL 1.1.1 keylog file, copied from create_default_context()
563578
if hasattr(ssl, 'keylog_filename'):

tests/test_connect.py

+29
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False):
7171
yield
7272

7373

74+
@contextlib.contextmanager
75+
def mock_no_home_dir():
76+
with unittest.mock.patch(
77+
'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError)
78+
):
79+
yield
80+
81+
7482
class TestSettings(tb.ConnectedTestCase):
7583

7684
async def test_get_settings_01(self):
@@ -1299,6 +1307,27 @@ async def test_connection_implicit_host(self):
12991307
user=conn_spec.get('user'))
13001308
await con.close()
13011309

1310+
@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
1311+
async def test_connection_no_home_dir(self):
1312+
with mock_no_home_dir():
1313+
con = await self.connect(
1314+
dsn='postgresql://foo/',
1315+
user='postgres',
1316+
database='postgres',
1317+
host='localhost')
1318+
await con.fetchval('SELECT 42')
1319+
await con.close()
1320+
1321+
with self.assertRaisesRegex(
1322+
RuntimeError,
1323+
'Cannot determine home directory'
1324+
):
1325+
with mock_no_home_dir():
1326+
await self.connect(
1327+
host='localhost',
1328+
user='ssl_user',
1329+
ssl='verify-full')
1330+
13021331

13031332
class BaseTestSSLConnection(tb.ConnectedTestCase):
13041333
@classmethod

0 commit comments

Comments
 (0)