@@ -23,29 +23,6 @@ export class Stream<Item> implements AsyncIterable<Item> {
23
23
24
24
static fromSSEResponse < Item > ( response : Response , controller : AbortController ) {
25
25
let consumed = false ;
26
- const decoder = new SSEDecoder ( ) ;
27
-
28
- async function * iterMessages ( ) : AsyncGenerator < ServerSentEvent , void , unknown > {
29
- if ( ! response . body ) {
30
- controller . abort ( ) ;
31
- throw new OpenAIError ( `Attempted to iterate over a response with no body` ) ;
32
- }
33
-
34
- const lineDecoder = new LineDecoder ( ) ;
35
-
36
- const iter = readableStreamAsyncIterable < Bytes > ( response . body ) ;
37
- for await ( const chunk of iter ) {
38
- for ( const line of lineDecoder . decode ( chunk ) ) {
39
- const sse = decoder . decode ( line ) ;
40
- if ( sse ) yield sse ;
41
- }
42
- }
43
-
44
- for ( const line of lineDecoder . flush ( ) ) {
45
- const sse = decoder . decode ( line ) ;
46
- if ( sse ) yield sse ;
47
- }
48
- }
49
26
50
27
async function * iterator ( ) : AsyncIterator < Item , any , undefined > {
51
28
if ( consumed ) {
@@ -54,7 +31,7 @@ export class Stream<Item> implements AsyncIterable<Item> {
54
31
consumed = true ;
55
32
let done = false ;
56
33
try {
57
- for await ( const sse of iterMessages ( ) ) {
34
+ for await ( const sse of _iterSSEMessages ( response , controller ) ) {
58
35
if ( done ) continue ;
59
36
60
37
if ( sse . data . startsWith ( '[DONE]' ) ) {
@@ -220,6 +197,97 @@ export class Stream<Item> implements AsyncIterable<Item> {
220
197
}
221
198
}
222
199
200
+ export async function * _iterSSEMessages (
201
+ response : Response ,
202
+ controller : AbortController ,
203
+ ) : AsyncGenerator < ServerSentEvent , void , unknown > {
204
+ if ( ! response . body ) {
205
+ controller . abort ( ) ;
206
+ throw new OpenAIError ( `Attempted to iterate over a response with no body` ) ;
207
+ }
208
+
209
+ const sseDecoder = new SSEDecoder ( ) ;
210
+ const lineDecoder = new LineDecoder ( ) ;
211
+
212
+ const iter = readableStreamAsyncIterable < Bytes > ( response . body ) ;
213
+ for await ( const sseChunk of iterSSEChunks ( iter ) ) {
214
+ for ( const line of lineDecoder . decode ( sseChunk ) ) {
215
+ const sse = sseDecoder . decode ( line ) ;
216
+ if ( sse ) yield sse ;
217
+ }
218
+ }
219
+
220
+ for ( const line of lineDecoder . flush ( ) ) {
221
+ const sse = sseDecoder . decode ( line ) ;
222
+ if ( sse ) yield sse ;
223
+ }
224
+ }
225
+
226
+ /**
227
+ * Given an async iterable iterator, iterates over it and yields full
228
+ * SSE chunks, i.e. yields when a double new-line is encountered.
229
+ */
230
+ async function * iterSSEChunks ( iterator : AsyncIterableIterator < Bytes > ) : AsyncGenerator < Uint8Array > {
231
+ let data = new Uint8Array ( ) ;
232
+
233
+ for await ( const chunk of iterator ) {
234
+ if ( chunk == null ) {
235
+ continue ;
236
+ }
237
+
238
+ const binaryChunk =
239
+ chunk instanceof ArrayBuffer ? new Uint8Array ( chunk )
240
+ : typeof chunk === 'string' ? new TextEncoder ( ) . encode ( chunk )
241
+ : chunk ;
242
+
243
+ let newData = new Uint8Array ( data . length + binaryChunk . length ) ;
244
+ newData . set ( data ) ;
245
+ newData . set ( binaryChunk , data . length ) ;
246
+ data = newData ;
247
+
248
+ let patternIndex ;
249
+ while ( ( patternIndex = findDoubleNewlineIndex ( data ) ) !== - 1 ) {
250
+ yield data . slice ( 0 , patternIndex ) ;
251
+ data = data . slice ( patternIndex ) ;
252
+ }
253
+ }
254
+
255
+ if ( data . length > 0 ) {
256
+ yield data ;
257
+ }
258
+ }
259
+
260
+ function findDoubleNewlineIndex ( buffer : Uint8Array ) : number {
261
+ // This function searches the buffer for the end patterns (\r\r, \n\n, \r\n\r\n)
262
+ // and returns the index right after the first occurrence of any pattern,
263
+ // or -1 if none of the patterns are found.
264
+ const newline = 0x0a ; // \n
265
+ const carriage = 0x0d ; // \r
266
+
267
+ for ( let i = 0 ; i < buffer . length - 2 ; i ++ ) {
268
+ if ( buffer [ i ] === newline && buffer [ i + 1 ] === newline ) {
269
+ // \n\n
270
+ return i + 2 ;
271
+ }
272
+ if ( buffer [ i ] === carriage && buffer [ i + 1 ] === carriage ) {
273
+ // \r\r
274
+ return i + 2 ;
275
+ }
276
+ if (
277
+ buffer [ i ] === carriage &&
278
+ buffer [ i + 1 ] === newline &&
279
+ i + 3 < buffer . length &&
280
+ buffer [ i + 2 ] === carriage &&
281
+ buffer [ i + 3 ] === newline
282
+ ) {
283
+ // \r\n\r\n
284
+ return i + 4 ;
285
+ }
286
+ }
287
+
288
+ return - 1 ;
289
+ }
290
+
223
291
class SSEDecoder {
224
292
private data : string [ ] ;
225
293
private event : string | null ;
@@ -283,8 +351,8 @@ class SSEDecoder {
283
351
*/
284
352
class LineDecoder {
285
353
// prettier-ignore
286
- static NEWLINE_CHARS = new Set ( [ '\n' , '\r' , '\x0b' , '\x0c' , '\x1c' , '\x1d' , '\x1e' , '\x85' , '\u2028' , '\u2029' ] ) ;
287
- static NEWLINE_REGEXP = / \r \n | [ \n \r \x0b \x0c \x1c \x1d \x1e \x85 \u2028 \u2029 ] / g;
354
+ static NEWLINE_CHARS = new Set ( [ '\n' , '\r' ] ) ;
355
+ static NEWLINE_REGEXP = / \r \n | [ \n \r ] / g;
288
356
289
357
buffer : string [ ] ;
290
358
trailingCR : boolean ;
0 commit comments