Skip to content

Commit d824333

Browse files
author
Jesse De Loore
committed
Update based on review.
1 parent 4e97acc commit d824333

File tree

3 files changed

+34
-25
lines changed

3 files changed

+34
-25
lines changed

asyncpg/connect_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def parse(cls, sslmode):
5757
'direct_tls',
5858
'connect_timeout',
5959
'server_settings',
60-
'target_session_attribute',
60+
'target_session_attrs',
6161
])
6262

6363

@@ -258,7 +258,7 @@ def _dot_postgresql_path(filename) -> pathlib.Path:
258258
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
259259
password, passfile, database, ssl,
260260
direct_tls, connect_timeout, server_settings,
261-
target_session_attribute):
261+
target_session_attrs):
262262
# `auth_hosts` is the version of host information for the purposes
263263
# of reading the pgpass file.
264264
auth_hosts = None
@@ -595,11 +595,24 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
595595
'server_settings is expected to be None or '
596596
'a Dict[str, str]')
597597

598+
if target_session_attrs is None:
599+
600+
target_session_attrs = os.getenv("PGTARGETSESSIONATTRS", SessionAttribute.any)
601+
try:
602+
603+
target_session_attrs = SessionAttribute(target_session_attrs)
604+
except ValueError as exc:
605+
raise exceptions.InterfaceError(
606+
"target_session_attrs is expected to be one of "
607+
"{!r}"
608+
", got {!r}".format(SessionAttribute.__members__.values, target_session_attrs)
609+
) from exc
610+
598611
params = _ConnectionParameters(
599612
user=user, password=password, database=database, ssl=ssl,
600613
sslmode=sslmode, direct_tls=direct_tls,
601614
connect_timeout=connect_timeout, server_settings=server_settings,
602-
target_session_attribute=target_session_attribute)
615+
target_session_attrs=target_session_attrs)
603616

604617
return addrs, params
605618

@@ -610,7 +623,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
610623
max_cached_statement_lifetime,
611624
max_cacheable_statement_size,
612625
ssl, direct_tls, server_settings,
613-
target_session_attribute):
626+
target_session_attrs):
614627
local_vars = locals()
615628
for var_name in {'max_cacheable_statement_size',
616629
'max_cached_statement_lifetime',
@@ -639,7 +652,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
639652
password=password, passfile=passfile, ssl=ssl,
640653
direct_tls=direct_tls, database=database,
641654
connect_timeout=timeout, server_settings=server_settings,
642-
target_session_attribute=target_session_attribute)
655+
target_session_attrs=target_session_attrs)
643656

644657
config = _ClientConfiguration(
645658
command_timeout=command_timeout,
@@ -941,7 +954,7 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
941954
loop = asyncio.get_event_loop()
942955

943956
addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
944-
target_attr = params.target_session_attribute
957+
target_attr = params.target_session_attrs
945958

946959
candidates = []
947960
chosen_connection = None

asyncpg/connection.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,7 @@ async def connect(dsn=None, *,
17941794
connection_class=Connection,
17951795
record_class=protocol.Record,
17961796
server_settings=None,
1797-
target_session_attribute=SessionAttribute.any):
1797+
target_session_attrs=None):
17981798
r"""A coroutine to establish a connection to a PostgreSQL server.
17991799
18001800
The connection parameters may be specified either as a connection
@@ -2005,16 +2005,21 @@ async def connect(dsn=None, *,
20052005
this connection object. Must be a subclass of
20062006
:class:`~asyncpg.Record`.
20072007
2008-
:param SessionAttribute target_session_attribute:
2008+
:param SessionAttribute target_session_attrs:
20092009
If specified, check that the host has the correct attribute.
20102010
Can be one of:
20112011
"any": the first successfully connected host
20122012
"primary": the host must NOT be in hot standby mode
20132013
"standby": the host must be in hot standby mode
2014+
"read-write": the host must allow writes
2015+
"read-only": the host most NOT allow writes
20142016
"prefer-standby": first try to find a standby host, but if
20152017
none of the listed hosts is a standby server,
20162018
return any of them.
20172019
2020+
If not specified will try to use PGTARGETSESSIONATTRS from the environment.
2021+
Defaults to "any" if no value is set.
2022+
20182023
:return: A :class:`~asyncpg.connection.Connection` instance.
20192024
20202025
Example:
@@ -2099,15 +2104,6 @@ async def connect(dsn=None, *,
20992104
if record_class is not protocol.Record:
21002105
_check_record_class(record_class)
21012106

2102-
try:
2103-
target_session_attribute = SessionAttribute(target_session_attribute)
2104-
except ValueError as exc:
2105-
raise exceptions.InterfaceError(
2106-
"target_session_attribute is expected to be one of "
2107-
"'any', 'primary', 'standby' or 'prefer-standby'"
2108-
", got {!r}".format(target_session_attribute)
2109-
) from exc
2110-
21112107
if loop is None:
21122108
loop = asyncio.get_event_loop()
21132109

@@ -2130,7 +2126,7 @@ async def connect(dsn=None, *,
21302126
statement_cache_size=statement_cache_size,
21312127
max_cached_statement_lifetime=max_cached_statement_lifetime,
21322128
max_cacheable_statement_size=max_cacheable_statement_size,
2133-
target_session_attribute=target_session_attribute
2129+
target_session_attrs=target_session_attrs
21342130
)
21352131

21362132

tests/test_connect.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def run_testcase(self, testcase):
788788
database = testcase.get('database')
789789
sslmode = testcase.get('ssl')
790790
server_settings = testcase.get('server_settings')
791-
target_session_attribute = testcase.get('target_session_attribute')
791+
target_session_attrs = testcase.get('target_session_attrs')
792792

793793
expected = testcase.get('result')
794794
expected_error = testcase.get('error')
@@ -813,7 +813,7 @@ def run_testcase(self, testcase):
813813
passfile=passfile, database=database, ssl=sslmode,
814814
direct_tls=False, connect_timeout=None,
815815
server_settings=server_settings,
816-
target_session_attribute=target_session_attribute)
816+
target_session_attrs=target_session_attrs)
817817

818818
params = {
819819
k: v for k, v in params._asdict().items()
@@ -1750,7 +1750,7 @@ class TestConnectionAttributes(tb.HotStandbyTestCase):
17501750
async def _run_connection_test(
17511751
self, connect, target_attribute, expected_port
17521752
):
1753-
conn = await connect(target_session_attribute=target_attribute)
1753+
conn = await connect(target_session_attrs=target_attribute)
17541754
self.assertTrue(_get_connected_host(conn).endswith(expected_port))
17551755
await conn.close()
17561756

@@ -1790,7 +1790,7 @@ async def test_target_attribute_not_matched(self):
17901790

17911791
for connect, target_attr in tests:
17921792
with self.assertRaises(exceptions.TargetServerAttributeNotMatched):
1793-
await connect(target_session_attribute=target_attr)
1793+
await connect(target_session_attrs=target_attr)
17941794

17951795
if self.master_cluster.get_pg_version()[0] < 14:
17961796
self.skipTest("PostgreSQL<14 does not support these features")
@@ -1801,12 +1801,12 @@ async def test_target_attribute_not_matched(self):
18011801

18021802
for connect, target_attr in tests:
18031803
with self.assertRaises(exceptions.TargetServerAttributeNotMatched):
1804-
await connect(target_session_attribute=target_attr)
1804+
await connect(target_session_attrs=target_attr)
18051805

18061806
async def test_prefer_standby_when_standby_is_up(self):
18071807
if self.master_cluster.get_pg_version()[0] == 11:
18081808
self.skipTest("PostgreSQL 11 seems to have issues with this test")
1809-
con = await self.connect(target_session_attribute='prefer-standby')
1809+
con = await self.connect(target_session_attrs='prefer-standby')
18101810
standby_port = self.standby_cluster.get_connection_spec()['port']
18111811
connected_host = _get_connected_host(con)
18121812
self.assertTrue(connected_host.endswith(standby_port))
@@ -1824,7 +1824,7 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self):
18241824
'port': [primary_spec['port'], 15345],
18251825
'database': primary_spec['database'],
18261826
'user': primary_spec['user'],
1827-
'target_session_attribute': 'prefer-standby'
1827+
'target_session_attrs': 'prefer-standby'
18281828
}
18291829

18301830
con = await self.connect(**connection_spec)

0 commit comments

Comments
 (0)