Skip to content

Add connect kwarg to Pool to better support GCP's CloudSQL #1170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
connect=None,
setup=None,
init=None,
loop=None,
Expand All @@ -271,12 +272,18 @@ def create_pool(dsn=None, *,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class,
**connect_kwargs)
**connect_kwargs,
)


class ClusterTestCase(TestCase):
Expand Down
49 changes: 41 additions & 8 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class Pool:

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect_args', '_connect_kwargs',
'_init', '_connect', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
Expand All @@ -324,8 +324,9 @@ def __init__(self, *connect_args,
max_size,
max_queries,
max_inactive_connection_lifetime,
setup,
init,
connect=None,
setup=None,
init=None,
loop,
connection_class,
record_class,
Expand Down Expand Up @@ -385,11 +386,14 @@ def __init__(self, *connect_args,
self._closing = False
self._closed = False
self._generation = 0
self._init = init

self._connect = connect if connect is not None else connection.connect
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs

self._setup = setup
self._init = init

self._max_queries = max_queries
self._max_inactive_connection_lifetime = \
max_inactive_connection_lifetime
Expand Down Expand Up @@ -503,13 +507,25 @@ def set_connect_args(self, dsn=None, **connect_kwargs):
self._connect_kwargs = connect_kwargs

async def _get_new_connection(self):
con = await connection.connect(
con = await self._connect(
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
record_class=self._record_class,
**self._connect_kwargs,
)
if not isinstance(con, self._connection_class):
good = self._connection_class
good_n = f'{good.__module__}.{good.__name__}'
bad = type(con)
if bad.__module__ == "builtins":
bad_n = bad.__name__
else:
bad_n = f'{bad.__module__}.{bad.__name__}'
raise exceptions.InterfaceError(
"expected pool connect callback to return an instance of "
f"'{good_n}', got " f"'{bad_n}'"
)

if self._init is not None:
try:
Expand Down Expand Up @@ -1001,6 +1017,7 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
connect=None,
setup=None,
init=None,
loop=None,
Expand Down Expand Up @@ -1083,6 +1100,13 @@ def create_pool(dsn=None, *,
Number of seconds after which inactive connections in the
pool will be closed. Pass ``0`` to disable this mechanism.

:param coroutine connect:
A coroutine that is called instead of
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
new connection. Must return an instance of type specified by
*connection_class* or :class:`~asyncpg.connection.Connection` if
*connection_class* was not specified.

:param coroutine setup:
A coroutine to prepare a connection right before it is returned
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
Expand Down Expand Up @@ -1123,12 +1147,21 @@ def create_pool(dsn=None, *,

.. versionchanged:: 0.22.0
Added the *record_class* parameter.

.. versionchanged:: 0.30.0
Added the *connect* parameter.
"""
return Pool(
dsn,
connection_class=connection_class,
record_class=record_class,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs)
**connect_kwargs,
)
21 changes: 20 additions & 1 deletion tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ async def setup(con):

async def test_pool_07(self):
cons = set()
connect_called = 0

async def connect(*args, **kwargs):
nonlocal connect_called
connect_called += 1
return await pg_connection.connect(*args, **kwargs)

async def setup(con):
if con._con not in cons: # `con` is `PoolConnectionProxy`.
Expand All @@ -152,13 +158,26 @@ async def user(pool):
raise RuntimeError('init was not called')

async with self.create_pool(database='postgres',
min_size=2, max_size=5,
min_size=2,
max_size=5,
connect=connect,
init=init,
setup=setup) as pool:
users = asyncio.gather(*[user(pool) for _ in range(10)])
await users

self.assertEqual(len(cons), 5)
self.assertEqual(connect_called, 5)

async def bad_connect(*args, **kwargs):
return 1

with self.assertRaisesRegex(
asyncpg.InterfaceError,
"expected pool connect callback to return an instance of "
"'asyncpg\\.connection\\.Connection', got 'int'"
):
await self.create_pool(database='postgres', connect=bad_connect)

async def test_pool_08(self):
pool = await self.create_pool(database='postgres',
Expand Down
Loading