Skip to content

Commit 48f9f43

Browse files
committed
Add trio message assembler.
1 parent 25c5c07 commit 48f9f43

File tree

5 files changed

+955
-3
lines changed

5 files changed

+955
-3
lines changed

src/websockets/asyncio/messages.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ class Assembler:
8181
8282
"""
8383

84-
# coverage reports incorrectly: "line NN didn't jump to the function exit"
85-
def __init__( # pragma: no cover
84+
def __init__(
8685
self,
8786
high: int | None = None,
8887
low: int | None = None,

src/websockets/trio/messages.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
from __future__ import annotations
2+
3+
import codecs
4+
import math
5+
from collections.abc import AsyncIterator
6+
from typing import Any, Callable, Literal, TypeVar, overload
7+
8+
import trio
9+
10+
from ..exceptions import ConcurrencyError
11+
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
12+
from ..typing import Data
13+
14+
15+
__all__ = ["Assembler"]
16+
17+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
18+
19+
T = TypeVar("T")
20+
21+
22+
class Assembler:
23+
"""
24+
Assemble messages from frames.
25+
26+
:class:`Assembler` expects only data frames. The stream of frames must
27+
respect the protocol; if it doesn't, the behavior is undefined.
28+
29+
Args:
30+
pause: Called when the buffer of frames goes above the high water mark;
31+
should pause reading from the network.
32+
resume: Called when the buffer of frames goes below the low water mark;
33+
should resume reading from the network.
34+
35+
"""
36+
37+
def __init__(
38+
self,
39+
high: int | None = None,
40+
low: int | None = None,
41+
pause: Callable[[], Any] = lambda: None,
42+
resume: Callable[[], Any] = lambda: None,
43+
) -> None:
44+
# Queue of incoming frames.
45+
self.send_frames: trio.MemorySendChannel[Frame]
46+
self.recv_frames: trio.MemoryReceiveChannel[Frame]
47+
self.send_frames, self.recv_frames = trio.open_memory_channel(math.inf)
48+
49+
# We cannot put a hard limit on the size of the queue because a single
50+
# call to Protocol.data_received() could produce thousands of frames,
51+
# which must be buffered. Instead, we pause reading when the buffer goes
52+
# above the high limit and we resume when it goes under the low limit.
53+
if high is not None and low is None:
54+
low = high // 4
55+
if high is None and low is not None:
56+
high = low * 4
57+
if high is not None and low is not None:
58+
if low < 0:
59+
raise ValueError("low must be positive or equal to zero")
60+
if high < low:
61+
raise ValueError("high must be greater than or equal to low")
62+
self.high, self.low = high, low
63+
self.pause = pause
64+
self.resume = resume
65+
self.paused = False
66+
67+
# This flag prevents concurrent calls to get() by user code.
68+
self.get_in_progress = False
69+
70+
# This flag marks the end of the connection.
71+
self.closed = False
72+
73+
@overload
74+
async def get(self, decode: Literal[True]) -> str: ...
75+
76+
@overload
77+
async def get(self, decode: Literal[False]) -> bytes: ...
78+
79+
@overload
80+
async def get(self, decode: bool | None = None) -> Data: ...
81+
82+
async def get(self, decode: bool | None = None) -> Data:
83+
"""
84+
Read the next message.
85+
86+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
87+
88+
If the message is fragmented, :meth:`get` waits until the last frame is
89+
received, then it reassembles the message and returns it. To receive
90+
messages frame by frame, use :meth:`get_iter` instead.
91+
92+
Args:
93+
decode: :obj:`False` disables UTF-8 decoding of text frames and
94+
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
95+
binary frames and returns :class:`str`.
96+
97+
Raises:
98+
EOFError: If the stream of frames has ended.
99+
UnicodeDecodeError: If a text frame contains invalid UTF-8.
100+
ConcurrencyError: If two coroutines run :meth:`get` or
101+
:meth:`get_iter` concurrently.
102+
103+
"""
104+
if self.get_in_progress:
105+
raise ConcurrencyError("get() or get_iter() is already running")
106+
self.get_in_progress = True
107+
108+
# Locking with get_in_progress prevents concurrent execution
109+
# until get() fetches a complete message or is canceled.
110+
111+
try:
112+
# First frame
113+
try:
114+
frame = await self.recv_frames.receive()
115+
except trio.EndOfChannel:
116+
raise EOFError("stream of frames ended")
117+
self.maybe_resume()
118+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
119+
if decode is None:
120+
decode = frame.opcode is OP_TEXT
121+
frames = [frame]
122+
123+
# Following frames, for fragmented messages
124+
while not frame.fin:
125+
try:
126+
frame = await self.recv_frames.receive()
127+
except trio.Cancelled:
128+
# Put frames already received back into the queue
129+
# so that future calls to get() can return them.
130+
assert not self.send_frames._state.receive_tasks, (
131+
"no task should be waiting on receive()"
132+
)
133+
assert not self.send_frames._state.data, "queue should be empty"
134+
for frame in frames:
135+
self.send_frames.send_nowait(frame)
136+
raise
137+
except trio.EndOfChannel:
138+
raise EOFError("stream of frames ended")
139+
self.maybe_resume()
140+
assert frame.opcode is OP_CONT
141+
frames.append(frame)
142+
143+
finally:
144+
self.get_in_progress = False
145+
146+
data = b"".join(frame.data for frame in frames)
147+
if decode:
148+
return data.decode()
149+
else:
150+
return data
151+
152+
@overload
153+
def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ...
154+
155+
@overload
156+
def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...
157+
158+
@overload
159+
def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
160+
161+
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
162+
"""
163+
Stream the next message.
164+
165+
Iterating the return value of :meth:`get_iter` asynchronously yields a
166+
:class:`str` or :class:`bytes` for each frame in the message.
167+
168+
The iterator must be fully consumed before calling :meth:`get_iter` or
169+
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
170+
171+
This method only makes sense for fragmented messages. If messages aren't
172+
fragmented, use :meth:`get` instead.
173+
174+
Args:
175+
decode: :obj:`False` disables UTF-8 decoding of text frames and
176+
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
177+
binary frames and returns :class:`str`.
178+
179+
Raises:
180+
EOFError: If the stream of frames has ended.
181+
UnicodeDecodeError: If a text frame contains invalid UTF-8.
182+
ConcurrencyError: If two coroutines run :meth:`get` or
183+
:meth:`get_iter` concurrently.
184+
185+
"""
186+
if self.get_in_progress:
187+
raise ConcurrencyError("get() or get_iter() is already running")
188+
self.get_in_progress = True
189+
190+
# Locking with get_in_progress prevents concurrent execution
191+
# until get_iter() fetches a complete message or is canceled.
192+
193+
# If get_iter() raises an exception e.g. in decoder.decode(),
194+
# get_in_progress remains set and the connection becomes unusable.
195+
196+
# First frame
197+
try:
198+
frame = await self.recv_frames.receive()
199+
except trio.Cancelled:
200+
self.get_in_progress = False
201+
raise
202+
except trio.EndOfChannel:
203+
raise EOFError("stream of frames ended")
204+
self.maybe_resume()
205+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
206+
if decode is None:
207+
decode = frame.opcode is OP_TEXT
208+
if decode:
209+
decoder = UTF8Decoder()
210+
yield decoder.decode(frame.data, frame.fin)
211+
else:
212+
yield frame.data
213+
214+
# Following frames, for fragmented messages
215+
while not frame.fin:
216+
# We cannot handle trio.Cancelled because we don't buffer
217+
# previous fragments — we're streaming them. Canceling get_iter()
218+
# here will leave the assembler in a stuck state. Future calls to
219+
# get() or get_iter() will raise ConcurrencyError.
220+
try:
221+
frame = await self.recv_frames.receive()
222+
except trio.EndOfChannel:
223+
raise EOFError("stream of frames ended")
224+
self.maybe_resume()
225+
assert frame.opcode is OP_CONT
226+
if decode:
227+
yield decoder.decode(frame.data, frame.fin)
228+
else:
229+
yield frame.data
230+
231+
self.get_in_progress = False
232+
233+
def put(self, frame: Frame) -> None:
234+
"""
235+
Add ``frame`` to the next message.
236+
237+
Raises:
238+
EOFError: If the stream of frames has ended.
239+
240+
"""
241+
if self.closed:
242+
raise EOFError("stream of frames ended")
243+
244+
self.send_frames.send_nowait(frame)
245+
self.maybe_pause()
246+
247+
def maybe_pause(self) -> None:
248+
"""Pause the writer if queue is above the high water mark."""
249+
# Skip if flow control is disabled
250+
if self.high is None:
251+
return
252+
253+
# Bypass the statistics() method for performance reasons.
254+
# Check for "> high" to support high = 0
255+
if len(self.send_frames._state.data) > self.high and not self.paused:
256+
self.paused = True
257+
self.pause()
258+
259+
def maybe_resume(self) -> None:
260+
"""Resume the writer if queue is below the low water mark."""
261+
# Skip if flow control is disabled
262+
if self.low is None:
263+
return
264+
265+
# Bypass the statistics() method for performance reasons.
266+
# Check for "<= low" to support low = 0
267+
if len(self.send_frames._state.data) <= self.low and self.paused:
268+
self.paused = False
269+
self.resume()
270+
271+
def close(self) -> None:
272+
"""
273+
End the stream of frames.
274+
275+
Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
276+
or :meth:`put` is safe. They will raise :exc:`trio.EndOfChannel`.
277+
278+
"""
279+
if self.closed:
280+
return
281+
282+
self.closed = True
283+
284+
# Unblock get() or get_iter().
285+
self.send_frames.close()

tests/asyncio/test_messages.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ async def test_get_iter_fragmented_text_message_not_received_yet(self):
267267
self.assertEqual(await anext(iterator), "f")
268268
self.assembler.put(Frame(OP_CONT, b"\xa9"))
269269
self.assertEqual(await anext(iterator), "é")
270+
await iterator.aclose()
270271

271272
async def test_get_iter_fragmented_binary_message_not_received_yet(self):
272273
"""get_iter yields a fragmented binary message when it is received."""
@@ -277,6 +278,7 @@ async def test_get_iter_fragmented_binary_message_not_received_yet(self):
277278
self.assertEqual(await anext(iterator), b"e")
278279
self.assembler.put(Frame(OP_CONT, b"a"))
279280
self.assertEqual(await anext(iterator), b"a")
281+
await iterator.aclose()
280282

281283
async def test_get_iter_fragmented_text_message_being_received(self):
282284
"""get_iter yields a fragmented text message that is partially received."""
@@ -287,6 +289,7 @@ async def test_get_iter_fragmented_text_message_being_received(self):
287289
self.assertEqual(await anext(iterator), "f")
288290
self.assembler.put(Frame(OP_CONT, b"\xa9"))
289291
self.assertEqual(await anext(iterator), "é")
292+
await iterator.aclose()
290293

291294
async def test_get_iter_fragmented_binary_message_being_received(self):
292295
"""get_iter yields a fragmented binary message that is partially received."""
@@ -297,6 +300,7 @@ async def test_get_iter_fragmented_binary_message_being_received(self):
297300
self.assertEqual(await anext(iterator), b"e")
298301
self.assembler.put(Frame(OP_CONT, b"a"))
299302
self.assertEqual(await anext(iterator), b"a")
303+
await iterator.aclose()
300304

301305
async def test_get_iter_encoded_text_message(self):
302306
"""get_iter yields a text message without UTF-8 decoding."""
@@ -334,6 +338,8 @@ async def test_get_iter_resumes_reading(self):
334338
await anext(iterator)
335339
self.resume.assert_called_once_with()
336340

341+
await iterator.aclose()
342+
337343
async def test_get_iter_does_not_resume_reading(self):
338344
"""get_iter does not resume reading when the low-water mark is unset."""
339345
self.assembler.low = None
@@ -345,6 +351,7 @@ async def test_get_iter_does_not_resume_reading(self):
345351
await anext(iterator)
346352
await anext(iterator)
347353
await anext(iterator)
354+
await iterator.aclose()
348355

349356
self.resume.assert_not_called()
350357

@@ -467,7 +474,7 @@ async def test_get_iter_queued_fragmented_message_after_close(self):
467474
self.assertEqual(fragments, [b"t", b"e", b"a"])
468475

469476
async def test_get_partially_queued_fragmented_message_after_close(self):
470-
"""get raises EOF on a partial fragmented message after close is called."""
477+
"""get raises EOFError on a partial fragmented message after close is called."""
471478
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
472479
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
473480
self.assembler.close()

0 commit comments

Comments
 (0)