Skip to content

Fix stream() support for chunked event streams #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ declare module "replicate" {
signature?: string;
},
secret: string
): boolean;
): Promise<boolean>;

export function parseProgressFromLogs(logs: Prediction | string): {
percentage: number;
Expand Down
6 changes: 5 additions & 1 deletion index.js
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,11 @@ class Replicate {

if (prediction.urls && prediction.urls.stream) {
const { signal } = options;
const stream = new Stream(prediction.urls.stream, { signal });
const stream = new Stream({
url: prediction.urls.stream,
fetch: this.fetch,
options: { signal },
});
yield* stream;
} else {
throw new Error("Prediction does not support streaming");
Expand Down
324 changes: 321 additions & 3 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import Replicate, {
} from "replicate";
import nock from "nock";
import fetch from "cross-fetch";
import { Stream } from "./lib/stream";
import { PassThrough } from "node:stream";

let client: Replicate;
const BASE_URL = "https://api.replicate.com/v1";
Expand Down Expand Up @@ -251,7 +253,7 @@ describe("Replicate client", () => {
let actual: Record<string, any> | undefined;
nock(BASE_URL)
.post("/predictions")
.reply(201, (uri: string, body: Record<string, any>) => {
.reply(201, (_uri: string, body: Record<string, any>) => {
actual = body;
return body;
});
Expand Down Expand Up @@ -1010,8 +1012,6 @@ describe("Replicate client", () => {
});

test("Calls the correct API routes for a model", async () => {
const firstPollingRequest = true;

nock(BASE_URL)
.post("/models/replicate/hello-world/predictions")
.reply(201, {
Expand Down Expand Up @@ -1187,4 +1187,322 @@ describe("Replicate client", () => {
});

// Continue with tests for other methods

describe("Stream", () => {
function createStream(body: string | NodeJS.ReadableStream, status = 200) {
const streamEndpoint = "https://stream.replicate.com";
nock(streamEndpoint)
.get("/fake_stream")
.matchHeader("Accept", "text/event-stream")
.reply(status, body);

return new Stream({ url: `${streamEndpoint}/fake_stream`, fetch });
}

test("consumes a server sent event stream", async () => {
const stream = createStream(
`
event: output
id: EVENT_1
data: hello world

event: done
id: EVENT_2
data: {}
`
.trim()
.replace(/^[ ]+/gm, "")
);

const iterator = stream[Symbol.asyncIterator]();

expect(await iterator.next()).toEqual({
done: false,
value: { event: "output", id: "EVENT_1", data: "hello world" },
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});
Comment on lines +1223 to +1226
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bug, rather than a feature because it means that we always need to check event.event === "output" before accessing event.data. The toString function kind of hides this but it doesn't work with JSON.stringify().

I think actually we want to return the done event from the iterator rather than yielding it. Then the done iterator will look like:

expect(await iterator.next()).toEqual({ done: true, value: { event: "done", ... } });

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable to me. Any reason not to make that change to return for done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's technically a breaking change, but unclear how much of an issue it would cause. There's nothing you can do with it except listen for it and ignore it.

expect(await iterator.next()).toEqual({ done: true });
expect(await iterator.next()).toEqual({ done: true });
});

test("consumes multiple events", async () => {
const stream = createStream(
`
event: output
id: EVENT_1
data: hello world

event: output
id: EVENT_2
data: hello dave

event: done
id: EVENT_3
data: {}
`
.trim()
.replace(/^[ ]+/gm, "")
);

const iterator = stream[Symbol.asyncIterator]();

expect(await iterator.next()).toEqual({
done: false,
value: { event: "output", id: "EVENT_1", data: "hello world" },
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "output", id: "EVENT_2", data: "hello dave" },
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_3", data: "{}" },
});
expect(await iterator.next()).toEqual({ done: true });
expect(await iterator.next()).toEqual({ done: true });
});

test("ignores unexpected characters", async () => {
const stream = createStream(
`
: hi

event: output
id: EVENT_1
data: hello world

event: done
id: EVENT_2
data: {}
`
.trim()
.replace(/^[ ]+/gm, "")
);

const iterator = stream[Symbol.asyncIterator]();

expect(await iterator.next()).toEqual({
done: false,
value: { event: "output", id: "EVENT_1", data: "hello world" },
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});
expect(await iterator.next()).toEqual({ done: true });
expect(await iterator.next()).toEqual({ done: true });
});

test("supports multiple lines of output in a single event", async () => {
const stream = createStream(
`
: hi

event: output
id: EVENT_1
data: hello,
data: this is a new line,
data: and this is a new line too

event: done
id: EVENT_2
data: {}
`
.trim()
.replace(/^[ ]+/gm, "")
);

const iterator = stream[Symbol.asyncIterator]();

expect(await iterator.next()).toEqual({
done: false,
value: {
event: "output",
id: "EVENT_1",
data: "hello,\nthis is a new line,\nand this is a new line too",
},
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});
expect(await iterator.next()).toEqual({ done: true });
expect(await iterator.next()).toEqual({ done: true });
});

test("supports the server writing data lines in multiple chunks", async () => {
const body = new PassThrough();
const stream = createStream(body);

// Create a stream of data chunks split on the pipe character for readability.
const data = `
event: output
id: EVENT_1
data: hello,|
data: this is a new line,|
data: and this is a new line too

event: done
id: EVENT_2
data: {}
`
.trim()
.replace(/^[ ]+/gm, "");

const chunks = data.split("|");

// Consume the iterator in parallel to writing it.
const reading = new Promise((resolve, reject) => {
(async () => {
const iterator = stream[Symbol.asyncIterator]();
expect(await iterator.next()).toEqual({
done: false,
value: {
event: "output",
id: "EVENT_1",
data: "hello,\nthis is a new line,\nand this is a new line too",
},
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});
expect(await iterator.next()).toEqual({ done: true });
})().then(resolve, reject);
});

// Write the chunks to the stream at an interval.
const writing = new Promise((resolve, reject) => {
(async () => {
for await (const chunk of chunks) {
body.write(chunk);
await new Promise((resolve) => setTimeout(resolve, 1));
}
body.end();
resolve(null);
})().then(resolve, reject);
});

// Wait for both promises to resolve.
await Promise.all([reading, writing]);
});

test("supports the server writing data in a complete mess", async () => {
const body = new PassThrough();
const stream = createStream(body);

// Create a stream of data chunks split on the pipe character for readability.
const data = `
: hi

ev|ent: output
id: EVENT_1
data: hello,
data: this |is a new line,|
data: and this is |a new line too

event: d|one
id: EVENT|_2
data: {}
`
.trim()
.replace(/^[ ]+/gm, "");

const chunks = data.split("|");

// Consume the iterator in parallel to writing it.
const reading = new Promise((resolve, reject) => {
(async () => {
const iterator = stream[Symbol.asyncIterator]();
expect(await iterator.next()).toEqual({
done: false,
value: {
event: "output",
id: "EVENT_1",
data: "hello,\nthis is a new line,\nand this is a new line too",
},
});
expect(await iterator.next()).toEqual({
done: false,
value: { event: "done", id: "EVENT_2", data: "{}" },
});
expect(await iterator.next()).toEqual({ done: true });
})().then(resolve, reject);
});

// Write the chunks to the stream at an interval.
const writing = new Promise((resolve, reject) => {
(async () => {
for await (const chunk of chunks) {
body.write(chunk);
await new Promise((resolve) => setTimeout(resolve, 1));
}
body.end();
resolve(null);
})().then(resolve, reject);
});

// Wait for both promises to resolve.
await Promise.all([reading, writing]);
});

test("supports ending without a done", async () => {
const stream = createStream(
`
event: output
id: EVENT_1
data: hello world

`
.trim()
.replace(/^[ ]+/gm, "")
);

const iterator = stream[Symbol.asyncIterator]();
expect(await iterator.next()).toEqual({
done: false,
value: { event: "output", id: "EVENT_1", data: "hello world" },
});
expect(await iterator.next()).toEqual({ done: true });
});

test("an error event in the stream raises an exception", async () => {
const stream = createStream(
`
event: output
id: EVENT_1
data: hello world

event: error
id: EVENT_2
data: An unexpected error occurred

`
.trim()
.replace(/^[ ]+/gm, "")
);

const iterator = stream[Symbol.asyncIterator]();
expect(await iterator.next()).toEqual({
done: false,
value: { event: "output", id: "EVENT_1", data: "hello world" },
});
await expect(iterator.next()).rejects.toThrowError(
"An unexpected error occurred"
);
expect(await iterator.next()).toEqual({ done: true });
});

test("an error when fetching the stream raises an exception", async () => {
const stream = createStream("{}", 500);
const iterator = stream[Symbol.asyncIterator]();
await expect(iterator.next()).rejects.toThrowError(
"Request to https://stream.replicate.com/fake_stream failed with status 500 Internal Server Error: {}."
);
expect(await iterator.next()).toEqual({ done: true });
});
});
});
Loading