Skip to content

Commit 8185e17

Browse files
committed
Add aclose() to asyncio.client.PubSub
close() and reset() retained as aliases
1 parent 386ccf4 commit 8185e17

File tree

3 files changed

+85
-32
lines changed

3 files changed

+85
-32
lines changed

redis/asyncio/client.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
List,
1515
Mapping,
1616
MutableMapping,
17-
NoReturn,
1817
Optional,
1918
Set,
2019
Tuple,
@@ -729,13 +728,18 @@ async def __aenter__(self):
729728
return self
730729

731730
async def __aexit__(self, exc_type, exc_value, traceback):
732-
await self.reset()
731+
await self.aclose()
733732

734733
def __del__(self):
735734
if self.connection:
736735
self.connection.clear_connect_callbacks()
737736

738-
async def reset(self):
737+
async def aclose(self):
738+
# In case a connection property does not yet exist
739+
# (due to a crash earlier in the Redis() constructor), return
740+
# immediately as there is nothing to clean-up.
741+
if not hasattr(self, "connection"):
742+
return
739743
async with self._lock:
740744
if self.connection:
741745
await self.connection.disconnect()
@@ -747,13 +751,13 @@ async def reset(self):
747751
self.patterns = {}
748752
self.pending_unsubscribe_patterns = set()
749753

750-
def close(self) -> Awaitable[NoReturn]:
751-
# In case a connection property does not yet exist
752-
# (due to a crash earlier in the Redis() constructor), return
753-
# immediately as there is nothing to clean-up.
754-
if not hasattr(self, "connection"):
755-
return
756-
return self.reset()
754+
async def close(self) -> None:
755+
"""Alias for aclose(), for backwards compatibility"""
756+
await self.aclose()
757+
758+
async def reset(self) -> None:
759+
"""alias for aclose(), for backwards compatibility"""
760+
await self.aclose()
757761

758762
async def on_connect(self, connection: Connection):
759763
"""Re-subscribe to any channels and patterns previously subscribed to"""
@@ -1229,14 +1233,14 @@ async def _disconnect_reset_raise(self, conn, error):
12291233
# valid since this connection has died. raise a WatchError, which
12301234
# indicates the user should retry this transaction.
12311235
if self.watching:
1232-
await self.reset()
1236+
await self.aclose()
12331237
raise WatchError(
12341238
"A ConnectionError occurred on while watching one or more keys"
12351239
)
12361240
# if retry_on_timeout is not set, or the error is not
12371241
# a TimeoutError, raise it
12381242
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
1239-
await self.reset()
1243+
await self.aclose()
12401244
raise
12411245

12421246
async def immediate_execute_command(self, *args, **options):

tests/test_asyncio/compat.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@
66
except AttributeError:
77
import mock
88

9+
try:
10+
from contextlib import aclosing
11+
except ImportError:
12+
import contextlib
13+
14+
@contextlib.asynccontextmanager
15+
async def aclosing(thing):
16+
try:
17+
yield thing
18+
finally:
19+
await thing.aclose()
20+
921

1022
def create_task(coroutine):
1123
return asyncio.create_task(coroutine)

tests/test_asyncio/test_pubsub.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from redis.utils import HIREDIS_AVAILABLE
2121
from tests.conftest import get_protocol_version, skip_if_server_version_lt
2222

23-
from .compat import create_task, mock
23+
from .compat import aclosing, create_task, mock
2424

2525

2626
def with_timeout(t):
@@ -84,9 +84,8 @@ def make_subscribe_test_data(pubsub, type):
8484

8585
@pytest_asyncio.fixture()
8686
async def pubsub(r: redis.Redis):
87-
p = r.pubsub()
88-
yield p
89-
await p.close()
87+
async with r.pubsub() as p:
88+
yield p
9089

9190

9291
@pytest.mark.onlynoncluster
@@ -217,6 +216,44 @@ async def test_subscribe_property_with_patterns(self, pubsub):
217216
kwargs = make_subscribe_test_data(pubsub, "pattern")
218217
await self._test_subscribed_property(**kwargs)
219218

219+
async def test_aclosing(self, r: redis.Redis):
220+
p = r.pubsub()
221+
async with aclosing(p):
222+
assert p.subscribed is False
223+
await p.subscribe("foo")
224+
assert p.subscribed is True
225+
assert p.subscribed is False
226+
227+
async def test_context_manager(self, r: redis.Redis):
228+
p = r.pubsub()
229+
async with p:
230+
assert p.subscribed is False
231+
await p.subscribe("foo")
232+
assert p.subscribed is True
233+
assert p.subscribed is False
234+
235+
async def test_close_is_aclose(self, r: redis.Redis):
236+
"""
237+
Test backwards compatible close method
238+
"""
239+
p = r.pubsub()
240+
assert p.subscribed is False
241+
await p.subscribe("foo")
242+
assert p.subscribed is True
243+
await p.close()
244+
assert p.subscribed is False
245+
246+
async def test_reset_is_aclose(self, r: redis.Redis):
247+
"""
248+
Test backwards compatible reset method
249+
"""
250+
p = r.pubsub()
251+
assert p.subscribed is False
252+
await p.subscribe("foo")
253+
assert p.subscribed is True
254+
await p.reset()
255+
assert p.subscribed is False
256+
220257
async def test_ignore_all_subscribe_messages(self, r: redis.Redis):
221258
p = r.pubsub(ignore_subscribe_messages=True)
222259

@@ -233,7 +270,7 @@ async def test_ignore_all_subscribe_messages(self, r: redis.Redis):
233270
assert p.subscribed is True
234271
assert await wait_for_message(p) is None
235272
assert p.subscribed is False
236-
await p.close()
273+
await p.aclose()
237274

238275
async def test_ignore_individual_subscribe_messages(self, pubsub):
239276
p = pubsub
@@ -350,7 +387,7 @@ async def test_channel_message_handler(self, r: redis.Redis):
350387
assert await r.publish("foo", "test message") == 1
351388
assert await wait_for_message(p) is None
352389
assert self.message == make_message("message", "foo", "test message")
353-
await p.close()
390+
await p.aclose()
354391

355392
async def test_channel_async_message_handler(self, r):
356393
p = r.pubsub(ignore_subscribe_messages=True)
@@ -359,7 +396,7 @@ async def test_channel_async_message_handler(self, r):
359396
assert await r.publish("foo", "test message") == 1
360397
assert await wait_for_message(p) is None
361398
assert self.async_message == make_message("message", "foo", "test message")
362-
await p.close()
399+
await p.aclose()
363400

364401
async def test_channel_sync_async_message_handler(self, r):
365402
p = r.pubsub(ignore_subscribe_messages=True)
@@ -371,7 +408,7 @@ async def test_channel_sync_async_message_handler(self, r):
371408
assert await wait_for_message(p) is None
372409
assert self.message == make_message("message", "foo", "test message")
373410
assert self.async_message == make_message("message", "bar", "test message 2")
374-
await p.close()
411+
await p.aclose()
375412

376413
@pytest.mark.onlynoncluster
377414
async def test_pattern_message_handler(self, r: redis.Redis):
@@ -383,7 +420,7 @@ async def test_pattern_message_handler(self, r: redis.Redis):
383420
assert self.message == make_message(
384421
"pmessage", "foo", "test message", pattern="f*"
385422
)
386-
await p.close()
423+
await p.aclose()
387424

388425
async def test_unicode_channel_message_handler(self, r: redis.Redis):
389426
p = r.pubsub(ignore_subscribe_messages=True)
@@ -394,7 +431,7 @@ async def test_unicode_channel_message_handler(self, r: redis.Redis):
394431
assert await r.publish(channel, "test message") == 1
395432
assert await wait_for_message(p) is None
396433
assert self.message == make_message("message", channel, "test message")
397-
await p.close()
434+
await p.aclose()
398435

399436
@pytest.mark.onlynoncluster
400437
# see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html
@@ -410,7 +447,7 @@ async def test_unicode_pattern_message_handler(self, r: redis.Redis):
410447
assert self.message == make_message(
411448
"pmessage", channel, "test message", pattern=pattern
412449
)
413-
await p.close()
450+
await p.aclose()
414451

415452
async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub):
416453
p = pubsub
@@ -524,7 +561,7 @@ async def test_channel_message_handler(self, r: redis.Redis):
524561
await r.publish(self.channel, new_data)
525562
assert await wait_for_message(p) is None
526563
assert self.message == self.make_message("message", self.channel, new_data)
527-
await p.close()
564+
await p.aclose()
528565

529566
async def test_pattern_message_handler(self, r: redis.Redis):
530567
p = r.pubsub(ignore_subscribe_messages=True)
@@ -546,7 +583,7 @@ async def test_pattern_message_handler(self, r: redis.Redis):
546583
assert self.message == self.make_message(
547584
"pmessage", self.channel, new_data, pattern=self.pattern
548585
)
549-
await p.close()
586+
await p.aclose()
550587

551588
async def test_context_manager(self, r: redis.Redis):
552589
async with r.pubsub() as pubsub:
@@ -556,7 +593,7 @@ async def test_context_manager(self, r: redis.Redis):
556593
assert pubsub.connection is None
557594
assert pubsub.channels == {}
558595
assert pubsub.patterns == {}
559-
await pubsub.close()
596+
await pubsub.aclose()
560597

561598

562599
@pytest.mark.onlynoncluster
@@ -597,9 +634,9 @@ async def test_pubsub_numsub(self, r: redis.Redis):
597634

598635
channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)]
599636
assert await r.pubsub_numsub("foo", "bar", "baz") == channels
600-
await p1.close()
601-
await p2.close()
602-
await p3.close()
637+
await p1.aclose()
638+
await p2.aclose()
639+
await p3.aclose()
603640

604641
@skip_if_server_version_lt("2.8.0")
605642
async def test_pubsub_numpat(self, r: redis.Redis):
@@ -608,7 +645,7 @@ async def test_pubsub_numpat(self, r: redis.Redis):
608645
for i in range(3):
609646
assert (await wait_for_message(p))["type"] == "psubscribe"
610647
assert await r.pubsub_numpat() == 3
611-
await p.close()
648+
await p.aclose()
612649

613650

614651
@pytest.mark.onlynoncluster
@@ -621,7 +658,7 @@ async def test_send_pubsub_ping(self, r: redis.Redis):
621658
assert await wait_for_message(p) == make_message(
622659
type="pong", channel=None, data="", pattern=None
623660
)
624-
await p.close()
661+
await p.aclose()
625662

626663
@skip_if_server_version_lt("3.0.0")
627664
async def test_send_pubsub_ping_message(self, r: redis.Redis):
@@ -631,7 +668,7 @@ async def test_send_pubsub_ping_message(self, r: redis.Redis):
631668
assert await wait_for_message(p) == make_message(
632669
type="pong", channel=None, data="hello world", pattern=None
633670
)
634-
await p.close()
671+
await p.aclose()
635672

636673

637674
@pytest.mark.onlynoncluster

0 commit comments

Comments
 (0)