@@ -24,7 +24,7 @@ class Stream(Generic[_T]):
24
24
25
25
response : httpx .Response
26
26
27
- _decoder : SSEDecoder | SSEBytesDecoder
27
+ _decoder : SSEBytesDecoder
28
28
29
29
def __init__ (
30
30
self ,
@@ -47,10 +47,7 @@ def __iter__(self) -> Iterator[_T]:
47
47
yield item
48
48
49
49
def _iter_events (self ) -> Iterator [ServerSentEvent ]:
50
- if isinstance (self ._decoder , SSEBytesDecoder ):
51
- yield from self ._decoder .iter_bytes (self .response .iter_bytes ())
52
- else :
53
- yield from self ._decoder .iter (self .response .iter_lines ())
50
+ yield from self ._decoder .iter_bytes (self .response .iter_bytes ())
54
51
55
52
def __stream__ (self ) -> Iterator [_T ]:
56
53
cast_to = cast (Any , self ._cast_to )
@@ -151,12 +148,8 @@ async def __aiter__(self) -> AsyncIterator[_T]:
151
148
yield item
152
149
153
150
async def _iter_events (self ) -> AsyncIterator [ServerSentEvent ]:
154
- if isinstance (self ._decoder , SSEBytesDecoder ):
155
- async for sse in self ._decoder .aiter_bytes (self .response .aiter_bytes ()):
156
- yield sse
157
- else :
158
- async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
159
- yield sse
151
+ async for sse in self ._decoder .aiter_bytes (self .response .aiter_bytes ()):
152
+ yield sse
160
153
161
154
async def __stream__ (self ) -> AsyncIterator [_T ]:
162
155
cast_to = cast (Any , self ._cast_to )
@@ -282,21 +275,49 @@ def __init__(self) -> None:
282
275
self ._last_event_id = None
283
276
self ._retry = None
284
277
285
- def iter (self , iterator : Iterator [str ]) -> Iterator [ServerSentEvent ]:
286
- """Given an iterator that yields lines, iterate over it & yield every event encountered"""
287
- for line in iterator :
288
- line = line .rstrip ("\n " )
289
- sse = self .decode (line )
290
- if sse is not None :
291
- yield sse
292
-
293
- async def aiter (self , iterator : AsyncIterator [str ]) -> AsyncIterator [ServerSentEvent ]:
294
- """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
295
- async for line in iterator :
296
- line = line .rstrip ("\n " )
297
- sse = self .decode (line )
298
- if sse is not None :
299
- yield sse
278
+ def iter_bytes (self , iterator : Iterator [bytes ]) -> Iterator [ServerSentEvent ]:
279
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
280
+ for chunk in self ._iter_chunks (iterator ):
281
+ # Split before decoding so splitlines() only uses \r and \n
282
+ for raw_line in chunk .splitlines ():
283
+ line = raw_line .decode ("utf-8" )
284
+ sse = self .decode (line )
285
+ if sse :
286
+ yield sse
287
+
288
+ def _iter_chunks (self , iterator : Iterator [bytes ]) -> Iterator [bytes ]:
289
+ """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
290
+ data = b""
291
+ for chunk in iterator :
292
+ for line in chunk .splitlines (keepends = True ):
293
+ data += line
294
+ if data .endswith ((b"\r \r " , b"\n \n " , b"\r \n \r \n " )):
295
+ yield data
296
+ data = b""
297
+ if data :
298
+ yield data
299
+
300
+ async def aiter_bytes (self , iterator : AsyncIterator [bytes ]) -> AsyncIterator [ServerSentEvent ]:
301
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
302
+ async for chunk in self ._aiter_chunks (iterator ):
303
+ # Split before decoding so splitlines() only uses \r and \n
304
+ for raw_line in chunk .splitlines ():
305
+ line = raw_line .decode ("utf-8" )
306
+ sse = self .decode (line )
307
+ if sse :
308
+ yield sse
309
+
310
+ async def _aiter_chunks (self , iterator : AsyncIterator [bytes ]) -> AsyncIterator [bytes ]:
311
+ """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
312
+ data = b""
313
+ async for chunk in iterator :
314
+ for line in chunk .splitlines (keepends = True ):
315
+ data += line
316
+ if data .endswith ((b"\r \r " , b"\n \n " , b"\r \n \r \n " )):
317
+ yield data
318
+ data = b""
319
+ if data :
320
+ yield data
300
321
301
322
def decode (self , line : str ) -> ServerSentEvent | None :
302
323
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
0 commit comments