Skip to content

Commit c93c299

Browse files
chore(internal): streaming updates (#340)
1 parent 405e0fb commit c93c299

File tree

2 files changed

+253
-88
lines changed

2 files changed

+253
-88
lines changed

src/finch/_streaming.py

+47-26
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Stream(Generic[_T]):
2323

2424
response: httpx.Response
2525

26-
_decoder: SSEDecoder | SSEBytesDecoder
26+
_decoder: SSEBytesDecoder
2727

2828
def __init__(
2929
self,
@@ -46,10 +46,7 @@ def __iter__(self) -> Iterator[_T]:
4646
yield item
4747

4848
def _iter_events(self) -> Iterator[ServerSentEvent]:
49-
if isinstance(self._decoder, SSEBytesDecoder):
50-
yield from self._decoder.iter_bytes(self.response.iter_bytes())
51-
else:
52-
yield from self._decoder.iter(self.response.iter_lines())
49+
yield from self._decoder.iter_bytes(self.response.iter_bytes())
5350

5451
def __stream__(self) -> Iterator[_T]:
5552
cast_to = cast(Any, self._cast_to)
@@ -112,12 +109,8 @@ async def __aiter__(self) -> AsyncIterator[_T]:
112109
yield item
113110

114111
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
115-
if isinstance(self._decoder, SSEBytesDecoder):
116-
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
117-
yield sse
118-
else:
119-
async for sse in self._decoder.aiter(self.response.aiter_lines()):
120-
yield sse
112+
async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
113+
yield sse
121114

122115
async def __stream__(self) -> AsyncIterator[_T]:
123116
cast_to = cast(Any, self._cast_to)
@@ -205,21 +198,49 @@ def __init__(self) -> None:
205198
self._last_event_id = None
206199
self._retry = None
207200

208-
def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]:
209-
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
210-
for line in iterator:
211-
line = line.rstrip("\n")
212-
sse = self.decode(line)
213-
if sse is not None:
214-
yield sse
215-
216-
async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]:
217-
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
218-
async for line in iterator:
219-
line = line.rstrip("\n")
220-
sse = self.decode(line)
221-
if sse is not None:
222-
yield sse
201+
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
202+
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
203+
for chunk in self._iter_chunks(iterator):
204+
# Split before decoding so splitlines() only uses \r and \n
205+
for raw_line in chunk.splitlines():
206+
line = raw_line.decode("utf-8")
207+
sse = self.decode(line)
208+
if sse:
209+
yield sse
210+
211+
def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]:
212+
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
213+
data = b""
214+
for chunk in iterator:
215+
for line in chunk.splitlines(keepends=True):
216+
data += line
217+
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
218+
yield data
219+
data = b""
220+
if data:
221+
yield data
222+
223+
async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
224+
"""Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
225+
async for chunk in self._aiter_chunks(iterator):
226+
# Split before decoding so splitlines() only uses \r and \n
227+
for raw_line in chunk.splitlines():
228+
line = raw_line.decode("utf-8")
229+
sse = self.decode(line)
230+
if sse:
231+
yield sse
232+
233+
async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
234+
"""Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
235+
data = b""
236+
async for chunk in iterator:
237+
for line in chunk.splitlines(keepends=True):
238+
data += line
239+
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
240+
yield data
241+
data = b""
242+
if data:
243+
yield data
223244

224245
def decode(self, line: str) -> ServerSentEvent | None:
225246
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501

0 commit comments

Comments
 (0)