diff --git a/index.js b/index.js index ce407f9..f1ecbd0 100644 --- a/index.js +++ b/index.js @@ -1,6 +1,6 @@ const ApiError = require("./lib/error"); const ModelVersionIdentifier = require("./lib/identifier"); -const { Stream } = require("./lib/stream"); +const { createStream } = require("./lib/stream"); const { withAutomaticRetries } = require("./lib/util"); const collections = require("./lib/collections"); @@ -270,7 +270,7 @@ class Replicate { if (prediction.urls && prediction.urls.stream) { const { signal } = options; - const stream = new Stream(prediction.urls.stream, { signal }); + const stream = createStream(prediction.urls.stream, { signal }); yield* stream; } else { throw new Error("Prediction does not support streaming"); diff --git a/lib/stream.js b/lib/stream.js index 012d6d0..11cd460 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -1,15 +1,3 @@ -// Attempt to use readable-stream if available, attempt to use the built-in stream module. -let Readable; -try { - Readable = require("readable-stream").Readable; -} catch (e) { - try { - Readable = require("stream").Readable; - } catch (e) { - Readable = null; - } -} - /** * A server-sent event. */ @@ -41,99 +29,119 @@ class ServerSentEvent { } } -/** - * A stream of server-sent events. - */ -class Stream extends Readable { - /** - * Create a new stream of server-sent events. - * - * @param {string} url The URL to connect to. - * @param {object} options The fetch options. - */ - constructor(url, options) { - if (!Readable) { +async function createStream(url, options) { + // Attempt to use readable-stream if available, attempt to use the built-in stream module. + let Readable; + try { + Readable = await import("readable-stream").then( + (module) => module.Readable + ); + } catch (e) { + try { + Readable = await import("node:stream").then((module) => module.Readable); + } catch (e) { throw new Error( "Readable streams are not supported. Please use Node.js 18 or later, or install the readable-stream package." ); } - - super(); - this.url = url; - this.options = options; - - this.event = null; - this.data = []; - this.lastEventId = null; - this.retry = null; } - decode(line) { - if (!line) { - if (!this.event && !this.data.length && !this.lastEventId) { - return null; + /** + * A stream of server-sent events. + */ + class Stream extends Readable { + /** + * Create a new stream of server-sent events. + * + * @param {string} url The URL to connect to. + * @param {object} options The fetch options. + */ + constructor(url, options) { + if (!Readable) { + throw new Error( + "Readable streams are not supported. Please use Node.js 18 or later, or install the readable-stream package." + ); } - const sse = new ServerSentEvent( - this.event, - this.data.join("\n"), - this.lastEventId - ); + super(); + this.url = url; + this.options = options; this.event = null; this.data = []; + this.lastEventId = null; this.retry = null; - - return sse; } - if (line.startsWith(":")) { - return null; - } + decode(line) { + if (!line) { + if (!this.event && !this.data.length && !this.lastEventId) { + return null; + } - const [field, value] = line.split(": "); - if (field === "event") { - this.event = value; - } else if (field === "data") { - this.data.push(value); - } else if (field === "id") { - this.lastEventId = value; - } + const sse = new ServerSentEvent( + this.event, + this.data.join("\n"), + this.lastEventId + ); - return null; - } + this.event = null; + this.data = []; + this.retry = null; - async *[Symbol.asyncIterator]() { - const response = await fetch(this.url, { - ...this.options, - headers: { - Accept: "text/event-stream", - }, - }); - - for await (const chunk of response.body) { - const decoder = new TextDecoder("utf-8"); - const text = decoder.decode(chunk); - const lines = text.split("\n"); - for (const line of lines) { - const sse = this.decode(line); - if (sse) { - if (sse.event === "error") { - throw new Error(sse.data); - } + return sse; + } - yield sse; + if (line.startsWith(":")) { + return null; + } + + const [field, value] = line.split(": "); + if (field === "event") { + this.event = value; + } else if (field === "data") { + this.data.push(value); + } else if (field === "id") { + this.lastEventId = value; + } - if (sse.event === "done") { - return; + return null; + } + + async *[Symbol.asyncIterator]() { + const response = await fetch(this.url, { + ...this.options, + headers: { + Accept: "text/event-stream", + }, + }); + + for await (const chunk of response.body) { + const decoder = new TextDecoder("utf-8"); + const text = decoder.decode(chunk); + const lines = text.split("\n"); + for (const line of lines) { + const sse = this.decode(line); + if (sse) { + if (sse.event === "error") { + throw new Error(sse.data); + } + + yield sse; + + if (sse.event === "done") { + return; + } } } } } } + + return new Stream(url, options); } module.exports = { - Stream, ServerSentEvent, + createStream, };