From ff0b156df01c8f3dde9ca6f34ef07d6d4921e6c9 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 5 Mar 2024 13:02:48 +0000 Subject: [PATCH 1/5] Use the fetch() function provided in the Stream implementation --- index.js | 6 +++++- lib/stream.js | 11 +++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/index.js b/index.js index 24376fe..a1b74b2 100644 --- a/index.js +++ b/index.js @@ -289,7 +289,11 @@ class Replicate { if (prediction.urls && prediction.urls.stream) { const { signal } = options; - const stream = new Stream(prediction.urls.stream, { signal }); + const stream = new Stream({ + url: prediction.urls.stream, + fetch: this.fetch, + options: { signal }, + }); yield* stream; } else { throw new Error("Prediction does not support streaming"); diff --git a/lib/stream.js b/lib/stream.js index 012d6d0..d962ce1 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -48,10 +48,12 @@ class Stream extends Readable { /** * Create a new stream of server-sent events. * - * @param {string} url The URL to connect to. - * @param {object} options The fetch options. + * @param {object} config + * @param {string} config.url The URL to connect to. + * @param {Function} [config.fetch] The fetch implemention to use. + * @param {object} [config.options] The fetch options. */ - constructor(url, options) { + constructor({ url, fetch = globalThis.fetch, options = {} }) { if (!Readable) { throw new Error( "Readable streams are not supported. Please use Node.js 18 or later, or install the readable-stream package." @@ -60,6 +62,7 @@ class Stream extends Readable { super(); this.url = url; + this.fetch = fetch; this.options = options; this.event = null; @@ -104,7 +107,7 @@ class Stream extends Readable { } async *[Symbol.asyncIterator]() { - const response = await fetch(this.url, { + const response = await this.fetch(this.url, { ...this.options, headers: { Accept: "text/event-stream", From 913f524b5833aff65a48ecc127d39f08f94b60fd Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 5 Mar 2024 13:03:04 +0000 Subject: [PATCH 2/5] Allow js files to be imported into tests --- tsconfig.json | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tsconfig.json b/tsconfig.json index 7a564ee..b699d79 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -2,9 +2,10 @@ "compilerOptions": { "esModuleInterop": true, "noEmit": true, - "strict": true + "strict": true, + "allowJs": true }, "exclude": [ "**/node_modules" ] -} +} \ No newline at end of file From 3efcc3e8d8c1e72722fb0bf0141731f6cdea0145 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 5 Mar 2024 13:04:04 +0000 Subject: [PATCH 3/5] Handle processing chunked event streams --- index.test.ts | 290 +++++++++++++++++++++++++++++++++++++++++++++++++- lib/stream.js | 26 ++++- 2 files changed, 310 insertions(+), 6 deletions(-) diff --git a/index.test.ts b/index.test.ts index 97abc6f..f90a4ae 100644 --- a/index.test.ts +++ b/index.test.ts @@ -8,6 +8,8 @@ import Replicate, { } from "replicate"; import nock from "nock"; import fetch from "cross-fetch"; +import { Stream } from "./lib/stream"; +import { PassThrough } from "node:stream"; let client: Replicate; const BASE_URL = "https://api.replicate.com/v1"; @@ -251,7 +253,7 @@ describe("Replicate client", () => { let actual: Record | undefined; nock(BASE_URL) .post("/predictions") - .reply(201, (uri: string, body: Record) => { + .reply(201, (_uri: string, body: Record) => { actual = body; return body; }); @@ -1010,8 +1012,6 @@ describe("Replicate client", () => { }); test("Calls the correct API routes for a model", async () => { - const firstPollingRequest = true; - nock(BASE_URL) .post("/models/replicate/hello-world/predictions") .reply(201, { @@ -1179,7 +1179,7 @@ describe("Replicate client", () => { // This is a test secret and should not be used in production const secret = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw"; - const isValid = await validateWebhook(request, secret); + const isValid = validateWebhook(request, secret); expect(isValid).toBe(true); }); @@ -1187,4 +1187,286 @@ describe("Replicate client", () => { }); // Continue with tests for other methods + + describe("Stream", () => { + function createStream(body: string | NodeJS.ReadableStream) { + const streamEndpoint = "https://stream.replicate.com"; + nock(streamEndpoint) + .get("/fake_stream") + .matchHeader("Accept", "text/event-stream") + .reply(200, body); + + return new Stream({ url: `${streamEndpoint}/fake_stream`, fetch }); + } + + test("consumes a server sent event stream", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data: hello world + + event: done + id: EVENT_2 + data: {} + ` + .trim() + .replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("consumes multiple events", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data: hello world + + event: output + id: EVENT_2 + data: hello dave + + event: done + id: EVENT_3 + data: {} + ` + .trim() + .replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_2", data: "hello dave" }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_3", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("ignores unexpected characters", async () => { + const stream = createStream( + ` + : hi + + event: output + id: EVENT_1 + data: hello world + + event: done + id: EVENT_2 + data: {} + ` + .trim() + .replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("supports multiple lines of output in a single event", async () => { + const stream = createStream( + ` + : hi + + event: output + id: EVENT_1 + data: hello, + data: this is a new line, + data: and this is a new line too + + event: done + id: EVENT_2 + data: {} + ` + .trim() + .replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "hello,\nthis is a new line,\nand this is a new line too", + }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("supports the server writing data lines in multiple chunks", async () => { + const body = new PassThrough(); + const stream = createStream(body); + + // Create a stream of data chunks split on the pipe character for readability. + const data = ` + event: output + id: EVENT_1 + data: hello,| + data: this is a new line,| + data: and this is a new line too + + event: done + id: EVENT_2 + data: {} + ` + .trim() + .replace(/^[ ]+/gm, ""); + + const chunks = data.split("|"); + + // Consume the iterator in parallel to writing it. + const reading = new Promise((resolve, reject) => { + (async () => { + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "hello,\nthis is a new line,\nand this is a new line too", + }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + })().then(resolve, reject); + }); + + // Write the chunks to the stream at an interval. + const writing = new Promise((resolve, reject) => { + (async () => { + for await (const chunk of chunks) { + body.write(chunk); + await new Promise((resolve) => setTimeout(resolve, 1)); + } + body.end(); + resolve(null); + })().then(resolve, reject); + }); + + // Wait for both promises to resolve. + await Promise.all([reading, writing]); + }); + + test("supports the server writing data in a complete mess", async () => { + const body = new PassThrough(); + const stream = createStream(body); + + // Create a stream of data chunks split on the pipe character for readability. + const data = ` + : hi + + ev|ent: output + id: EVENT_1 + data: hello, + data: this |is a new line,| + data: and this is |a new line too + + event: d|one + id: EVENT|_2 + data: {} + ` + .trim() + .replace(/^[ ]+/gm, ""); + + const chunks = data.split("|"); + + // Consume the iterator in parallel to writing it. + const reading = new Promise((resolve, reject) => { + (async () => { + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "hello,\nthis is a new line,\nand this is a new line too", + }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + })().then(resolve, reject); + }); + + // Write the chunks to the stream at an interval. + const writing = new Promise((resolve, reject) => { + (async () => { + for await (const chunk of chunks) { + body.write(chunk); + await new Promise((resolve) => setTimeout(resolve, 1)); + } + body.end(); + resolve(null); + })().then(resolve, reject); + }); + + // Wait for both promises to resolve. + await Promise.all([reading, writing]); + }); + + test("supports ending without a done", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data: hello world + + ` + .trim() + .replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + }); + }); }); diff --git a/lib/stream.js b/lib/stream.js index d962ce1..8bc5623 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -50,7 +50,7 @@ class Stream extends Readable { * * @param {object} config * @param {string} config.url The URL to connect to. - * @param {Function} [config.fetch] The fetch implemention to use. + * @param {typeof fetch} [config.fetch] The fetch implementation to use. * @param {object} [config.options] The fetch options. */ constructor({ url, fetch = globalThis.fetch, options = {} }) { @@ -114,10 +114,21 @@ class Stream extends Readable { }, }); + if (!response.ok) { + throw new Error(); + } + + let partialChunk = ""; for await (const chunk of response.body) { const decoder = new TextDecoder("utf-8"); - const text = decoder.decode(chunk); + const text = partialChunk + decoder.decode(chunk); const lines = text.split("\n"); + + // We want to ensure that the last line is not a fragment + // so we keep it and append it to the start of the next + // chunk. + partialChunk = lines.pop(); + for (const line of lines) { const sse = this.decode(line); if (sse) { @@ -133,6 +144,17 @@ class Stream extends Readable { } } } + + // Process the final line and ensure we have captured the final event. + this.decode(partialChunk); + const sse = this.decode(""); + if (sse) { + if (sse.event === "error") { + throw new Error(sse.data); + } + + yield sse; + } } } From fc62230ed57927fbd3c43e160587d9e31824ba13 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 5 Mar 2024 14:44:21 +0000 Subject: [PATCH 4/5] Fix types for validateWebhook --- index.d.ts | 2 +- index.test.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/index.d.ts b/index.d.ts index 8dc998a..abf68dc 100644 --- a/index.d.ts +++ b/index.d.ts @@ -279,7 +279,7 @@ declare module "replicate" { signature?: string; }, secret: string - ): boolean; + ): Promise; export function parseProgressFromLogs(logs: Prediction | string): { percentage: number; diff --git a/index.test.ts b/index.test.ts index f90a4ae..b40ff5e 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1179,7 +1179,7 @@ describe("Replicate client", () => { // This is a test secret and should not be used in production const secret = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw"; - const isValid = validateWebhook(request, secret); + const isValid = await validateWebhook(request, secret); expect(isValid).toBe(true); }); From c3a498a7841412ab19c19049d4e642f2929c1ce0 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 5 Mar 2024 21:27:07 +0000 Subject: [PATCH 5/5] Stream iterator throws an ApiError if fetch fails --- index.test.ts | 40 ++++++++++++++++++++++++++++++++++++++-- lib/stream.js | 17 ++++++++++++++--- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/index.test.ts b/index.test.ts index b40ff5e..2e7ddd4 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1189,12 +1189,12 @@ describe("Replicate client", () => { // Continue with tests for other methods describe("Stream", () => { - function createStream(body: string | NodeJS.ReadableStream) { + function createStream(body: string | NodeJS.ReadableStream, status = 200) { const streamEndpoint = "https://stream.replicate.com"; nock(streamEndpoint) .get("/fake_stream") .matchHeader("Accept", "text/event-stream") - .reply(200, body); + .reply(status, body); return new Stream({ url: `${streamEndpoint}/fake_stream`, fetch }); } @@ -1468,5 +1468,41 @@ describe("Replicate client", () => { }); expect(await iterator.next()).toEqual({ done: true }); }); + + test("an error event in the stream raises an exception", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data: hello world + + event: error + id: EVENT_2 + data: An unexpected error occurred + + ` + .trim() + .replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + await expect(iterator.next()).rejects.toThrowError( + "An unexpected error occurred" + ); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("an error when fetching the stream raises an exception", async () => { + const stream = createStream("{}", 500); + const iterator = stream[Symbol.asyncIterator](); + await expect(iterator.next()).rejects.toThrowError( + "Request to https://stream.replicate.com/fake_stream failed with status 500 Internal Server Error: {}." + ); + expect(await iterator.next()).toEqual({ done: true }); + }); }); }); diff --git a/lib/stream.js b/lib/stream.js index 8bc5623..c4ffc70 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -1,4 +1,7 @@ // Attempt to use readable-stream if available, attempt to use the built-in stream module. + +const ApiError = require("./error"); + let Readable; try { Readable = require("readable-stream").Readable; @@ -107,15 +110,23 @@ class Stream extends Readable { } async *[Symbol.asyncIterator]() { - const response = await this.fetch(this.url, { + const init = { ...this.options, headers: { Accept: "text/event-stream", }, - }); + }; + const response = await this.fetch(this.url, init); if (!response.ok) { - throw new Error(); + // The cross-fetch shim doesn't accept Request objects so we create one here. + const request = new Request(this.url, init); + const text = await response.text(); + throw new ApiError( + `Request to ${request.url} failed with status ${response.status} ${response.statusText}: ${text}.`, + request, + response + ); } let partialChunk = "";