1
- import asyncio
1
+ import functools
2
2
import random
3
3
import sys
4
4
from typing import Union
5
5
from urllib .parse import urlparse
6
6
7
- if sys .version_info [0 :2 ] == (3 , 6 ):
8
- import pytest as pytest_asyncio
9
- else :
10
- import pytest_asyncio
11
-
12
7
import pytest
13
8
from packaging .version import Version
14
9
26
21
27
22
from .compat import mock
28
23
24
+ if sys .version_info [0 :2 ] == (3 , 6 ):
25
+ import pytest as pytest_asyncio
26
+
27
+ pytestmark = pytest .mark .asyncio
28
+ else :
29
+ import pytest_asyncio
30
+
29
31
30
32
async def _get_info (redis_url ):
31
33
client = redis .Redis .from_url (redis_url )
@@ -69,11 +71,13 @@ async def _get_info(redis_url):
69
71
"pool-hiredis" ,
70
72
],
71
73
)
72
- def create_redis (request , event_loop : asyncio . BaseEventLoop ):
74
+ async def create_redis (request ):
73
75
"""Wrapper around redis.create_redis."""
74
76
single_connection , parser_cls = request .param
75
77
76
- async def f (
78
+ teardown_clients = []
79
+
80
+ async def client_factory (
77
81
url : str = request .config .getoption ("--redis-url" ),
78
82
cls = redis .Redis ,
79
83
flushdb = True ,
@@ -95,56 +99,50 @@ async def f(
95
99
client = client .client ()
96
100
await client .initialize ()
97
101
98
- def teardown ():
99
- async def ateardown ():
100
- if not cluster_mode :
101
- if "username" in kwargs :
102
- return
103
- if flushdb :
104
- try :
105
- await client .flushdb ()
106
- except redis .ConnectionError :
107
- # handle cases where a test disconnected a client
108
- # just manually retry the flushdb
109
- await client .flushdb ()
110
- await client .close ()
111
- await client .connection_pool .disconnect ()
112
- else :
113
- if flushdb :
114
- try :
115
- await client .flushdb (target_nodes = "primaries" )
116
- except redis .ConnectionError :
117
- # handle cases where a test disconnected a client
118
- # just manually retry the flushdb
119
- await client .flushdb (target_nodes = "primaries" )
120
- await client .close ()
121
-
122
- if event_loop .is_running ():
123
- event_loop .create_task (ateardown ())
102
+ async def teardown ():
103
+ if not cluster_mode :
104
+ if flushdb and "username" not in kwargs :
105
+ try :
106
+ await client .flushdb ()
107
+ except redis .ConnectionError :
108
+ # handle cases where a test disconnected a client
109
+ # just manually retry the flushdb
110
+ await client .flushdb ()
111
+ await client .close ()
112
+ await client .connection_pool .disconnect ()
124
113
else :
125
- event_loop .run_until_complete (ateardown ())
126
-
127
- request .addfinalizer (teardown )
128
-
114
+ if flushdb :
115
+ try :
116
+ await client .flushdb (target_nodes = "primaries" )
117
+ except redis .ConnectionError :
118
+ # handle cases where a test disconnected a client
119
+ # just manually retry the flushdb
120
+ await client .flushdb (target_nodes = "primaries" )
121
+ await client .close ()
122
+
123
+ teardown_clients .append (teardown )
129
124
return client
130
125
131
- return f
126
+ yield client_factory
127
+
128
+ for teardown in teardown_clients :
129
+ await teardown ()
132
130
133
131
134
132
@pytest_asyncio .fixture ()
135
- async def r (request , create_redis ):
136
- yield await create_redis ()
133
+ async def r (create_redis ):
134
+ return await create_redis ()
137
135
138
136
139
137
@pytest_asyncio .fixture ()
140
138
async def r2 (create_redis ):
141
139
"""A second client for tests that need multiple"""
142
- yield await create_redis ()
140
+ return await create_redis ()
143
141
144
142
145
143
@pytest_asyncio .fixture ()
146
144
async def modclient (request , create_redis ):
147
- yield await create_redis (
145
+ return await create_redis (
148
146
url = request .config .getoption ("--redismod-url" ), decode_responses = True
149
147
)
150
148
@@ -222,7 +220,7 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs):
222
220
def master_host (request ):
223
221
url = request .config .getoption ("--redis-url" )
224
222
parts = urlparse (url )
225
- yield parts .hostname
223
+ return parts .hostname
226
224
227
225
228
226
async def wait_for_command (
@@ -246,3 +244,41 @@ async def wait_for_command(
246
244
return monitor_response
247
245
if key in monitor_response ["command" ]:
248
246
return None
247
+
248
+
249
+ # python 3.6 doesn't have the asynccontextmanager decorator. Provide it here.
250
+ class AsyncContextManager :
251
+ def __init__ (self , async_generator ):
252
+ self .gen = async_generator
253
+
254
+ async def __aenter__ (self ):
255
+ try :
256
+ return await self .gen .__anext__ ()
257
+ except StopAsyncIteration as err :
258
+ raise RuntimeError ("Pickles" ) from err
259
+
260
+ async def __aexit__ (self , exc_type , exc_inst , tb ):
261
+ if exc_type :
262
+ await self .gen .athrow (exc_type , exc_inst , tb )
263
+ return True
264
+ try :
265
+ await self .gen .__anext__ ()
266
+ except StopAsyncIteration :
267
+ return
268
+ raise RuntimeError ("More pickles" )
269
+
270
+
271
+ if sys .version_info [0 :2 ] == (3 , 6 ):
272
+
273
+ def asynccontextmanager (func ):
274
+ @functools .wraps (func )
275
+ def wrapper (* args , ** kwargs ):
276
+ return AsyncContextManager (func (* args , ** kwargs ))
277
+
278
+ return wrapper
279
+
280
+ else :
281
+ from contextlib import asynccontextmanager as _asynccontextmanager
282
+
283
+ def asynccontextmanager (func ):
284
+ return _asynccontextmanager (func )
0 commit comments