Skip to content

Commit 966052c

Browse files
committed
Use context manager to clean up connections in connection pool for unit tests
1 parent 28affc4 commit 966052c

File tree

1 file changed

+83
-71
lines changed

1 file changed

+83
-71
lines changed

tests/test_asyncio/test_connection_pool.py

Lines changed: 83 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import os
34
import re
45
import sys
@@ -114,7 +115,8 @@ async def can_read(self, timeout: float = 0):
114115

115116

116117
class TestConnectionPool:
117-
def get_pool(
118+
@contextlib.asynccontextmanager
119+
async def get_pool(
118120
self,
119121
connection_kwargs=None,
120122
max_connections=None,
@@ -126,79 +128,88 @@ def get_pool(
126128
max_connections=max_connections,
127129
**connection_kwargs,
128130
)
129-
return pool
131+
try:
132+
yield pool
133+
finally:
134+
await pool.disconnect(inuse_connections=True)
130135

131136
async def test_connection_creation(self):
132137
connection_kwargs = {"foo": "bar", "biz": "baz"}
133-
pool = self.get_pool(
138+
async with self.get_pool(
134139
connection_kwargs=connection_kwargs, connection_class=DummyConnection
135-
)
136-
connection = await pool.get_connection("_")
137-
assert isinstance(connection, DummyConnection)
138-
assert connection.kwargs == connection_kwargs
140+
) as pool:
141+
connection = await pool.get_connection("_")
142+
assert isinstance(connection, DummyConnection)
143+
assert connection.kwargs == connection_kwargs
139144

140145
async def test_multiple_connections(self, master_host):
141146
connection_kwargs = {"host": master_host}
142-
pool = self.get_pool(connection_kwargs=connection_kwargs)
143-
c1 = await pool.get_connection("_")
144-
c2 = await pool.get_connection("_")
145-
assert c1 != c2
147+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
148+
c1 = await pool.get_connection("_")
149+
c2 = await pool.get_connection("_")
150+
assert c1 != c2
146151

147152
async def test_max_connections(self, master_host):
148153
connection_kwargs = {"host": master_host}
149-
pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs)
150-
await pool.get_connection("_")
151-
await pool.get_connection("_")
152-
with pytest.raises(redis.ConnectionError):
154+
async with self.get_pool(
155+
max_connections=2, connection_kwargs=connection_kwargs
156+
) as pool:
157+
await pool.get_connection("_")
153158
await pool.get_connection("_")
159+
with pytest.raises(redis.ConnectionError):
160+
await pool.get_connection("_")
154161

155162
async def test_reuse_previously_released_connection(self, master_host):
156163
connection_kwargs = {"host": master_host}
157-
pool = self.get_pool(connection_kwargs=connection_kwargs)
158-
c1 = await pool.get_connection("_")
159-
await pool.release(c1)
160-
c2 = await pool.get_connection("_")
161-
assert c1 == c2
164+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
165+
c1 = await pool.get_connection("_")
166+
await pool.release(c1)
167+
c2 = await pool.get_connection("_")
168+
assert c1 == c2
162169

163-
def test_repr_contains_db_info_tcp(self):
170+
async def test_repr_contains_db_info_tcp(self):
164171
connection_kwargs = {
165172
"host": "localhost",
166173
"port": 6379,
167174
"db": 1,
168175
"client_name": "test-client",
169176
}
170-
pool = self.get_pool(
177+
async with self.get_pool(
171178
connection_kwargs=connection_kwargs, connection_class=redis.Connection
172-
)
173-
expected = (
174-
"ConnectionPool<Connection<"
175-
"host=localhost,port=6379,db=1,client_name=test-client>>"
176-
)
177-
assert repr(pool) == expected
179+
) as pool:
180+
expected = (
181+
"ConnectionPool<Connection<"
182+
"host=localhost,port=6379,db=1,client_name=test-client>>"
183+
)
184+
assert repr(pool) == expected
178185

179-
def test_repr_contains_db_info_unix(self):
186+
async def test_repr_contains_db_info_unix(self):
180187
connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"}
181-
pool = self.get_pool(
188+
async with self.get_pool(
182189
connection_kwargs=connection_kwargs,
183190
connection_class=redis.UnixDomainSocketConnection,
184-
)
185-
expected = (
186-
"ConnectionPool<UnixDomainSocketConnection<"
187-
"path=/abc,db=1,client_name=test-client>>"
188-
)
189-
assert repr(pool) == expected
191+
) as pool:
192+
expected = (
193+
"ConnectionPool<UnixDomainSocketConnection<"
194+
"path=/abc,db=1,client_name=test-client>>"
195+
)
196+
assert repr(pool) == expected
190197

191198

192199
class TestBlockingConnectionPool:
193-
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
200+
@contextlib.asynccontextmanager
201+
async def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
194202
connection_kwargs = connection_kwargs or {}
195203
pool = redis.BlockingConnectionPool(
196204
connection_class=DummyConnection,
197205
max_connections=max_connections,
198206
timeout=timeout,
199207
**connection_kwargs,
200208
)
201-
return pool
209+
try:
210+
yield pool
211+
finally:
212+
await pool.disconnect(inuse_connections=True)
202213

203214
async def test_connection_creation(self, master_host):
204215
connection_kwargs = {
@@ -207,10 +218,10 @@ async def test_connection_creation(self, master_host):
207218
"host": master_host[0],
208219
"port": master_host[1],
209220
}
210-
pool = self.get_pool(connection_kwargs=connection_kwargs)
211-
connection = await pool.get_connection("_")
212-
assert isinstance(connection, DummyConnection)
213-
assert connection.kwargs == connection_kwargs
221+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
222+
connection = await pool.get_connection("_")
223+
assert isinstance(connection, DummyConnection)
224+
assert connection.kwargs == connection_kwargs
214225

215226
async def test_disconnect(self, master_host):
216227
"""A regression test for #1047"""
@@ -220,57 +231,58 @@ async def test_disconnect(self, master_host):
220231
"host": master_host[0],
221232
"port": master_host[1],
222233
}
223-
pool = self.get_pool(connection_kwargs=connection_kwargs)
224-
await pool.get_connection("_")
225-
await pool.disconnect()
234+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
235+
await pool.get_connection("_")
236+
await pool.disconnect()
226237

227238
async def test_multiple_connections(self, master_host):
228239
connection_kwargs = {"host": master_host[0], "port": master_host[1]}
229-
pool = self.get_pool(connection_kwargs=connection_kwargs)
230-
c1 = await pool.get_connection("_")
231-
c2 = await pool.get_connection("_")
232-
assert c1 != c2
240+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
241+
c1 = await pool.get_connection("_")
242+
c2 = await pool.get_connection("_")
243+
assert c1 != c2
233244

234245
async def test_connection_pool_blocks_until_timeout(self, master_host):
235246
"""When out of connections, block for timeout seconds, then raise"""
236247
connection_kwargs = {"host": master_host}
237-
pool = self.get_pool(
248+
async with self.get_pool(
238249
max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs
239-
)
240-
await pool.get_connection("_")
250+
) as pool:
251+
c1 = await pool.get_connection("_")
241252

242-
start = asyncio.get_event_loop().time()
243-
with pytest.raises(redis.ConnectionError):
244-
await pool.get_connection("_")
245-
# we should have waited at least 0.1 seconds
246-
assert asyncio.get_event_loop().time() - start >= 0.1
253+
start = asyncio.get_event_loop().time()
254+
with pytest.raises(redis.ConnectionError):
255+
await pool.get_connection("_")
256+
# we should have waited at least 0.1 seconds
257+
assert asyncio.get_event_loop().time() - start >= 0.1
258+
await c1.disconnect()
247259

248260
async def test_connection_pool_blocks_until_conn_available(self, master_host):
249261
"""
250262
When out of connections, block until another connection is released
251263
to the pool
252264
"""
253265
connection_kwargs = {"host": master_host[0], "port": master_host[1]}
254-
pool = self.get_pool(
266+
async with self.get_pool(
255267
max_connections=1, timeout=2, connection_kwargs=connection_kwargs
256-
)
257-
c1 = await pool.get_connection("_")
268+
) as pool:
269+
c1 = await pool.get_connection("_")
258270

259-
async def target():
260-
await asyncio.sleep(0.1)
261-
await pool.release(c1)
271+
async def target():
272+
await asyncio.sleep(0.1)
273+
await pool.release(c1)
262274

263-
start = asyncio.get_event_loop().time()
264-
await asyncio.gather(target(), pool.get_connection("_"))
265-
assert asyncio.get_event_loop().time() - start >= 0.1
275+
start = asyncio.get_event_loop().time()
276+
await asyncio.gather(target(), pool.get_connection("_"))
277+
assert asyncio.get_event_loop().time() - start >= 0.1
266278

267279
async def test_reuse_previously_released_connection(self, master_host):
268280
connection_kwargs = {"host": master_host}
269-
pool = self.get_pool(connection_kwargs=connection_kwargs)
270-
c1 = await pool.get_connection("_")
271-
await pool.release(c1)
272-
c2 = await pool.get_connection("_")
273-
assert c1 == c2
281+
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
282+
c1 = await pool.get_connection("_")
283+
await pool.release(c1)
284+
c2 = await pool.get_connection("_")
285+
assert c1 == c2
274286

275287
def test_repr_contains_db_info_tcp(self):
276288
pool = redis.ConnectionPool(

0 commit comments

Comments
 (0)