@@ -23,7 +23,7 @@ class Stream(Generic[_T]):
23
23
24
24
response : httpx .Response
25
25
26
- _decoder : SSEDecoder | SSEBytesDecoder
26
+ _decoder : SSEBytesDecoder
27
27
28
28
def __init__ (
29
29
self ,
@@ -46,10 +46,7 @@ def __iter__(self) -> Iterator[_T]:
46
46
yield item
47
47
48
48
def _iter_events (self ) -> Iterator [ServerSentEvent ]:
49
- if isinstance (self ._decoder , SSEBytesDecoder ):
50
- yield from self ._decoder .iter_bytes (self .response .iter_bytes ())
51
- else :
52
- yield from self ._decoder .iter (self .response .iter_lines ())
49
+ yield from self ._decoder .iter_bytes (self .response .iter_bytes ())
53
50
54
51
def __stream__ (self ) -> Iterator [_T ]:
55
52
cast_to = cast (Any , self ._cast_to )
@@ -112,12 +109,8 @@ async def __aiter__(self) -> AsyncIterator[_T]:
112
109
yield item
113
110
114
111
async def _iter_events (self ) -> AsyncIterator [ServerSentEvent ]:
115
- if isinstance (self ._decoder , SSEBytesDecoder ):
116
- async for sse in self ._decoder .aiter_bytes (self .response .aiter_bytes ()):
117
- yield sse
118
- else :
119
- async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
120
- yield sse
112
+ async for sse in self ._decoder .aiter_bytes (self .response .aiter_bytes ()):
113
+ yield sse
121
114
122
115
async def __stream__ (self ) -> AsyncIterator [_T ]:
123
116
cast_to = cast (Any , self ._cast_to )
@@ -205,21 +198,49 @@ def __init__(self) -> None:
205
198
self ._last_event_id = None
206
199
self ._retry = None
207
200
208
- def iter (self , iterator : Iterator [str ]) -> Iterator [ServerSentEvent ]:
209
- """Given an iterator that yields lines, iterate over it & yield every event encountered"""
210
- for line in iterator :
211
- line = line .rstrip ("\n " )
212
- sse = self .decode (line )
213
- if sse is not None :
214
- yield sse
215
-
216
- async def aiter (self , iterator : AsyncIterator [str ]) -> AsyncIterator [ServerSentEvent ]:
217
- """Given an async iterator that yields lines, iterate over it & yield every event encountered"""
218
- async for line in iterator :
219
- line = line .rstrip ("\n " )
220
- sse = self .decode (line )
221
- if sse is not None :
222
- yield sse
201
+ def iter_bytes (self , iterator : Iterator [bytes ]) -> Iterator [ServerSentEvent ]:
202
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
203
+ for chunk in self ._iter_chunks (iterator ):
204
+ # Split before decoding so splitlines() only uses \r and \n
205
+ for raw_line in chunk .splitlines ():
206
+ line = raw_line .decode ("utf-8" )
207
+ sse = self .decode (line )
208
+ if sse :
209
+ yield sse
210
+
211
+ def _iter_chunks (self , iterator : Iterator [bytes ]) -> Iterator [bytes ]:
212
+ """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
213
+ data = b""
214
+ for chunk in iterator :
215
+ for line in chunk .splitlines (keepends = True ):
216
+ data += line
217
+ if data .endswith ((b"\r \r " , b"\n \n " , b"\r \n \r \n " )):
218
+ yield data
219
+ data = b""
220
+ if data :
221
+ yield data
222
+
223
+ async def aiter_bytes (self , iterator : AsyncIterator [bytes ]) -> AsyncIterator [ServerSentEvent ]:
224
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
225
+ async for chunk in self ._aiter_chunks (iterator ):
226
+ # Split before decoding so splitlines() only uses \r and \n
227
+ for raw_line in chunk .splitlines ():
228
+ line = raw_line .decode ("utf-8" )
229
+ sse = self .decode (line )
230
+ if sse :
231
+ yield sse
232
+
233
+ async def _aiter_chunks (self , iterator : AsyncIterator [bytes ]) -> AsyncIterator [bytes ]:
234
+ """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks"""
235
+ data = b""
236
+ async for chunk in iterator :
237
+ for line in chunk .splitlines (keepends = True ):
238
+ data += line
239
+ if data .endswith ((b"\r \r " , b"\n \n " , b"\r \n \r \n " )):
240
+ yield data
241
+ data = b""
242
+ if data :
243
+ yield data
223
244
224
245
def decode (self , line : str ) -> ServerSentEvent | None :
225
246
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
0 commit comments