diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 4a656124..85c420c6 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -11,10 +11,12 @@ import collections.abc import functools import itertools +import inspect import os import sys import time import traceback +import typing import warnings import weakref @@ -133,17 +135,21 @@ async def add_listener(self, channel, callback): :param str channel: Channel to listen on. :param callable callback: - A callable receiving the following arguments: + A callable or a coroutine function receiving the following + arguments: **connection**: a Connection the callback is registered with; **pid**: PID of the Postgres server that sent the notification; **channel**: name of the channel the notification was sent to; **payload**: the payload. + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. """ self._check_open() if channel not in self._listeners: await self.fetch('LISTEN {}'.format(utils._quote_ident(channel))) self._listeners[channel] = set() - self._listeners[channel].add(callback) + self._listeners[channel].add(_Callback.from_callable(callback)) async def remove_listener(self, channel, callback): """Remove a listening callback on the specified channel.""" @@ -151,9 +157,10 @@ async def remove_listener(self, channel, callback): return if channel not in self._listeners: return - if callback not in self._listeners[channel]: + cb = _Callback.from_callable(callback) + if cb not in self._listeners[channel]: return - self._listeners[channel].remove(callback) + self._listeners[channel].remove(cb) if not self._listeners[channel]: del self._listeners[channel] await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel))) @@ -166,44 +173,51 @@ def add_log_listener(self, callback): DEBUG, INFO, or LOG. :param callable callback: - A callable receiving the following arguments: + A callable or a coroutine function receiving the following + arguments: **connection**: a Connection the callback is registered with; **message**: the `exceptions.PostgresLogMessage` message. .. versionadded:: 0.12.0 + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. """ if self.is_closed(): raise exceptions.InterfaceError('connection is closed') - self._log_listeners.add(callback) + self._log_listeners.add(_Callback.from_callable(callback)) def remove_log_listener(self, callback): """Remove a listening callback for log messages. .. versionadded:: 0.12.0 """ - self._log_listeners.discard(callback) + self._log_listeners.discard(_Callback.from_callable(callback)) def add_termination_listener(self, callback): """Add a listener that will be called when the connection is closed. :param callable callback: - A callable receiving one argument: + A callable or a coroutine function receiving one argument: **connection**: a Connection the callback is registered with. .. versionadded:: 0.21.0 + + .. versionchanged:: 0.24.0 + The ``callback`` argument may be a coroutine function. """ - self._termination_listeners.add(callback) + self._termination_listeners.add(_Callback.from_callable(callback)) def remove_termination_listener(self, callback): """Remove a listening callback for connection termination. :param callable callback: - The callable that was passed to + The callable or coroutine function that was passed to :meth:`Connection.add_termination_listener`. .. versionadded:: 0.21.0 """ - self._termination_listeners.discard(callback) + self._termination_listeners.discard(_Callback.from_callable(callback)) def get_server_pid(self): """Return the PID of the Postgres server the connection is bound to.""" @@ -1430,18 +1444,10 @@ def _process_log_message(self, fields, last_query): con_ref = self._unwrap() for cb in self._log_listeners: - self._loop.call_soon( - self._call_log_listener, cb, con_ref, message) - - def _call_log_listener(self, cb, con_ref, message): - try: - cb(con_ref, message) - except Exception as ex: - self._loop.call_exception_handler({ - 'message': 'Unhandled exception in asyncpg log message ' - 'listener callback {!r}'.format(cb), - 'exception': ex - }) + if cb.is_async: + self._loop.create_task(cb.cb(con_ref, message)) + else: + self._loop.call_soon(cb.cb, con_ref, message) def _call_termination_listeners(self): if not self._termination_listeners: @@ -1449,16 +1455,10 @@ def _call_termination_listeners(self): con_ref = self._unwrap() for cb in self._termination_listeners: - try: - cb(con_ref) - except Exception as ex: - self._loop.call_exception_handler({ - 'message': ( - 'Unhandled exception in asyncpg connection ' - 'termination listener callback {!r}'.format(cb) - ), - 'exception': ex - }) + if cb.is_async: + self._loop.create_task(cb.cb(con_ref)) + else: + self._loop.call_soon(cb.cb, con_ref) self._termination_listeners.clear() @@ -1468,18 +1468,10 @@ def _process_notification(self, pid, channel, payload): con_ref = self._unwrap() for cb in self._listeners[channel]: - self._loop.call_soon( - self._call_listener, cb, con_ref, pid, channel, payload) - - def _call_listener(self, cb, con_ref, pid, channel, payload): - try: - cb(con_ref, pid, channel, payload) - except Exception as ex: - self._loop.call_exception_handler({ - 'message': 'Unhandled exception in asyncpg notification ' - 'listener callback {!r}'.format(cb), - 'exception': ex - }) + if cb.is_async: + self._loop.create_task(cb.cb(con_ref, pid, channel, payload)) + else: + self._loop.call_soon(cb.cb, con_ref, pid, channel, payload) def _unwrap(self): if self._proxy is None: @@ -2154,6 +2146,26 @@ def _maybe_cleanup(self): self._on_remove(old_entry._statement) +class _Callback(typing.NamedTuple): + + cb: typing.Callable[..., None] + is_async: bool + + @classmethod + def from_callable(cls, cb: typing.Callable[..., None]) -> '_Callback': + if inspect.iscoroutinefunction(cb): + is_async = True + elif callable(cb): + is_async = False + else: + raise exceptions.InterfaceError( + 'expected a callable or an `async def` function,' + 'got {!r}'.format(cb) + ) + + return cls(cb, is_async) + + class _Atomic: __slots__ = ('_acquired',) diff --git a/tests/test_listeners.py b/tests/test_listeners.py index 1af9627c..7fdf0312 100644 --- a/tests/test_listeners.py +++ b/tests/test_listeners.py @@ -23,6 +23,7 @@ async def test_listen_01(self): q1 = asyncio.Queue() q2 = asyncio.Queue() + q3 = asyncio.Queue() def listener1(*args): q1.put_nowait(args) @@ -30,8 +31,12 @@ def listener1(*args): def listener2(*args): q2.put_nowait(args) + async def async_listener3(*args): + q3.put_nowait(args) + await con.add_listener('test', listener1) await con.add_listener('test', listener2) + await con.add_listener('test', async_listener3) await con.execute("NOTIFY test, 'aaaa'") @@ -41,8 +46,12 @@ def listener2(*args): self.assertEqual( await q2.get(), (con, con.get_server_pid(), 'test', 'aaaa')) + self.assertEqual( + await q3.get(), + (con, con.get_server_pid(), 'test', 'aaaa')) await con.remove_listener('test', listener2) + await con.remove_listener('test', async_listener3) await con.execute("NOTIFY test, 'aaaa'") @@ -117,6 +126,7 @@ class TestLogListeners(tb.ConnectedTestCase): }) async def test_log_listener_01(self): q1 = asyncio.Queue() + q2 = asyncio.Queue() def notice_callb(con, message): # Message fields depend on PG version, hide some values. @@ -124,6 +134,12 @@ def notice_callb(con, message): del dct['server_source_line'] q1.put_nowait((con, type(message), dct)) + async def async_notice_callb(con, message): + # Message fields depend on PG version, hide some values. + dct = message.as_dict() + del dct['server_source_line'] + q2.put_nowait((con, type(message), dct)) + async def raise_notice(): await self.con.execute( """DO $$ @@ -140,6 +156,7 @@ async def raise_warning(): con = self.con con.add_log_listener(notice_callb) + con.add_log_listener(async_notice_callb) expected_msg = { 'context': 'PL/pgSQL function inline_code_block line 2 at RAISE', @@ -182,7 +199,21 @@ async def raise_warning(): msg, (con, exceptions.PostgresWarning, expected_msg_warn)) + msg = await q2.get() + msg[2].pop('server_source_filename', None) + self.assertEqual( + msg, + (con, exceptions.PostgresLogMessage, expected_msg_notice)) + + msg = await q2.get() + msg[2].pop('server_source_filename', None) + self.assertEqual( + msg, + (con, exceptions.PostgresWarning, expected_msg_warn)) + con.remove_log_listener(notice_callb) + con.remove_log_listener(async_notice_callb) + await raise_notice() self.assertTrue(q1.empty()) @@ -291,19 +322,26 @@ class TestConnectionTerminationListener(tb.ProxiedClusterTestCase): async def test_connection_termination_callback_called_on_remote(self): called = False + async_called = False def close_cb(con): nonlocal called called = True + async def async_close_cb(con): + nonlocal async_called + async_called = True + con = await self.connect() con.add_termination_listener(close_cb) + con.add_termination_listener(async_close_cb) self.proxy.close_all_connections() try: await con.fetchval('SELECT 1') except Exception: pass self.assertTrue(called) + self.assertTrue(async_called) async def test_connection_termination_callback_called_on_local(self): @@ -316,4 +354,5 @@ def close_cb(con): con = await self.connect() con.add_termination_listener(close_cb) await con.close() + await asyncio.sleep(0) self.assertTrue(called)