From 0c18811515660d7be3476e1acf4a48504abd8f01 Mon Sep 17 00:00:00 2001 From: Dzmitry Sauchanka Date: Fri, 12 Apr 2024 17:49:28 +0300 Subject: [PATCH] Allow to specify custom connection reset query builder The default connection reset behaviour might be an overkill for some cases e.g. when respective server capabilities are not used or resources are cleaned up explicitly before returning the connection to the pool. This provides an opt-in way to override this default behaviour. Also see https://github.com/MagicStack/asyncpg/issues/780 --- asyncpg/connect_utils.py | 8 ++++++-- asyncpg/connection.py | 32 ++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 0631f976..c3607569 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -68,6 +68,7 @@ def parse(cls, sslmode): 'statement_cache_size', 'max_cached_statement_lifetime', 'max_cacheable_statement_size', + 'get_reset_query', ]) @@ -690,7 +691,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib): + target_session_attrs, krbsrvname, gsslib, + get_reset_query): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -726,7 +728,9 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, - max_cacheable_statement_size=max_cacheable_statement_size,) + max_cacheable_statement_size=max_cacheable_statement_size, + get_reset_query=get_reset_query, + ) return addrs, params, config diff --git a/asyncpg/connection.py b/asyncpg/connection.py index e54d6df8..a0ea10a0 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1659,25 +1659,25 @@ def _unwrap(self): return con_ref def _get_reset_query(self): - if self._reset_query is not None: - return self._reset_query + if self._reset_query is None: + get_reset_query = self._config.get_reset_query or self._get_default_reset_query + self._reset_query = get_reset_query(self._server_caps) - caps = self._server_caps + return self._reset_query + def _get_default_reset_query(self, server_caps): _reset_query = [] - if caps.advisory_locks: + + if server_caps.advisory_locks: _reset_query.append('SELECT pg_advisory_unlock_all();') - if caps.sql_close_all: + if server_caps.sql_close_all: _reset_query.append('CLOSE ALL;') - if caps.notifications and caps.plpgsql: + if server_caps.notifications and server_caps.plpgsql: _reset_query.append('UNLISTEN *;') - if caps.sql_reset: + if server_caps.sql_reset: _reset_query.append('RESET ALL;') - _reset_query = '\n'.join(_reset_query) - self._reset_query = _reset_query - - return _reset_query + return '\n'.join(_reset_query) def _set_proxy(self, proxy): if self._proxy is not None and proxy is not None: @@ -2009,7 +2009,8 @@ async def connect(dsn=None, *, server_settings=None, target_session_attrs=None, krbsrvname=None, - gsslib=None): + gsslib=None, + get_reset_query=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2245,6 +2246,12 @@ async def connect(dsn=None, *, GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi' or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise. + :param get_reset_query: + Function to build a query that should be executed when resetting + the connection. Takes a single argument of type `~.asyncpg.connection.ServerCapabilities` + that communicates auto-detected server capabilities. + Defaults to `None` which means use the default reset query builder + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2360,6 +2367,7 @@ async def connect(dsn=None, *, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib, + get_reset_query=get_reset_query, )