Skip to content

Add replicate.stream method #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,49 @@ const input = { prompt: "a 19th century portrait of a raccoon gentleman wearing
const output = await replicate.run(model, { input });
```

### `replicate.stream`

Run a model and stream its output. Unlike [`replicate.prediction.create`](#replicatepredictionscreate), this method returns only the prediction output rather than the entire prediction object.

```js
for await (const event of replicate.stream(identifier, options)) { /* ... */ }
```

| name | type | description |
| ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `identifier` | string | **Required**. The model version identifier in the format `{owner}/{name}` or `{owner}/{name}:{version}`, for example `meta/llama-2-70b-chat` |
| `options.input` | object | **Required**. An object with the model inputs. |
| `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output |
| `options.webhook_events_filter` | string[] | An array of events which should trigger [webhooks](https://replicate.com/docs/webhooks). Allowable values are `start`, `output`, `logs`, and `completed` |
| `options.signal` | object | An [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) to cancel the prediction |

Throws `Error` if the prediction failed.

Returns `AsyncGenerator<ServerSentEvent>` which yields the events of running the model.

Example:

```js
for await (const event of replicate.stream("meta/llama-2-70b-chat")) {
process.stdout.write(`${event}`);
}
```

### Server-sent events

A stream generates server-sent events with the following properties:

| name | type | description |
| ------- | ------ | ---------------------------------------------------------------------------- |
| `event` | string | The type of event. Possible values are `output`, `logs`, `error`, and `done` |
| `data` | string | The event data |
| `id` | string | The event id |
| `retry` | number | The number of milliseconds to wait before reconnecting to the server |

As the prediction runs, the generator yields `output` and `logs` events. If an error occurs, the generator yields an `error` event with a JSON object containing the error message set to the `data` property. When the prediction is done, the generator yields a `done` event with an empty JSON object set to the `data` property.

Events with the `output` event type have their `toString()` method overridden to return the event data as a string. Other event types return an empty string.

### `replicate.models.get`

Get metadata for a public model or a private model that you own.
Expand Down
17 changes: 17 additions & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ declare module "replicate" {
results: T[];
}

export interface ServerSentEvent {
event: string;
data: string;
id?: string;
retry?: number;
}

export default class Replicate {
constructor(options?: {
auth?: string;
Expand Down Expand Up @@ -103,6 +110,16 @@ declare module "replicate" {
progress?: (prediction: Prediction) => void
): Promise<object>;

stream(
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
options: {
input: object;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
signal?: AbortSignal;
}
): AsyncGenerator<ServerSentEvent>;

request(
route: string | URL,
options: {
Expand Down
42 changes: 42 additions & 0 deletions index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const ApiError = require("./lib/error");
const ModelVersionIdentifier = require("./lib/identifier");
const { Stream } = require("./lib/stream");
const { withAutomaticRetries } = require("./lib/util");

const collections = require("./lib/collections");
Expand Down Expand Up @@ -235,6 +236,47 @@ class Replicate {
return response;
}

/**
* Stream a model and wait for its output.
*
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
* @throws {Error} If the prediction failed
* @yields {ServerSentEvent} Each streamed event from the prediction
*/
async *stream(ref, options) {
const { wait, ...data } = options;

const identifier = ModelVersionIdentifier.parse(ref);

let prediction;
if (identifier.version) {
prediction = await this.predictions.create({
...data,
version: identifier.version,
stream: true,
});
} else {
prediction = await this.models.predictions.create(
identifier.owner,
identifier.name,
{ ...data, stream: true }
);
}

if (prediction.urls && prediction.urls.stream) {
const { signal } = options;
const stream = new Stream(prediction.urls.stream, { signal });
yield* stream;
} else {
throw new Error("Prediction does not support streaming");
}
}

/**
* Paginate through a list of results.
*
Expand Down
123 changes: 123 additions & 0 deletions lib/stream.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
const { Readable } = require("stream");

/**
* A server-sent event.
*/
class ServerSentEvent {
/**
* Create a new server-sent event.
*
* @param {string} event The event name.
* @param {string} data The event data.
* @param {string} id The event ID.
* @param {number} retry The retry time.
*/
constructor(event, data, id, retry) {
this.event = event;
this.data = data;
this.id = id;
this.retry = retry;
}

/**
* Convert the event to a string.
*/
toString() {
if (this.event === "output") {
return this.data;
}

return "";
}
}

/**
* A stream of server-sent events.
*/
class Stream extends Readable {
/**
* Create a new stream of server-sent events.
*
* @param {string} url The URL to connect to.
* @param {object} options The fetch options.
*/
constructor(url, options) {
super();
this.url = url;
this.options = options;

this.event = null;
this.data = [];
this.lastEventId = null;
this.retry = null;
}

decode(line) {
if (!line) {
if (!this.event && !this.data.length && !this.lastEventId) {
return null;
}

const sse = new ServerSentEvent(
this.event,
this.data.join("\n"),
this.lastEventId
);

this.event = null;
this.data = [];
this.retry = null;

return sse;
}

if (line.startsWith(":")) {
return null;
}

const [field, value] = line.split(": ");
if (field === "event") {
this.event = value;
} else if (field === "data") {
this.data.push(value);
} else if (field === "id") {
this.lastEventId = value;
}

return null;
}

async *[Symbol.asyncIterator]() {
const response = await fetch(this.url, {
...this.options,
headers: {
Accept: "text/event-stream",
},
});

for await (const chunk of response.body) {
const decoder = new TextDecoder("utf-8");
const text = decoder.decode(chunk);
const lines = text.split("\n");
for (const line of lines) {
const sse = this.decode(line);
if (sse) {
if (sse.event === "error") {
throw new Error(sse.data);
}

yield sse;

if (sse.event === "done") {
return;
}
}
}
}
}
}

module.exports = {
Stream,
ServerSentEvent,
};