Skip to content

Commit e2d5d2b

Browse files
committed
fix(streaming): handle special line characters and fix multi-byte character decoding (#757)
1 parent c5eb4ea commit e2d5d2b

File tree

2 files changed

+338
-27
lines changed

2 files changed

+338
-27
lines changed

src/streaming.ts

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,6 @@ export class Stream<Item> implements AsyncIterable<Item> {
2323

2424
static fromSSEResponse<Item>(response: Response, controller: AbortController) {
2525
let consumed = false;
26-
const decoder = new SSEDecoder();
27-
28-
async function* iterMessages(): AsyncGenerator<ServerSentEvent, void, unknown> {
29-
if (!response.body) {
30-
controller.abort();
31-
throw new OpenAIError(`Attempted to iterate over a response with no body`);
32-
}
33-
34-
const lineDecoder = new LineDecoder();
35-
36-
const iter = readableStreamAsyncIterable<Bytes>(response.body);
37-
for await (const chunk of iter) {
38-
for (const line of lineDecoder.decode(chunk)) {
39-
const sse = decoder.decode(line);
40-
if (sse) yield sse;
41-
}
42-
}
43-
44-
for (const line of lineDecoder.flush()) {
45-
const sse = decoder.decode(line);
46-
if (sse) yield sse;
47-
}
48-
}
4926

5027
async function* iterator(): AsyncIterator<Item, any, undefined> {
5128
if (consumed) {
@@ -54,7 +31,7 @@ export class Stream<Item> implements AsyncIterable<Item> {
5431
consumed = true;
5532
let done = false;
5633
try {
57-
for await (const sse of iterMessages()) {
34+
for await (const sse of _iterSSEMessages(response, controller)) {
5835
if (done) continue;
5936

6037
if (sse.data.startsWith('[DONE]')) {
@@ -220,6 +197,97 @@ export class Stream<Item> implements AsyncIterable<Item> {
220197
}
221198
}
222199

200+
export async function* _iterSSEMessages(
201+
response: Response,
202+
controller: AbortController,
203+
): AsyncGenerator<ServerSentEvent, void, unknown> {
204+
if (!response.body) {
205+
controller.abort();
206+
throw new OpenAIError(`Attempted to iterate over a response with no body`);
207+
}
208+
209+
const sseDecoder = new SSEDecoder();
210+
const lineDecoder = new LineDecoder();
211+
212+
const iter = readableStreamAsyncIterable<Bytes>(response.body);
213+
for await (const sseChunk of iterSSEChunks(iter)) {
214+
for (const line of lineDecoder.decode(sseChunk)) {
215+
const sse = sseDecoder.decode(line);
216+
if (sse) yield sse;
217+
}
218+
}
219+
220+
for (const line of lineDecoder.flush()) {
221+
const sse = sseDecoder.decode(line);
222+
if (sse) yield sse;
223+
}
224+
}
225+
226+
/**
227+
* Given an async iterable iterator, iterates over it and yields full
228+
* SSE chunks, i.e. yields when a double new-line is encountered.
229+
*/
230+
async function* iterSSEChunks(iterator: AsyncIterableIterator<Bytes>): AsyncGenerator<Uint8Array> {
231+
let data = new Uint8Array();
232+
233+
for await (const chunk of iterator) {
234+
if (chunk == null) {
235+
continue;
236+
}
237+
238+
const binaryChunk =
239+
chunk instanceof ArrayBuffer ? new Uint8Array(chunk)
240+
: typeof chunk === 'string' ? new TextEncoder().encode(chunk)
241+
: chunk;
242+
243+
let newData = new Uint8Array(data.length + binaryChunk.length);
244+
newData.set(data);
245+
newData.set(binaryChunk, data.length);
246+
data = newData;
247+
248+
let patternIndex;
249+
while ((patternIndex = findDoubleNewlineIndex(data)) !== -1) {
250+
yield data.slice(0, patternIndex);
251+
data = data.slice(patternIndex);
252+
}
253+
}
254+
255+
if (data.length > 0) {
256+
yield data;
257+
}
258+
}
259+
260+
function findDoubleNewlineIndex(buffer: Uint8Array): number {
261+
// This function searches the buffer for the end patterns (\r\r, \n\n, \r\n\r\n)
262+
// and returns the index right after the first occurrence of any pattern,
263+
// or -1 if none of the patterns are found.
264+
const newline = 0x0a; // \n
265+
const carriage = 0x0d; // \r
266+
267+
for (let i = 0; i < buffer.length - 2; i++) {
268+
if (buffer[i] === newline && buffer[i + 1] === newline) {
269+
// \n\n
270+
return i + 2;
271+
}
272+
if (buffer[i] === carriage && buffer[i + 1] === carriage) {
273+
// \r\r
274+
return i + 2;
275+
}
276+
if (
277+
buffer[i] === carriage &&
278+
buffer[i + 1] === newline &&
279+
i + 3 < buffer.length &&
280+
buffer[i + 2] === carriage &&
281+
buffer[i + 3] === newline
282+
) {
283+
// \r\n\r\n
284+
return i + 4;
285+
}
286+
}
287+
288+
return -1;
289+
}
290+
223291
class SSEDecoder {
224292
private data: string[];
225293
private event: string | null;
@@ -283,8 +351,8 @@ class SSEDecoder {
283351
*/
284352
class LineDecoder {
285353
// prettier-ignore
286-
static NEWLINE_CHARS = new Set(['\n', '\r', '\x0b', '\x0c', '\x1c', '\x1d', '\x1e', '\x85', '\u2028', '\u2029']);
287-
static NEWLINE_REGEXP = /\r\n|[\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029]/g;
354+
static NEWLINE_CHARS = new Set(['\n', '\r']);
355+
static NEWLINE_REGEXP = /\r\n|[\n\r]/g;
288356

289357
buffer: string[];
290358
trailingCR: boolean;

0 commit comments

Comments
 (0)