Skip to content

Add support for coroutine functions as listener callbacks #802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 57 additions & 45 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import collections.abc
import functools
import itertools
import inspect
import os
import sys
import time
import traceback
import typing
import warnings
import weakref

Expand Down Expand Up @@ -133,27 +135,32 @@ async def add_listener(self, channel, callback):
:param str channel: Channel to listen on.

:param callable callback:
A callable receiving the following arguments:
A callable or a coroutine function receiving the following
arguments:
**connection**: a Connection the callback is registered with;
**pid**: PID of the Postgres server that sent the notification;
**channel**: name of the channel the notification was sent to;
**payload**: the payload.

.. versionchanged:: 0.24.0
The ``callback`` argument may be a coroutine function.
"""
self._check_open()
if channel not in self._listeners:
await self.fetch('LISTEN {}'.format(utils._quote_ident(channel)))
self._listeners[channel] = set()
self._listeners[channel].add(callback)
self._listeners[channel].add(_Callback.from_callable(callback))

async def remove_listener(self, channel, callback):
"""Remove a listening callback on the specified channel."""
if self.is_closed():
return
if channel not in self._listeners:
return
if callback not in self._listeners[channel]:
cb = _Callback.from_callable(callback)
if cb not in self._listeners[channel]:
return
self._listeners[channel].remove(callback)
self._listeners[channel].remove(cb)
if not self._listeners[channel]:
del self._listeners[channel]
await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel)))
Expand All @@ -166,44 +173,51 @@ def add_log_listener(self, callback):
DEBUG, INFO, or LOG.

:param callable callback:
A callable receiving the following arguments:
A callable or a coroutine function receiving the following
arguments:
**connection**: a Connection the callback is registered with;
**message**: the `exceptions.PostgresLogMessage` message.

.. versionadded:: 0.12.0

.. versionchanged:: 0.24.0
The ``callback`` argument may be a coroutine function.
"""
if self.is_closed():
raise exceptions.InterfaceError('connection is closed')
self._log_listeners.add(callback)
self._log_listeners.add(_Callback.from_callable(callback))

def remove_log_listener(self, callback):
"""Remove a listening callback for log messages.

.. versionadded:: 0.12.0
"""
self._log_listeners.discard(callback)
self._log_listeners.discard(_Callback.from_callable(callback))

def add_termination_listener(self, callback):
"""Add a listener that will be called when the connection is closed.

:param callable callback:
A callable receiving one argument:
A callable or a coroutine function receiving one argument:
**connection**: a Connection the callback is registered with.

.. versionadded:: 0.21.0

.. versionchanged:: 0.24.0
The ``callback`` argument may be a coroutine function.
"""
self._termination_listeners.add(callback)
self._termination_listeners.add(_Callback.from_callable(callback))

def remove_termination_listener(self, callback):
"""Remove a listening callback for connection termination.

:param callable callback:
The callable that was passed to
The callable or coroutine function that was passed to
:meth:`Connection.add_termination_listener`.

.. versionadded:: 0.21.0
"""
self._termination_listeners.discard(callback)
self._termination_listeners.discard(_Callback.from_callable(callback))

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
Expand Down Expand Up @@ -1430,35 +1444,21 @@ def _process_log_message(self, fields, last_query):

con_ref = self._unwrap()
for cb in self._log_listeners:
self._loop.call_soon(
self._call_log_listener, cb, con_ref, message)

def _call_log_listener(self, cb, con_ref, message):
try:
cb(con_ref, message)
except Exception as ex:
self._loop.call_exception_handler({
'message': 'Unhandled exception in asyncpg log message '
'listener callback {!r}'.format(cb),
'exception': ex
})
if cb.is_async:
self._loop.create_task(cb.cb(con_ref, message))
else:
self._loop.call_soon(cb.cb, con_ref, message)

def _call_termination_listeners(self):
if not self._termination_listeners:
return

con_ref = self._unwrap()
for cb in self._termination_listeners:
try:
cb(con_ref)
except Exception as ex:
self._loop.call_exception_handler({
'message': (
'Unhandled exception in asyncpg connection '
'termination listener callback {!r}'.format(cb)
),
'exception': ex
})
if cb.is_async:
self._loop.create_task(cb.cb(con_ref))
else:
self._loop.call_soon(cb.cb, con_ref)

self._termination_listeners.clear()

Expand All @@ -1468,18 +1468,10 @@ def _process_notification(self, pid, channel, payload):

con_ref = self._unwrap()
for cb in self._listeners[channel]:
self._loop.call_soon(
self._call_listener, cb, con_ref, pid, channel, payload)

def _call_listener(self, cb, con_ref, pid, channel, payload):
try:
cb(con_ref, pid, channel, payload)
except Exception as ex:
self._loop.call_exception_handler({
'message': 'Unhandled exception in asyncpg notification '
'listener callback {!r}'.format(cb),
'exception': ex
})
if cb.is_async:
self._loop.create_task(cb.cb(con_ref, pid, channel, payload))
else:
self._loop.call_soon(cb.cb, con_ref, pid, channel, payload)

def _unwrap(self):
if self._proxy is None:
Expand Down Expand Up @@ -2154,6 +2146,26 @@ def _maybe_cleanup(self):
self._on_remove(old_entry._statement)


class _Callback(typing.NamedTuple):

cb: typing.Callable[..., None]
is_async: bool

@classmethod
def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback':
if inspect.iscoroutinefunction(cb):
is_async = True
elif callable(cb):
is_async = False
else:
raise exceptions.InterfaceError(
'expected a callable or an `async def` function,'
'got {!r}'.format(cb)
)

return cls(cb, is_async)


class _Atomic:
__slots__ = ('_acquired',)

Expand Down
39 changes: 39 additions & 0 deletions tests/test_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@ async def test_listen_01(self):

q1 = asyncio.Queue()
q2 = asyncio.Queue()
q3 = asyncio.Queue()

def listener1(*args):
q1.put_nowait(args)

def listener2(*args):
q2.put_nowait(args)

async def async_listener3(*args):
q3.put_nowait(args)

await con.add_listener('test', listener1)
await con.add_listener('test', listener2)
await con.add_listener('test', async_listener3)

await con.execute("NOTIFY test, 'aaaa'")

Expand All @@ -41,8 +46,12 @@ def listener2(*args):
self.assertEqual(
await q2.get(),
(con, con.get_server_pid(), 'test', 'aaaa'))
self.assertEqual(
await q3.get(),
(con, con.get_server_pid(), 'test', 'aaaa'))

await con.remove_listener('test', listener2)
await con.remove_listener('test', async_listener3)

await con.execute("NOTIFY test, 'aaaa'")

Expand Down Expand Up @@ -117,13 +126,20 @@ class TestLogListeners(tb.ConnectedTestCase):
})
async def test_log_listener_01(self):
q1 = asyncio.Queue()
q2 = asyncio.Queue()

def notice_callb(con, message):
# Message fields depend on PG version, hide some values.
dct = message.as_dict()
del dct['server_source_line']
q1.put_nowait((con, type(message), dct))

async def async_notice_callb(con, message):
# Message fields depend on PG version, hide some values.
dct = message.as_dict()
del dct['server_source_line']
q2.put_nowait((con, type(message), dct))

async def raise_notice():
await self.con.execute(
"""DO $$
Expand All @@ -140,6 +156,7 @@ async def raise_warning():

con = self.con
con.add_log_listener(notice_callb)
con.add_log_listener(async_notice_callb)

expected_msg = {
'context': 'PL/pgSQL function inline_code_block line 2 at RAISE',
Expand Down Expand Up @@ -182,7 +199,21 @@ async def raise_warning():
msg,
(con, exceptions.PostgresWarning, expected_msg_warn))

msg = await q2.get()
msg[2].pop('server_source_filename', None)
self.assertEqual(
msg,
(con, exceptions.PostgresLogMessage, expected_msg_notice))

msg = await q2.get()
msg[2].pop('server_source_filename', None)
self.assertEqual(
msg,
(con, exceptions.PostgresWarning, expected_msg_warn))

con.remove_log_listener(notice_callb)
con.remove_log_listener(async_notice_callb)

await raise_notice()
self.assertTrue(q1.empty())

Expand Down Expand Up @@ -291,19 +322,26 @@ class TestConnectionTerminationListener(tb.ProxiedClusterTestCase):
async def test_connection_termination_callback_called_on_remote(self):

called = False
async_called = False

def close_cb(con):
nonlocal called
called = True

async def async_close_cb(con):
nonlocal async_called
async_called = True

con = await self.connect()
con.add_termination_listener(close_cb)
con.add_termination_listener(async_close_cb)
self.proxy.close_all_connections()
try:
await con.fetchval('SELECT 1')
except Exception:
pass
self.assertTrue(called)
self.assertTrue(async_called)

async def test_connection_termination_callback_called_on_local(self):

Expand All @@ -316,4 +354,5 @@ def close_cb(con):
con = await self.connect()
con.add_termination_listener(close_cb)
await con.close()
await asyncio.sleep(0)
self.assertTrue(called)