Skip to content

Target session attr (2) #987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,93 @@ def tearDown(self):
self.con = None
finally:
super().tearDown()


class HotStandbyTestCase(ClusterTestCase):

@classmethod
def setup_cluster(cls):
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
cls.start_cluster(
cls.master_cluster,
server_settings={
'max_wal_senders': 10,
'wal_level': 'hot_standby'
}
)

con = None

try:
con = cls.loop.run_until_complete(
cls.master_cluster.connect(
database='postgres', user='postgres', loop=cls.loop))

cls.loop.run_until_complete(
con.execute('''
CREATE ROLE replication WITH LOGIN REPLICATION
'''))

cls.master_cluster.trust_local_replication_by('replication')

conn_spec = cls.master_cluster.get_connection_spec()

cls.standby_cluster = cls.new_cluster(
pg_cluster.HotStandbyCluster,
cluster_kwargs={
'master': conn_spec,
'replication_user': 'replication'
}
)
cls.start_cluster(
cls.standby_cluster,
server_settings={
'hot_standby': True
}
)

finally:
if con is not None:
cls.loop.run_until_complete(con.close())

@classmethod
def get_cluster_connection_spec(cls, cluster, kwargs={}):
conn_spec = cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec

@classmethod
def get_connection_spec(cls, kwargs={}):
primary_spec = cls.get_cluster_connection_spec(
cls.master_cluster, kwargs
)
standby_spec = cls.get_cluster_connection_spec(
cls.standby_cluster, kwargs
)
return {
'host': [primary_spec['host'], standby_spec['host']],
'port': [primary_spec['port'], standby_spec['port']],
'database': primary_spec['database'],
'user': primary_spec['user'],
**kwargs
}

@classmethod
def connect_primary(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)

@classmethod
def connect_standby(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(
cls.standby_cluster,
kwargs
)
return pg_connection.connect(**conn_spec, loop=cls.loop)
105 changes: 97 additions & 8 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import pathlib
import platform
import random
import re
import socket
import ssl as ssl_module
Expand Down Expand Up @@ -56,6 +57,7 @@ def parse(cls, sslmode):
'direct_tls',
'connect_timeout',
'server_settings',
'target_session_attribute',
])


Expand Down Expand Up @@ -255,7 +257,8 @@ def _dot_postgresql_path(filename) -> pathlib.Path:

def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, connect_timeout, server_settings):
direct_tls, connect_timeout, server_settings,
target_session_attribute):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -595,7 +598,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
connect_timeout=connect_timeout, server_settings=server_settings)
connect_timeout=connect_timeout, server_settings=server_settings,
target_session_attribute=target_session_attribute)

return addrs, params

Expand All @@ -605,8 +609,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings):

ssl, direct_tls, server_settings,
target_session_attribute):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -634,7 +638,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
connect_timeout=timeout, server_settings=server_settings)
connect_timeout=timeout, server_settings=server_settings,
target_session_attribute=target_session_attribute)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down Expand Up @@ -867,18 +872,84 @@ async def __connect_addr(
return con


class SessionAttribute(str, enum.Enum):
any = 'any'
primary = 'primary'
standby = 'standby'
prefer_standby = 'prefer-standby'
read_write = "read-write"
read_only = "read-only"


def _accept_in_hot_standby(should_be_in_hot_standby: bool):
"""
If the server didn't report "in_hot_standby" at startup, we must determine
the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
If the server allows a connection and states it is in recovery it must
be a replica/standby server.
"""
async def can_be_used(connection):
settings = connection.get_settings()
hot_standby_status = getattr(settings, 'in_hot_standby', None)
if hot_standby_status is not None:
is_in_hot_standby = hot_standby_status == 'on'
else:
is_in_hot_standby = await connection.fetchval(
"SELECT pg_catalog.pg_is_in_recovery()"
)
return is_in_hot_standby == should_be_in_hot_standby

return can_be_used


def _accept_read_only(should_be_read_only: bool):
"""
Verify the server has not set default_transaction_read_only=True
"""
async def can_be_used(connection):
settings = connection.get_settings()
is_readonly = getattr(settings, 'default_transaction_read_only', 'off')

if is_readonly == "on":
return should_be_read_only

return await _accept_in_hot_standby(should_be_read_only)(connection)
return can_be_used


async def _accept_any(_):
return True


target_attrs_check = {
SessionAttribute.any: _accept_any,
SessionAttribute.primary: _accept_in_hot_standby(False),
SessionAttribute.standby: _accept_in_hot_standby(True),
SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
SessionAttribute.read_write: _accept_read_only(False),
SessionAttribute.read_only: _accept_read_only(True),
}


async def _can_use_connection(connection, attr: SessionAttribute):
can_use = target_attrs_check[attr]
return await can_use(connection)


async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
target_attr = params.target_session_attribute

candidates = []
chosen_connection = None
last_error = None
addr = None
for addr in addrs:
before = time.monotonic()
try:
return await _connect_addr(
conn = await _connect_addr(
addr=addr,
loop=loop,
timeout=timeout,
Expand All @@ -887,12 +958,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
connection_class=connection_class,
record_class=record_class,
)
candidates.append(conn)
if await _can_use_connection(conn, target_attr):
chosen_connection = conn
break
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
finally:
timeout -= time.monotonic() - before
else:
if target_attr == SessionAttribute.prefer_standby and candidates:
chosen_connection = random.choice(candidates)

await asyncio.gather(
(c.close() for c in candidates if c is not chosen_connection),
return_exceptions=True
)

if chosen_connection:
return chosen_connection

raise last_error
raise last_error or exceptions.TargetServerAttributeNotMatched(
'None of the hosts match the target attribute requirement '
'{!r}'.format(target_attr)
)


async def _cancel(*, loop, addr, params: _ConnectionParameters,
Expand Down
24 changes: 23 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from . import serverversion
from . import transaction
from . import utils
from .connect_utils import SessionAttribute


class ConnectionMeta(type):
Expand Down Expand Up @@ -1792,7 +1793,8 @@ async def connect(dsn=None, *,
direct_tls=False,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None):
server_settings=None,
target_session_attribute=SessionAttribute.any):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For improved libpq compatibility let's add support for the PGTARGETSESSIONATTRS environment variable also. Hence, the default here should be None (unspecified).

r"""A coroutine to establish a connection to a PostgreSQL server.

The connection parameters may be specified either as a connection
Expand Down Expand Up @@ -2003,6 +2005,16 @@ async def connect(dsn=None, *,
this connection object. Must be a subclass of
:class:`~asyncpg.Record`.

:param SessionAttribute target_session_attribute:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

libpq uses target_session_attrs as the spelling, so let's follow that.

If specified, check that the host has the correct attribute.
Can be one of:
"any": the first successfully connected host
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also read-write and read-only.

"primary": the host must NOT be in hot standby mode
"standby": the host must be in hot standby mode
"prefer-standby": first try to find a standby host, but if
none of the listed hosts is a standby server,
return any of them.

:return: A :class:`~asyncpg.connection.Connection` instance.

Example:
Expand Down Expand Up @@ -2087,6 +2099,15 @@ async def connect(dsn=None, *,
if record_class is not protocol.Record:
_check_record_class(record_class)

try:
target_session_attribute = SessionAttribute(target_session_attribute)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this check to _parse_connect_dsn_and_args. That is also the place where the PGTARGETSESSIONATTRS can be handled in case an explicit argument wasn't passed to connect().

except ValueError as exc:
raise exceptions.InterfaceError(
"target_session_attribute is expected to be one of "
"'any', 'primary', 'standby' or 'prefer-standby'"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest joining __members__._value_ to avoid repeating yourself and future-proof the message.

", got {!r}".format(target_session_attribute)
) from exc

if loop is None:
loop = asyncio.get_event_loop()

Expand All @@ -2109,6 +2130,7 @@ async def connect(dsn=None, *,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attribute=target_session_attribute
)


Expand Down
6 changes: 5 additions & 1 deletion asyncpg/exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError')
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -244,6 +244,10 @@ class ProtocolError(InternalClientError):
"""Unexpected condition in the handling of PostgreSQL protocol input."""


class TargetServerAttributeNotMatched(InternalClientError):
"""Could not find a host that satisfies the target attribute requirement"""


class OutdatedSchemaCacheError(InternalClientError):
"""A value decoding error caused by a schema change before row fetching."""

Expand Down
Loading