Skip to content

Commit 86a2396

Browse files
authored
Restore async concurrency safety to websocket compressor (#7865)
Fixes #7859
1 parent 17c7d95 commit 86a2396

File tree

4 files changed

+97
-20
lines changed

4 files changed

+97
-20
lines changed

CHANGES/7865.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Restore async concurrency safety to websocket compressor

aiohttp/compression_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,25 @@ def __init__(
6262
self._compressor = zlib.compressobj(
6363
wbits=self._mode, strategy=strategy, level=level
6464
)
65+
self._compress_lock = asyncio.Lock()
6566

6667
def compress_sync(self, data: bytes) -> bytes:
6768
return self._compressor.compress(data)
6869

6970
async def compress(self, data: bytes) -> bytes:
70-
if (
71-
self._max_sync_chunk_size is not None
72-
and len(data) > self._max_sync_chunk_size
73-
):
74-
return await asyncio.get_event_loop().run_in_executor(
75-
self._executor, self.compress_sync, data
76-
)
77-
return self.compress_sync(data)
71+
async with self._compress_lock:
72+
# To ensure the stream is consistent in the event
73+
# there are multiple writers, we need to lock
74+
# the compressor so that only one writer can
75+
# compress at a time.
76+
if (
77+
self._max_sync_chunk_size is not None
78+
and len(data) > self._max_sync_chunk_size
79+
):
80+
return await asyncio.get_event_loop().run_in_executor(
81+
self._executor, self.compress_sync, data
82+
)
83+
return self.compress_sync(data)
7884

7985
def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
8086
return self._compressor.flush(mode)

aiohttp/http_websocket.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -626,21 +626,17 @@ async def _send_frame(
626626
if (compress or self.compress) and opcode < 8:
627627
if compress:
628628
# Do not set self._compress if compressing is for this frame
629-
compressobj = ZLibCompressor(
630-
level=zlib.Z_BEST_SPEED,
631-
wbits=-compress,
632-
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
633-
)
629+
compressobj = self._make_compress_obj(compress)
634630
else: # self.compress
635631
if not self._compressobj:
636-
self._compressobj = ZLibCompressor(
637-
level=zlib.Z_BEST_SPEED,
638-
wbits=-self.compress,
639-
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
640-
)
632+
self._compressobj = self._make_compress_obj(self.compress)
641633
compressobj = self._compressobj
642634

643635
message = await compressobj.compress(message)
636+
# Its critical that we do not return control to the event
637+
# loop until we have finished sending all the compressed
638+
# data. Otherwise we could end up mixing compressed frames
639+
# if there are multiple coroutines compressing data.
644640
message += compressobj.flush(
645641
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
646642
)
@@ -678,10 +674,20 @@ async def _send_frame(
678674

679675
self._output_size += len(header) + len(message)
680676

677+
# It is safe to return control to the event loop when using compression
678+
# after this point as we have already sent or buffered all the data.
679+
681680
if self._output_size > self._limit:
682681
self._output_size = 0
683682
await self.protocol._drain_helper()
684683

684+
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
685+
return ZLibCompressor(
686+
level=zlib.Z_BEST_SPEED,
687+
wbits=-compress,
688+
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
689+
)
690+
685691
def _write(self, data: bytes) -> None:
686692
if self.transport.is_closing():
687693
raise ConnectionResetError("Cannot write to closing transport")

tests/test_websocket_writer.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# type: ignore
2+
import asyncio
23
import random
3-
from typing import Any
4+
from typing import Any, Callable
45
from unittest import mock
56

67
import pytest
78

8-
from aiohttp.http import WebSocketWriter
9+
from aiohttp import DataQueue, WSMessage
10+
from aiohttp.http import WebSocketReader, WebSocketWriter
911
from aiohttp.test_utils import make_mocked_coro
1012

1113

@@ -106,3 +108,65 @@ async def test_send_compress_text_per_message(protocol: Any, transport: Any) ->
106108
writer.transport.write.assert_called_with(b"\x81\x04text")
107109
await writer.send(b"text", compress=15)
108110
writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00")
111+
112+
113+
@pytest.mark.parametrize(
114+
("max_sync_chunk_size", "payload_point_generator"),
115+
(
116+
(16, lambda count: count),
117+
(4096, lambda count: count),
118+
(32, lambda count: 64 + count if count % 2 else count),
119+
),
120+
)
121+
async def test_concurrent_messages(
122+
protocol: Any,
123+
transport: Any,
124+
max_sync_chunk_size: int,
125+
payload_point_generator: Callable[[int], int],
126+
) -> None:
127+
"""Ensure messages are compressed correctly when there are multiple concurrent writers.
128+
129+
This test generates is parametrized to
130+
131+
- Generate messages that are larger than patch
132+
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 16
133+
where compression will run in the executor
134+
135+
- Generate messages that are smaller than patch
136+
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 4096
137+
where compression will run in the event loop
138+
139+
- Interleave generated messages with a
140+
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 32
141+
where compression will run in the event loop
142+
and in the executor
143+
"""
144+
with mock.patch(
145+
"aiohttp.http_websocket.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", max_sync_chunk_size
146+
):
147+
writer = WebSocketWriter(protocol, transport, compress=15)
148+
queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop())
149+
reader = WebSocketReader(queue, 50000)
150+
writers = []
151+
payloads = []
152+
for count in range(1, 64 + 1):
153+
point = payload_point_generator(count)
154+
payload = bytes((point,)) * point
155+
payloads.append(payload)
156+
writers.append(writer.send(payload, binary=True))
157+
await asyncio.gather(*writers)
158+
159+
for call in writer.transport.write.call_args_list:
160+
call_bytes = call[0][0]
161+
result, _ = reader.feed_data(call_bytes)
162+
assert result is False
163+
msg = await queue.read()
164+
bytes_data: bytes = msg.data
165+
first_char = bytes_data[0:1]
166+
char_val = ord(first_char)
167+
assert len(bytes_data) == char_val
168+
# If we have a concurrency problem, the data
169+
# tends to get mixed up between messages so
170+
# we want to validate that all the bytes are
171+
# the same value
172+
assert bytes_data == bytes_data[0:1] * char_val

0 commit comments

Comments
 (0)