Skip to content

Commit 62e304c

Browse files
committed
Add trio message assembler.
1 parent 25c5c07 commit 62e304c

File tree

4 files changed

+993
-16
lines changed

4 files changed

+993
-16
lines changed

src/websockets/asyncio/messages.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ class Assembler:
8181
8282
"""
8383

84-
# coverage reports incorrectly: "line NN didn't jump to the function exit"
85-
def __init__( # pragma: no cover
84+
def __init__(
8685
self,
8786
high: int | None = None,
8887
low: int | None = None,

src/websockets/trio/messages.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
from __future__ import annotations
2+
3+
import codecs
4+
import math
5+
from collections.abc import AsyncIterator
6+
from typing import Any, Callable, Literal, TypeVar, overload
7+
8+
import trio
9+
10+
from ..exceptions import ConcurrencyError
11+
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
12+
from ..typing import Data
13+
14+
15+
__all__ = ["Assembler"]
16+
17+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
18+
19+
T = TypeVar("T")
20+
21+
22+
class Assembler:
23+
"""
24+
Assemble messages from frames.
25+
26+
:class:`Assembler` expects only data frames. The stream of frames must
27+
respect the protocol; if it doesn't, the behavior is undefined.
28+
29+
Args:
30+
pause: Called when the buffer of frames goes above the high water mark;
31+
should pause reading from the network.
32+
resume: Called when the buffer of frames goes below the low water mark;
33+
should resume reading from the network.
34+
35+
"""
36+
37+
def __init__(
38+
self,
39+
high: int | None = None,
40+
low: int | None = None,
41+
pause: Callable[[], Any] = lambda: None,
42+
resume: Callable[[], Any] = lambda: None,
43+
) -> None:
44+
# Queue of incoming frames.
45+
self.send_frames: trio.MemorySendChannel[Frame]
46+
self.recv_frames: trio.MemoryReceiveChannel[Frame]
47+
self.send_frames, self.recv_frames = trio.open_memory_channel(math.inf)
48+
49+
# We cannot put a hard limit on the size of the queue because a single
50+
# call to Protocol.data_received() could produce thousands of frames,
51+
# which must be buffered. Instead, we pause reading when the buffer goes
52+
# above the high limit and we resume when it goes under the low limit.
53+
if high is not None and low is None:
54+
low = high // 4
55+
if high is None and low is not None:
56+
high = low * 4
57+
if high is not None and low is not None:
58+
if low < 0:
59+
raise ValueError("low must be positive or equal to zero")
60+
if high < low:
61+
raise ValueError("high must be greater than or equal to low")
62+
self.high, self.low = high, low
63+
self.pause = pause
64+
self.resume = resume
65+
self.paused = False
66+
67+
# This flag prevents concurrent calls to get() by user code.
68+
self.get_in_progress = False
69+
70+
# This flag marks the end of the connection.
71+
self.closed = False
72+
73+
@overload
74+
async def get(self, decode: Literal[True]) -> str: ...
75+
76+
@overload
77+
async def get(self, decode: Literal[False]) -> bytes: ...
78+
79+
@overload
80+
async def get(self, decode: bool | None = None) -> Data: ...
81+
82+
async def get(self, decode: bool | None = None) -> Data:
83+
"""
84+
Read the next message.
85+
86+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
87+
88+
If the message is fragmented, :meth:`get` waits until the last frame is
89+
received, then it reassembles the message and returns it. To receive
90+
messages frame by frame, use :meth:`get_iter` instead.
91+
92+
Args:
93+
decode: :obj:`False` disables UTF-8 decoding of text frames and
94+
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
95+
binary frames and returns :class:`str`.
96+
97+
Raises:
98+
EOFError: If the stream of frames has ended.
99+
UnicodeDecodeError: If a text frame contains invalid UTF-8.
100+
ConcurrencyError: If two coroutines run :meth:`get` or
101+
:meth:`get_iter` concurrently.
102+
103+
"""
104+
if self.get_in_progress:
105+
raise ConcurrencyError("get() or get_iter() is already running")
106+
self.get_in_progress = True
107+
108+
# Locking with get_in_progress prevents concurrent execution
109+
# until get() fetches a complete message or is canceled.
110+
111+
try:
112+
# First frame
113+
try:
114+
frame = await self.recv_frames.receive()
115+
except trio.EndOfChannel:
116+
raise EOFError("stream of frames ended")
117+
self.maybe_resume()
118+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
119+
if decode is None:
120+
decode = frame.opcode is OP_TEXT
121+
frames = [frame]
122+
123+
# Following frames, for fragmented messages
124+
while not frame.fin:
125+
try:
126+
frame = await self.recv_frames.receive()
127+
except trio.Cancelled:
128+
# Put frames already received back into the queue
129+
# so that future calls to get() can return them.
130+
assert not self.send_frames._state.receive_tasks, (
131+
"no task should be waiting on receive()"
132+
)
133+
assert not self.send_frames._state.data, "queue should be empty"
134+
for frame in frames:
135+
self.send_frames.send_nowait(frame)
136+
raise
137+
except trio.EndOfChannel:
138+
raise EOFError("stream of frames ended")
139+
self.maybe_resume()
140+
assert frame.opcode is OP_CONT
141+
frames.append(frame)
142+
143+
finally:
144+
self.get_in_progress = False
145+
146+
data = b"".join(frame.data for frame in frames)
147+
if decode:
148+
return data.decode()
149+
else:
150+
return data
151+
152+
@overload
153+
def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ...
154+
155+
@overload
156+
def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...
157+
158+
@overload
159+
def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
160+
161+
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
162+
"""
163+
Stream the next message.
164+
165+
Iterating the return value of :meth:`get_iter` asynchronously yields a
166+
:class:`str` or :class:`bytes` for each frame in the message.
167+
168+
The iterator must be fully consumed before calling :meth:`get_iter` or
169+
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
170+
171+
This method only makes sense for fragmented messages. If messages aren't
172+
fragmented, use :meth:`get` instead.
173+
174+
Args:
175+
decode: :obj:`False` disables UTF-8 decoding of text frames and
176+
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
177+
binary frames and returns :class:`str`.
178+
179+
Raises:
180+
EOFError: If the stream of frames has ended.
181+
UnicodeDecodeError: If a text frame contains invalid UTF-8.
182+
ConcurrencyError: If two coroutines run :meth:`get` or
183+
:meth:`get_iter` concurrently.
184+
185+
"""
186+
if self.get_in_progress:
187+
raise ConcurrencyError("get() or get_iter() is already running")
188+
self.get_in_progress = True
189+
190+
# Locking with get_in_progress prevents concurrent execution
191+
# until get_iter() fetches a complete message or is canceled.
192+
193+
# If get_iter() raises an exception e.g. in decoder.decode(),
194+
# get_in_progress remains set and the connection becomes unusable.
195+
196+
# First frame
197+
try:
198+
frame = await self.recv_frames.receive()
199+
except trio.Cancelled:
200+
self.get_in_progress = False
201+
raise
202+
except trio.EndOfChannel:
203+
raise EOFError("stream of frames ended")
204+
self.maybe_resume()
205+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
206+
if decode is None:
207+
decode = frame.opcode is OP_TEXT
208+
if decode:
209+
decoder = UTF8Decoder()
210+
yield decoder.decode(frame.data, frame.fin)
211+
else:
212+
yield frame.data
213+
214+
# Following frames, for fragmented messages
215+
while not frame.fin:
216+
# We cannot handle trio.Cancelled because we don't buffer
217+
# previous fragments — we're streaming them. Canceling get_iter()
218+
# here will leave the assembler in a stuck state. Future calls to
219+
# get() or get_iter() will raise ConcurrencyError.
220+
try:
221+
frame = await self.recv_frames.receive()
222+
except trio.EndOfChannel:
223+
raise EOFError("stream of frames ended")
224+
self.maybe_resume()
225+
assert frame.opcode is OP_CONT
226+
if decode:
227+
yield decoder.decode(frame.data, frame.fin)
228+
else:
229+
yield frame.data
230+
231+
self.get_in_progress = False
232+
233+
def put(self, frame: Frame) -> None:
234+
"""
235+
Add ``frame`` to the next message.
236+
237+
Raises:
238+
EOFError: If the stream of frames has ended.
239+
240+
"""
241+
if self.closed:
242+
raise EOFError("stream of frames ended")
243+
244+
self.send_frames.send_nowait(frame)
245+
self.maybe_pause()
246+
247+
def maybe_pause(self) -> None:
248+
"""Pause the writer if queue is above the high water mark."""
249+
# Skip if flow control is disabled
250+
if self.high is None:
251+
return
252+
253+
# Bypass the statistics() method for performance reasons.
254+
# Check for "> high" to support high = 0
255+
if len(self.send_frames._state.data) > self.high and not self.paused:
256+
self.paused = True
257+
self.pause()
258+
259+
def maybe_resume(self) -> None:
260+
"""Resume the writer if queue is below the low water mark."""
261+
# Skip if flow control is disabled
262+
if self.low is None:
263+
return
264+
265+
# Bypass the statistics() method for performance reasons.
266+
# Check for "<= low" to support low = 0
267+
if len(self.send_frames._state.data) <= self.low and self.paused:
268+
self.paused = False
269+
self.resume()
270+
271+
def close(self) -> None:
272+
"""
273+
End the stream of frames.
274+
275+
Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
276+
or :meth:`put` is safe. They will raise :exc:`trio.EndOfChannel`.
277+
278+
"""
279+
if self.closed:
280+
return
281+
282+
self.closed = True
283+
284+
# Unblock get() or get_iter().
285+
self.send_frames.close()

tests/asyncio/test_messages.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ async def test_get_iter_fragmented_binary_message_already_received(self):
261261
async def test_get_iter_fragmented_text_message_not_received_yet(self):
262262
"""get_iter yields a fragmented text message when it is received."""
263263
iterator = aiter(self.assembler.get_iter())
264+
self.addAsyncCleanup(iterator.aclose)
264265
self.assembler.put(Frame(OP_TEXT, b"ca", fin=False))
265266
self.assertEqual(await anext(iterator), "ca")
266267
self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False))
@@ -271,6 +272,7 @@ async def test_get_iter_fragmented_text_message_not_received_yet(self):
271272
async def test_get_iter_fragmented_binary_message_not_received_yet(self):
272273
"""get_iter yields a fragmented binary message when it is received."""
273274
iterator = aiter(self.assembler.get_iter())
275+
self.addAsyncCleanup(iterator.aclose)
274276
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
275277
self.assertEqual(await anext(iterator), b"t")
276278
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
@@ -282,6 +284,7 @@ async def test_get_iter_fragmented_text_message_being_received(self):
282284
"""get_iter yields a fragmented text message that is partially received."""
283285
self.assembler.put(Frame(OP_TEXT, b"ca", fin=False))
284286
iterator = aiter(self.assembler.get_iter())
287+
self.addAsyncCleanup(iterator.aclose)
285288
self.assertEqual(await anext(iterator), "ca")
286289
self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False))
287290
self.assertEqual(await anext(iterator), "f")
@@ -292,6 +295,7 @@ async def test_get_iter_fragmented_binary_message_being_received(self):
292295
"""get_iter yields a fragmented binary message that is partially received."""
293296
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
294297
iterator = aiter(self.assembler.get_iter())
298+
self.addAsyncCleanup(iterator.aclose)
295299
self.assertEqual(await anext(iterator), b"t")
296300
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
297301
self.assertEqual(await anext(iterator), b"e")
@@ -321,18 +325,18 @@ async def test_get_iter_resumes_reading(self):
321325
self.assembler.put(Frame(OP_CONT, b"a"))
322326

323327
iterator = aiter(self.assembler.get_iter())
328+
async with contextlib.aclosing(iterator):
329+
# queue is above the low-water mark
330+
await anext(iterator)
331+
self.resume.assert_not_called()
324332

325-
# queue is above the low-water mark
326-
await anext(iterator)
327-
self.resume.assert_not_called()
333+
# queue is at the low-water mark
334+
await anext(iterator)
335+
self.resume.assert_called_once_with()
328336

329-
# queue is at the low-water mark
330-
await anext(iterator)
331-
self.resume.assert_called_once_with()
332-
333-
# queue is below the low-water mark
334-
await anext(iterator)
335-
self.resume.assert_called_once_with()
337+
# queue is below the low-water mark
338+
await anext(iterator)
339+
self.resume.assert_called_once_with()
336340

337341
async def test_get_iter_does_not_resume_reading(self):
338342
"""get_iter does not resume reading when the low-water mark is unset."""
@@ -342,9 +346,10 @@ async def test_get_iter_does_not_resume_reading(self):
342346
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
343347
self.assembler.put(Frame(OP_CONT, b"a"))
344348
iterator = aiter(self.assembler.get_iter())
345-
await anext(iterator)
346-
await anext(iterator)
347-
await anext(iterator)
349+
async with contextlib.aclosing(iterator):
350+
await anext(iterator)
351+
await anext(iterator)
352+
await anext(iterator)
348353

349354
self.resume.assert_not_called()
350355

@@ -467,7 +472,7 @@ async def test_get_iter_queued_fragmented_message_after_close(self):
467472
self.assertEqual(fragments, [b"t", b"e", b"a"])
468473

469474
async def test_get_partially_queued_fragmented_message_after_close(self):
470-
"""get raises EOF on a partial fragmented message after close is called."""
475+
"""get raises EOFError on a partial fragmented message after close is called."""
471476
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
472477
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
473478
self.assembler.close()

0 commit comments

Comments
 (0)