Skip to content

Commit bb11101

Browse files
authored
Restore async concurrency safety to websocket compressor (#7865) (#7889)
Fixes #7859 (cherry picked from commit 86a2396)
1 parent 6dd0122 commit bb11101

File tree

4 files changed

+97
-19
lines changed

4 files changed

+97
-19
lines changed

CHANGES/7865.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Restore async concurrency safety to websocket compressor

aiohttp/compression_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,25 @@ def __init__(
6262
self._compressor = zlib.compressobj(
6363
wbits=self._mode, strategy=strategy, level=level
6464
)
65+
self._compress_lock = asyncio.Lock()
6566

6667
def compress_sync(self, data: bytes) -> bytes:
6768
return self._compressor.compress(data)
6869

6970
async def compress(self, data: bytes) -> bytes:
70-
if (
71-
self._max_sync_chunk_size is not None
72-
and len(data) > self._max_sync_chunk_size
73-
):
74-
return await asyncio.get_event_loop().run_in_executor(
75-
self._executor, self.compress_sync, data
76-
)
77-
return self.compress_sync(data)
71+
async with self._compress_lock:
72+
# To ensure the stream is consistent in the event
73+
# there are multiple writers, we need to lock
74+
# the compressor so that only one writer can
75+
# compress at a time.
76+
if (
77+
self._max_sync_chunk_size is not None
78+
and len(data) > self._max_sync_chunk_size
79+
):
80+
return await asyncio.get_event_loop().run_in_executor(
81+
self._executor, self.compress_sync, data
82+
)
83+
return self.compress_sync(data)
7884

7985
def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
8086
return self._compressor.flush(mode)

aiohttp/http_websocket.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -635,21 +635,17 @@ async def _send_frame(
635635
if (compress or self.compress) and opcode < 8:
636636
if compress:
637637
# Do not set self._compress if compressing is for this frame
638-
compressobj = ZLibCompressor(
639-
level=zlib.Z_BEST_SPEED,
640-
wbits=-compress,
641-
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
642-
)
638+
compressobj = self._make_compress_obj(compress)
643639
else: # self.compress
644640
if not self._compressobj:
645-
self._compressobj = ZLibCompressor(
646-
level=zlib.Z_BEST_SPEED,
647-
wbits=-self.compress,
648-
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
649-
)
641+
self._compressobj = self._make_compress_obj(self.compress)
650642
compressobj = self._compressobj
651643

652644
message = await compressobj.compress(message)
645+
# Its critical that we do not return control to the event
646+
# loop until we have finished sending all the compressed
647+
# data. Otherwise we could end up mixing compressed frames
648+
# if there are multiple coroutines compressing data.
653649
message += compressobj.flush(
654650
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
655651
)
@@ -687,10 +683,20 @@ async def _send_frame(
687683

688684
self._output_size += len(header) + len(message)
689685

686+
# It is safe to return control to the event loop when using compression
687+
# after this point as we have already sent or buffered all the data.
688+
690689
if self._output_size > self._limit:
691690
self._output_size = 0
692691
await self.protocol._drain_helper()
693692

693+
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
694+
return ZLibCompressor(
695+
level=zlib.Z_BEST_SPEED,
696+
wbits=-compress,
697+
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
698+
)
699+
694700
def _write(self, data: bytes) -> None:
695701
if self.transport is None or self.transport.is_closing():
696702
raise ConnectionResetError("Cannot write to closing transport")

tests/test_websocket_writer.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import asyncio
12
import random
3+
from typing import Any, Callable
24
from unittest import mock
35

46
import pytest
57

6-
from aiohttp.http import WebSocketWriter
8+
from aiohttp import DataQueue, WSMessage
9+
from aiohttp.http import WebSocketReader, WebSocketWriter
710
from aiohttp.test_utils import make_mocked_coro
811

912

@@ -104,3 +107,65 @@ async def test_send_compress_text_per_message(protocol, transport) -> None:
104107
writer.transport.write.assert_called_with(b"\x81\x04text")
105108
await writer.send(b"text", compress=15)
106109
writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00")
110+
111+
112+
@pytest.mark.parametrize(
113+
("max_sync_chunk_size", "payload_point_generator"),
114+
(
115+
(16, lambda count: count),
116+
(4096, lambda count: count),
117+
(32, lambda count: 64 + count if count % 2 else count),
118+
),
119+
)
120+
async def test_concurrent_messages(
121+
protocol: Any,
122+
transport: Any,
123+
max_sync_chunk_size: int,
124+
payload_point_generator: Callable[[int], int],
125+
) -> None:
126+
"""Ensure messages are compressed correctly when there are multiple concurrent writers.
127+
128+
This test generates is parametrized to
129+
130+
- Generate messages that are larger than patch
131+
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 16
132+
where compression will run in the executor
133+
134+
- Generate messages that are smaller than patch
135+
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 4096
136+
where compression will run in the event loop
137+
138+
- Interleave generated messages with a
139+
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 32
140+
where compression will run in the event loop
141+
and in the executor
142+
"""
143+
with mock.patch(
144+
"aiohttp.http_websocket.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", max_sync_chunk_size
145+
):
146+
writer = WebSocketWriter(protocol, transport, compress=15)
147+
queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop())
148+
reader = WebSocketReader(queue, 50000)
149+
writers = []
150+
payloads = []
151+
for count in range(1, 64 + 1):
152+
point = payload_point_generator(count)
153+
payload = bytes((point,)) * point
154+
payloads.append(payload)
155+
writers.append(writer.send(payload, binary=True))
156+
await asyncio.gather(*writers)
157+
158+
for call in writer.transport.write.call_args_list:
159+
call_bytes = call[0][0]
160+
result, _ = reader.feed_data(call_bytes)
161+
assert result is False
162+
msg = await queue.read()
163+
bytes_data: bytes = msg.data
164+
first_char = bytes_data[0:1]
165+
char_val = ord(first_char)
166+
assert len(bytes_data) == char_val
167+
# If we have a concurrency problem, the data
168+
# tends to get mixed up between messages so
169+
# we want to validate that all the bytes are
170+
# the same value
171+
assert bytes_data == bytes_data[0:1] * char_val

0 commit comments

Comments
 (0)