Skip to content

Commit 9d20929

Browse files
committed
Wait for a send event, rather than rely on sleep time. Excpect cancel errors.
1 parent 34e9263 commit 9d20929

File tree

1 file changed

+42
-28
lines changed

1 file changed

+42
-28
lines changed

tests/test_asyncio/test_cwe_404.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import sys
32
import urllib.parse
43

54
import pytest
@@ -20,23 +19,12 @@ def redis_addr(request):
2019
return netloc, "6379"
2120

2221

23-
async def pipe(
24-
reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
25-
):
26-
while True:
27-
data = await reader.read(1000)
28-
if not data:
29-
break
30-
await asyncio.sleep(delay)
31-
writer.write(data)
32-
await writer.drain()
33-
34-
3522
class DelayProxy:
3623
def __init__(self, addr, redis_addr, delay: float):
3724
self.addr = addr
3825
self.redis_addr = redis_addr
3926
self.delay = delay
27+
self.send_event = asyncio.Event()
4028

4129
async def start(self):
4230
# test that we can connect to redis
@@ -49,10 +37,10 @@ async def start(self):
4937
async def handle(self, reader, writer):
5038
# establish connection to redis
5139
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
52-
pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
53-
pipe2 = asyncio.create_task(
54-
pipe(redis_reader, writer, self.delay, "from redis:")
40+
pipe1 = asyncio.create_task(
41+
self.pipe(reader, redis_writer, "to redis:", self.send_event)
5542
)
43+
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:"))
5644
await asyncio.gather(pipe1, pipe2)
5745

5846
async def stop(self):
@@ -61,6 +49,24 @@ async def stop(self):
6149
loop = self.server.get_loop()
6250
await loop.shutdown_asyncgens()
6351

52+
async def pipe(
53+
self,
54+
reader: asyncio.StreamReader,
55+
writer: asyncio.StreamWriter,
56+
name="",
57+
event: asyncio.Event = None,
58+
):
59+
while True:
60+
data = await reader.read(1000)
61+
if not data:
62+
break
63+
# print(f"{name} read {len(data)} delay {self.delay}")
64+
if event:
65+
event.set()
66+
await asyncio.sleep(self.delay)
67+
writer.write(data)
68+
await writer.drain()
69+
6470

6571
@pytest.mark.onlynoncluster
6672
@pytest.mark.parametrize("delay", argvalues=[0.05, 0.5, 1, 2])
@@ -78,17 +84,18 @@ async def test_standalone(delay, redis_addr):
7884
await r.set("foo", "foo")
7985
await r.set("bar", "bar")
8086

87+
dp.send_event.clear()
8188
t = asyncio.create_task(r.get("foo"))
82-
await asyncio.sleep(delay)
89+
# Wait until the task has sent, and then some, to make sure it has
90+
# settled on the read.
91+
await dp.send_event.wait()
92+
await asyncio.sleep(0.01) # a little extra time for prudence
8393
t.cancel()
84-
try:
94+
with pytest.raises(asyncio.CancelledError):
8595
await t
86-
sys.stderr.write("try again, we did not cancel the task in time\n")
87-
except asyncio.CancelledError:
88-
sys.stderr.write(
89-
"canceled task, connection is left open with unread response\n"
90-
)
9196

97+
# make sure that our previous request, cancelled while waiting for
98+
# a repsponse, didn't leave the connection open andin a bad state
9299
assert await r.get("bar") == b"bar"
93100
assert await r.ping()
94101
assert await r.get("foo") == b"foo"
@@ -113,10 +120,17 @@ async def test_standalone_pipeline(delay, redis_addr):
113120
pipe2.ping()
114121
pipe2.get("foo")
115122

123+
dp.send_event.clear()
116124
t = asyncio.create_task(pipe.get("foo").execute())
117-
await asyncio.sleep(delay)
125+
# wait until task has settled on the read
126+
await dp.send_event.wait()
127+
await asyncio.sleep(0.01)
118128
t.cancel()
129+
with pytest.raises(asyncio.CancelledError):
130+
await t
119131

132+
# we have now cancelled the pieline in the middle of a request, make sure
133+
# that the connection is still usable
120134
pipe.get("bar")
121135
pipe.ping()
122136
pipe.get("foo")
@@ -147,13 +161,13 @@ async def test_cluster(request, redis_addr):
147161
await r.set("foo", "foo")
148162
await r.set("bar", "bar")
149163

164+
dp.send_event.clear()
150165
t = asyncio.create_task(r.get("foo"))
151-
await asyncio.sleep(0.050)
166+
await dp.send_event.wait()
167+
await asyncio.sleep(0.01)
152168
t.cancel()
153-
try:
169+
with pytest.raises(asyncio.CancelledError):
154170
await t
155-
except asyncio.CancelledError:
156-
pytest.fail("connection is left open with unread response")
157171

158172
assert await r.get("bar") == b"bar"
159173
assert await r.ping()

0 commit comments

Comments
 (0)