From bece2f8754a6316d64af110c18dd8f54c5719b8a Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 8 Dec 2023 11:09:32 -0800 Subject: [PATCH 1/2] Add replicate.stream method --- index.d.ts | 17 +++++++ index.js | 42 +++++++++++++++++ lib/stream.js | 123 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+) create mode 100644 lib/stream.js diff --git a/index.d.ts b/index.d.ts index 3b9d80c..05edbe6 100644 --- a/index.d.ts +++ b/index.d.ts @@ -75,6 +75,13 @@ declare module "replicate" { results: T[]; } + export interface ServerSentEvent { + event: string; + data: string; + id?: string; + retry?: number; + } + export default class Replicate { constructor(options?: { auth?: string; @@ -103,6 +110,16 @@ declare module "replicate" { progress?: (prediction: Prediction) => void ): Promise; + stream( + identifier: `${string}/${string}` | `${string}/${string}:${string}`, + options: { + input: object; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + signal?: AbortSignal; + } + ): AsyncGenerator; + request( route: string | URL, options: { diff --git a/index.js b/index.js index b908226..5a6c532 100644 --- a/index.js +++ b/index.js @@ -1,5 +1,6 @@ const ApiError = require("./lib/error"); const ModelVersionIdentifier = require("./lib/identifier"); +const { Stream } = require("./lib/stream"); const { withAutomaticRetries } = require("./lib/util"); const collections = require("./lib/collections"); @@ -235,6 +236,47 @@ class Replicate { return response; } + /** + * Stream a model and wait for its output. + * + * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" + * @param {object} options + * @param {object} options.input - Required. An object with the model inputs + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction + * @throws {Error} If the prediction failed + * @yields {ServerSentEvent} Each streamed event from the prediction + */ + async *stream(ref, options) { + const { wait, ...data } = options; + + const identifier = ModelVersionIdentifier.parse(ref); + + let prediction; + if (identifier.version) { + prediction = await this.predictions.create({ + ...data, + version: identifier.version, + stream: true, + }); + } else { + prediction = await this.models.predictions.create( + identifier.owner, + identifier.name, + { ...data, stream: true } + ); + } + + if (prediction.urls && prediction.urls.stream) { + const { signal } = options; + const stream = new Stream(prediction.urls.stream, { signal }); + yield* stream; + } else { + throw new Error("Prediction does not support streaming"); + } + } + /** * Paginate through a list of results. * diff --git a/lib/stream.js b/lib/stream.js new file mode 100644 index 0000000..dfc7c61 --- /dev/null +++ b/lib/stream.js @@ -0,0 +1,123 @@ +const { Readable } = require("stream"); + +/** + * A server-sent event. + */ +class ServerSentEvent { + /** + * Create a new server-sent event. + * + * @param {string} event The event name. + * @param {string} data The event data. + * @param {string} id The event ID. + * @param {number} retry The retry time. + */ + constructor(event, data, id, retry) { + this.event = event; + this.data = data; + this.id = id; + this.retry = retry; + } + + /** + * Convert the event to a string. + */ + toString() { + if (this.event === "output") { + return this.data; + } + + return ""; + } +} + +/** + * A stream of server-sent events. + */ +class Stream extends Readable { + /** + * Create a new stream of server-sent events. + * + * @param {string} url The URL to connect to. + * @param {object} options The fetch options. + */ + constructor(url, options) { + super(); + this.url = url; + this.options = options; + + this.event = null; + this.data = []; + this.lastEventId = null; + this.retry = null; + } + + decode(line) { + if (!line) { + if (!this.event && !this.data.length && !this.lastEventId) { + return null; + } + + const sse = new ServerSentEvent( + this.event, + this.data.join("\n"), + this.lastEventId + ); + + this.event = null; + this.data = []; + this.retry = null; + + return sse; + } + + if (line.startsWith(":")) { + return null; + } + + const [field, value] = line.split(": "); + if (field === "event") { + this.event = value; + } else if (field === "data") { + this.data.push(value); + } else if (field === "id") { + this.lastEventId = value; + } + + return null; + } + + async *[Symbol.asyncIterator]() { + const response = await fetch(this.url, { + ...this.options, + headers: { + Accept: "text/event-stream", + }, + }); + + for await (const chunk of response.body) { + const decoder = new TextDecoder("utf-8"); + const text = decoder.decode(chunk); + const lines = text.split("\n"); + for (const line of lines) { + const sse = this.decode(line); + if (sse) { + if (sse.event === "error") { + throw new Error(sse.data); + } + + yield sse; + + if (sse.event === "done") { + return; + } + } + } + } + } +} + +module.exports = { + Stream, + ServerSentEvent, +}; From 28026f3a6827e674fc8828d742a51fd2b1f85b30 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 11 Dec 2023 03:27:27 -0800 Subject: [PATCH 2/2] Update README --- README.md | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/README.md b/README.md index 1eb7a2b..fd13c70 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,49 @@ const input = { prompt: "a 19th century portrait of a raccoon gentleman wearing const output = await replicate.run(model, { input }); ``` +### `replicate.stream` + +Run a model and stream its output. Unlike [`replicate.prediction.create`](#replicatepredictionscreate), this method returns only the prediction output rather than the entire prediction object. + +```js +for await (const event of replicate.stream(identifier, options)) { /* ... */ } +``` + +| name | type | description | +| ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `identifier` | string | **Required**. The model version identifier in the format `{owner}/{name}` or `{owner}/{name}:{version}`, for example `meta/llama-2-70b-chat` | +| `options.input` | object | **Required**. An object with the model inputs. | +| `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output | +| `options.webhook_events_filter` | string[] | An array of events which should trigger [webhooks](https://replicate.com/docs/webhooks). Allowable values are `start`, `output`, `logs`, and `completed` | +| `options.signal` | object | An [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) to cancel the prediction | + +Throws `Error` if the prediction failed. + +Returns `AsyncGenerator` which yields the events of running the model. + +Example: + +```js +for await (const event of replicate.stream("meta/llama-2-70b-chat")) { + process.stdout.write(`${event}`); +} +``` + +### Server-sent events + +A stream generates server-sent events with the following properties: + +| name | type | description | +| ------- | ------ | ---------------------------------------------------------------------------- | +| `event` | string | The type of event. Possible values are `output`, `logs`, `error`, and `done` | +| `data` | string | The event data | +| `id` | string | The event id | +| `retry` | number | The number of milliseconds to wait before reconnecting to the server | + +As the prediction runs, the generator yields `output` and `logs` events. If an error occurs, the generator yields an `error` event with a JSON object containing the error message set to the `data` property. When the prediction is done, the generator yields a `done` event with an empty JSON object set to the `data` property. + +Events with the `output` event type have their `toString()` method overridden to return the event data as a string. Other event types return an empty string. + ### `replicate.models.get` Get metadata for a public model or a private model that you own.