Skip to content

Commit 4145b01

Browse files
committed
Refactor Stream module
This is now built on web standards based `EventSource` and `ReadableStream` primatives. The latter is available in Node >= 18 and the former is imported as a dependency from the `eventsource` module. If available the implementation will use the native implementation falling back to the module implementation which means this will work in Browser environments and Deno.
1 parent c3a498a commit 4145b01

File tree

5 files changed

+144
-173
lines changed

5 files changed

+144
-173
lines changed

index.js

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
3-
const { Stream } = require("./lib/stream");
3+
const { createReadableStream } = require("./lib/stream");
44
const {
55
withAutomaticRetries,
66
validateWebhook,
@@ -45,13 +45,17 @@ class Replicate {
4545
* @param {string} options.userAgent - Identifier of your app
4646
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
4747
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
48+
* @param {Function} [options.EventSource] - Custom EventSource implementation function to use.
49+
* @param {Function} [options.ReadableStream] - Custom ReadableStream implementation function to use.
4850
*/
4951
constructor(options = {}) {
5052
this.auth = options.auth || process.env.REPLICATE_API_TOKEN;
5153
this.userAgent =
5254
options.userAgent || `replicate-javascript/${packageJSON.version}`;
5355
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
5456
this.fetch = options.fetch || globalThis.fetch;
57+
this.EventSource = options.EventSource;
58+
this.ReadableStream = options.ReadableStream;
5559

5660
this.accounts = {
5761
current: accounts.current.bind(this),
@@ -289,9 +293,10 @@ class Replicate {
289293

290294
if (prediction.urls && prediction.urls.stream) {
291295
const { signal } = options;
292-
const stream = new Stream({
296+
const stream = await createReadableStream({
293297
url: prediction.urls.stream,
294-
fetch: this.fetch,
298+
EventSource: this.EventSource,
299+
ReadableStream: this.ReadableStream,
295300
options: { signal },
296301
});
297302
yield* stream;

index.test.ts

+34-42
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Replicate, {
88
} from "replicate";
99
import nock from "nock";
1010
import fetch from "cross-fetch";
11-
import { Stream } from "./lib/stream";
11+
import { createReadableStream } from "./lib/stream";
1212
import { PassThrough } from "node:stream";
1313

1414
let client: Replicate;
@@ -1188,19 +1188,24 @@ describe("Replicate client", () => {
11881188

11891189
// Continue with tests for other methods
11901190

1191-
describe("Stream", () => {
1192-
function createStream(body: string | NodeJS.ReadableStream, status = 200) {
1191+
describe("createReadableStream", () => {
1192+
async function createStream(
1193+
body: string | NodeJS.ReadableStream,
1194+
status = 200
1195+
) {
11931196
const streamEndpoint = "https://stream.replicate.com";
11941197
nock(streamEndpoint)
11951198
.get("/fake_stream")
11961199
.matchHeader("Accept", "text/event-stream")
11971200
.reply(status, body);
11981201

1199-
return new Stream({ url: `${streamEndpoint}/fake_stream`, fetch });
1202+
return await createReadableStream({
1203+
url: `${streamEndpoint}/fake_stream`,
1204+
});
12001205
}
12011206

12021207
test("consumes a server sent event stream", async () => {
1203-
const stream = createStream(
1208+
const stream = await createStream(
12041209
`
12051210
event: output
12061211
id: EVENT_1
@@ -1209,13 +1214,11 @@ describe("Replicate client", () => {
12091214
event: done
12101215
id: EVENT_2
12111216
data: {}
1212-
`
1213-
.trim()
1214-
.replace(/^[ ]+/gm, "")
1217+
1218+
`.replace(/^[ ]+/gm, "")
12151219
);
12161220

12171221
const iterator = stream[Symbol.asyncIterator]();
1218-
12191222
expect(await iterator.next()).toEqual({
12201223
done: false,
12211224
value: { event: "output", id: "EVENT_1", data: "hello world" },
@@ -1229,7 +1232,7 @@ describe("Replicate client", () => {
12291232
});
12301233

12311234
test("consumes multiple events", async () => {
1232-
const stream = createStream(
1235+
const stream = await createStream(
12331236
`
12341237
event: output
12351238
id: EVENT_1
@@ -1242,9 +1245,8 @@ describe("Replicate client", () => {
12421245
event: done
12431246
id: EVENT_3
12441247
data: {}
1245-
`
1246-
.trim()
1247-
.replace(/^[ ]+/gm, "")
1248+
1249+
`.replace(/^[ ]+/gm, "")
12481250
);
12491251

12501252
const iterator = stream[Symbol.asyncIterator]();
@@ -1266,7 +1268,7 @@ describe("Replicate client", () => {
12661268
});
12671269

12681270
test("ignores unexpected characters", async () => {
1269-
const stream = createStream(
1271+
const stream = await createStream(
12701272
`
12711273
: hi
12721274
@@ -1277,9 +1279,8 @@ describe("Replicate client", () => {
12771279
event: done
12781280
id: EVENT_2
12791281
data: {}
1280-
`
1281-
.trim()
1282-
.replace(/^[ ]+/gm, "")
1282+
1283+
`.replace(/^[ ]+/gm, "")
12831284
);
12841285

12851286
const iterator = stream[Symbol.asyncIterator]();
@@ -1297,7 +1298,7 @@ describe("Replicate client", () => {
12971298
});
12981299

12991300
test("supports multiple lines of output in a single event", async () => {
1300-
const stream = createStream(
1301+
const stream = await createStream(
13011302
`
13021303
: hi
13031304
@@ -1310,9 +1311,8 @@ describe("Replicate client", () => {
13101311
event: done
13111312
id: EVENT_2
13121313
data: {}
1313-
`
1314-
.trim()
1315-
.replace(/^[ ]+/gm, "")
1314+
1315+
`.replace(/^[ ]+/gm, "")
13161316
);
13171317

13181318
const iterator = stream[Symbol.asyncIterator]();
@@ -1335,7 +1335,7 @@ describe("Replicate client", () => {
13351335

13361336
test("supports the server writing data lines in multiple chunks", async () => {
13371337
const body = new PassThrough();
1338-
const stream = createStream(body);
1338+
const stream = await createStream(body);
13391339

13401340
// Create a stream of data chunks split on the pipe character for readability.
13411341
const data = `
@@ -1348,9 +1348,8 @@ describe("Replicate client", () => {
13481348
event: done
13491349
id: EVENT_2
13501350
data: {}
1351-
`
1352-
.trim()
1353-
.replace(/^[ ]+/gm, "");
1351+
1352+
`.replace(/^[ ]+/gm, "");
13541353

13551354
const chunks = data.split("|");
13561355

@@ -1392,7 +1391,7 @@ describe("Replicate client", () => {
13921391

13931392
test("supports the server writing data in a complete mess", async () => {
13941393
const body = new PassThrough();
1395-
const stream = createStream(body);
1394+
const stream = await createStream(body);
13961395

13971396
// Create a stream of data chunks split on the pipe character for readability.
13981397
const data = `
@@ -1407,9 +1406,8 @@ describe("Replicate client", () => {
14071406
event: d|one
14081407
id: EVENT|_2
14091408
data: {}
1410-
`
1411-
.trim()
1412-
.replace(/^[ ]+/gm, "");
1409+
1410+
`.replace(/^[ ]+/gm, "");
14131411

14141412
const chunks = data.split("|");
14151413

@@ -1450,15 +1448,13 @@ describe("Replicate client", () => {
14501448
});
14511449

14521450
test("supports ending without a done", async () => {
1453-
const stream = createStream(
1451+
const stream = await createStream(
14541452
`
14551453
event: output
14561454
id: EVENT_1
14571455
data: hello world
14581456
1459-
`
1460-
.trim()
1461-
.replace(/^[ ]+/gm, "")
1457+
`.replace(/^[ ]+/gm, "")
14621458
);
14631459

14641460
const iterator = stream[Symbol.asyncIterator]();
@@ -1470,7 +1466,7 @@ describe("Replicate client", () => {
14701466
});
14711467

14721468
test("an error event in the stream raises an exception", async () => {
1473-
const stream = createStream(
1469+
const stream = await createStream(
14741470
`
14751471
event: output
14761472
id: EVENT_1
@@ -1480,27 +1476,23 @@ describe("Replicate client", () => {
14801476
id: EVENT_2
14811477
data: An unexpected error occurred
14821478
1483-
`
1484-
.trim()
1485-
.replace(/^[ ]+/gm, "")
1479+
`.replace(/^[ ]+/gm, "")
14861480
);
14871481

14881482
const iterator = stream[Symbol.asyncIterator]();
14891483
expect(await iterator.next()).toEqual({
14901484
done: false,
14911485
value: { event: "output", id: "EVENT_1", data: "hello world" },
14921486
});
1493-
await expect(iterator.next()).rejects.toThrowError(
1494-
"An unexpected error occurred"
1495-
);
1487+
await expect(iterator.next()).rejects.toThrowError("Unexpected Error");
14961488
expect(await iterator.next()).toEqual({ done: true });
14971489
});
14981490

14991491
test("an error when fetching the stream raises an exception", async () => {
1500-
const stream = createStream("{}", 500);
1492+
const stream = await createStream("{}", 500);
15011493
const iterator = stream[Symbol.asyncIterator]();
15021494
await expect(iterator.next()).rejects.toThrowError(
1503-
"Request to https://stream.replicate.com/fake_stream failed with status 500 Internal Server Error: {}."
1495+
"Request to https://stream.replicate.com/fake_stream failed with status 500"
15041496
);
15051497
expect(await iterator.next()).toEqual({ done: true });
15061498
});

0 commit comments

Comments
 (0)