1
1
import asyncio
2
+ import contextlib
2
3
import os
3
4
import re
4
5
import sys
@@ -114,7 +115,8 @@ async def can_read(self, timeout: float = 0):
114
115
115
116
116
117
class TestConnectionPool :
117
- def get_pool (
118
+ @contextlib .asynccontextmanager
119
+ async def get_pool (
118
120
self ,
119
121
connection_kwargs = None ,
120
122
max_connections = None ,
@@ -126,79 +128,88 @@ def get_pool(
126
128
max_connections = max_connections ,
127
129
** connection_kwargs ,
128
130
)
129
- return pool
131
+ try :
132
+ yield pool
133
+ finally :
134
+ await pool .disconnect (inuse_connections = True )
130
135
131
136
async def test_connection_creation (self ):
132
137
connection_kwargs = {"foo" : "bar" , "biz" : "baz" }
133
- pool = self .get_pool (
138
+ async with self .get_pool (
134
139
connection_kwargs = connection_kwargs , connection_class = DummyConnection
135
- )
136
- connection = await pool .get_connection ("_" )
137
- assert isinstance (connection , DummyConnection )
138
- assert connection .kwargs == connection_kwargs
140
+ ) as pool :
141
+ connection = await pool .get_connection ("_" )
142
+ assert isinstance (connection , DummyConnection )
143
+ assert connection .kwargs == connection_kwargs
139
144
140
145
async def test_multiple_connections (self , master_host ):
141
146
connection_kwargs = {"host" : master_host }
142
- pool = self .get_pool (connection_kwargs = connection_kwargs )
143
- c1 = await pool .get_connection ("_" )
144
- c2 = await pool .get_connection ("_" )
145
- assert c1 != c2
147
+ async with self .get_pool (connection_kwargs = connection_kwargs ) as pool :
148
+ c1 = await pool .get_connection ("_" )
149
+ c2 = await pool .get_connection ("_" )
150
+ assert c1 != c2
146
151
147
152
async def test_max_connections (self , master_host ):
148
153
connection_kwargs = {"host" : master_host }
149
- pool = self .get_pool (max_connections = 2 , connection_kwargs = connection_kwargs )
150
- await pool . get_connection ( "_" )
151
- await pool . get_connection ( "_" )
152
- with pytest . raises ( redis . ConnectionError ):
154
+ async with self .get_pool (
155
+ max_connections = 2 , connection_kwargs = connection_kwargs
156
+ ) as pool :
157
+ await pool . get_connection ( "_" )
153
158
await pool .get_connection ("_" )
159
+ with pytest .raises (redis .ConnectionError ):
160
+ await pool .get_connection ("_" )
154
161
155
162
async def test_reuse_previously_released_connection (self , master_host ):
156
163
connection_kwargs = {"host" : master_host }
157
- pool = self .get_pool (connection_kwargs = connection_kwargs )
158
- c1 = await pool .get_connection ("_" )
159
- await pool .release (c1 )
160
- c2 = await pool .get_connection ("_" )
161
- assert c1 == c2
164
+ async with self .get_pool (connection_kwargs = connection_kwargs ) as pool :
165
+ c1 = await pool .get_connection ("_" )
166
+ await pool .release (c1 )
167
+ c2 = await pool .get_connection ("_" )
168
+ assert c1 == c2
162
169
163
- def test_repr_contains_db_info_tcp (self ):
170
+ async def test_repr_contains_db_info_tcp (self ):
164
171
connection_kwargs = {
165
172
"host" : "localhost" ,
166
173
"port" : 6379 ,
167
174
"db" : 1 ,
168
175
"client_name" : "test-client" ,
169
176
}
170
- pool = self .get_pool (
177
+ async with self .get_pool (
171
178
connection_kwargs = connection_kwargs , connection_class = redis .Connection
172
- )
173
- expected = (
174
- "ConnectionPool<Connection<"
175
- "host=localhost,port=6379,db=1,client_name=test-client>>"
176
- )
177
- assert repr (pool ) == expected
179
+ ) as pool :
180
+ expected = (
181
+ "ConnectionPool<Connection<"
182
+ "host=localhost,port=6379,db=1,client_name=test-client>>"
183
+ )
184
+ assert repr (pool ) == expected
178
185
179
- def test_repr_contains_db_info_unix (self ):
186
+ async def test_repr_contains_db_info_unix (self ):
180
187
connection_kwargs = {"path" : "/abc" , "db" : 1 , "client_name" : "test-client" }
181
- pool = self .get_pool (
188
+ async with self .get_pool (
182
189
connection_kwargs = connection_kwargs ,
183
190
connection_class = redis .UnixDomainSocketConnection ,
184
- )
185
- expected = (
186
- "ConnectionPool<UnixDomainSocketConnection<"
187
- "path=/abc,db=1,client_name=test-client>>"
188
- )
189
- assert repr (pool ) == expected
191
+ ) as pool :
192
+ expected = (
193
+ "ConnectionPool<UnixDomainSocketConnection<"
194
+ "path=/abc,db=1,client_name=test-client>>"
195
+ )
196
+ assert repr (pool ) == expected
190
197
191
198
192
199
class TestBlockingConnectionPool :
193
- def get_pool (self , connection_kwargs = None , max_connections = 10 , timeout = 20 ):
200
+ @contextlib .asynccontextmanager
201
+ async def get_pool (self , connection_kwargs = None , max_connections = 10 , timeout = 20 ):
194
202
connection_kwargs = connection_kwargs or {}
195
203
pool = redis .BlockingConnectionPool (
196
204
connection_class = DummyConnection ,
197
205
max_connections = max_connections ,
198
206
timeout = timeout ,
199
207
** connection_kwargs ,
200
208
)
201
- return pool
209
+ try :
210
+ yield pool
211
+ finally :
212
+ await pool .disconnect (inuse_connections = True )
202
213
203
214
async def test_connection_creation (self , master_host ):
204
215
connection_kwargs = {
@@ -207,10 +218,10 @@ async def test_connection_creation(self, master_host):
207
218
"host" : master_host [0 ],
208
219
"port" : master_host [1 ],
209
220
}
210
- pool = self .get_pool (connection_kwargs = connection_kwargs )
211
- connection = await pool .get_connection ("_" )
212
- assert isinstance (connection , DummyConnection )
213
- assert connection .kwargs == connection_kwargs
221
+ async with self .get_pool (connection_kwargs = connection_kwargs ) as pool :
222
+ connection = await pool .get_connection ("_" )
223
+ assert isinstance (connection , DummyConnection )
224
+ assert connection .kwargs == connection_kwargs
214
225
215
226
async def test_disconnect (self , master_host ):
216
227
"""A regression test for #1047"""
@@ -220,57 +231,58 @@ async def test_disconnect(self, master_host):
220
231
"host" : master_host [0 ],
221
232
"port" : master_host [1 ],
222
233
}
223
- pool = self .get_pool (connection_kwargs = connection_kwargs )
224
- await pool .get_connection ("_" )
225
- await pool .disconnect ()
234
+ async with self .get_pool (connection_kwargs = connection_kwargs ) as pool :
235
+ await pool .get_connection ("_" )
236
+ await pool .disconnect ()
226
237
227
238
async def test_multiple_connections (self , master_host ):
228
239
connection_kwargs = {"host" : master_host [0 ], "port" : master_host [1 ]}
229
- pool = self .get_pool (connection_kwargs = connection_kwargs )
230
- c1 = await pool .get_connection ("_" )
231
- c2 = await pool .get_connection ("_" )
232
- assert c1 != c2
240
+ async with self .get_pool (connection_kwargs = connection_kwargs ) as pool :
241
+ c1 = await pool .get_connection ("_" )
242
+ c2 = await pool .get_connection ("_" )
243
+ assert c1 != c2
233
244
234
245
async def test_connection_pool_blocks_until_timeout (self , master_host ):
235
246
"""When out of connections, block for timeout seconds, then raise"""
236
247
connection_kwargs = {"host" : master_host }
237
- pool = self .get_pool (
248
+ async with self .get_pool (
238
249
max_connections = 1 , timeout = 0.1 , connection_kwargs = connection_kwargs
239
- )
240
- await pool .get_connection ("_" )
250
+ ) as pool :
251
+ c1 = await pool .get_connection ("_" )
241
252
242
- start = asyncio .get_event_loop ().time ()
243
- with pytest .raises (redis .ConnectionError ):
244
- await pool .get_connection ("_" )
245
- # we should have waited at least 0.1 seconds
246
- assert asyncio .get_event_loop ().time () - start >= 0.1
253
+ start = asyncio .get_event_loop ().time ()
254
+ with pytest .raises (redis .ConnectionError ):
255
+ await pool .get_connection ("_" )
256
+ # we should have waited at least 0.1 seconds
257
+ assert asyncio .get_event_loop ().time () - start >= 0.1
258
+ await c1 .disconnect ()
247
259
248
260
async def test_connection_pool_blocks_until_conn_available (self , master_host ):
249
261
"""
250
262
When out of connections, block until another connection is released
251
263
to the pool
252
264
"""
253
265
connection_kwargs = {"host" : master_host [0 ], "port" : master_host [1 ]}
254
- pool = self .get_pool (
266
+ async with self .get_pool (
255
267
max_connections = 1 , timeout = 2 , connection_kwargs = connection_kwargs
256
- )
257
- c1 = await pool .get_connection ("_" )
268
+ ) as pool :
269
+ c1 = await pool .get_connection ("_" )
258
270
259
- async def target ():
260
- await asyncio .sleep (0.1 )
261
- await pool .release (c1 )
271
+ async def target ():
272
+ await asyncio .sleep (0.1 )
273
+ await pool .release (c1 )
262
274
263
- start = asyncio .get_event_loop ().time ()
264
- await asyncio .gather (target (), pool .get_connection ("_" ))
265
- assert asyncio .get_event_loop ().time () - start >= 0.1
275
+ start = asyncio .get_event_loop ().time ()
276
+ await asyncio .gather (target (), pool .get_connection ("_" ))
277
+ assert asyncio .get_event_loop ().time () - start >= 0.1
266
278
267
279
async def test_reuse_previously_released_connection (self , master_host ):
268
280
connection_kwargs = {"host" : master_host }
269
- pool = self .get_pool (connection_kwargs = connection_kwargs )
270
- c1 = await pool .get_connection ("_" )
271
- await pool .release (c1 )
272
- c2 = await pool .get_connection ("_" )
273
- assert c1 == c2
281
+ async with self .get_pool (connection_kwargs = connection_kwargs ) as pool :
282
+ c1 = await pool .get_connection ("_" )
283
+ await pool .release (c1 )
284
+ c2 = await pool .get_connection ("_" )
285
+ assert c1 == c2
274
286
275
287
def test_repr_contains_db_info_tcp (self ):
276
288
pool = redis .ConnectionPool (
0 commit comments