diff --git a/index.d.ts b/index.d.ts index 2b183d0..678d7a0 100644 --- a/index.d.ts +++ b/index.d.ts @@ -173,6 +173,7 @@ declare module "replicate" { webhook?: string; webhook_events_filter?: WebhookEventType[]; signal?: AbortSignal; + useFileOutput?: boolean; } ): AsyncGenerator; diff --git a/index.js b/index.js index b1248e7..5dbfc12 100644 --- a/index.js +++ b/index.js @@ -315,7 +315,7 @@ class Replicate { * @yields {ServerSentEvent} Each streamed event from the prediction */ async *stream(ref, options) { - const { wait, signal, ...data } = options; + const { wait, signal, useFileOutput = this.useFileOutput, ...data } = options; const identifier = ModelVersionIdentifier.parse(ref); @@ -338,7 +338,10 @@ class Replicate { const stream = createReadableStream({ url: prediction.urls.stream, fetch: this.fetch, - ...(signal ? { options: { signal } } : {}), + options: { + useFileOutput, + ...(signal ? { signal } : {}), + }, }); yield* streamAsyncIterator(stream); diff --git a/lib/stream.js b/lib/stream.js index 2c899bd..7fcee11 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -93,7 +93,7 @@ function createReadableStream({ url, fetch, options = {} }) { typeof data === "string" && (data.startsWith("https:") || data.startsWith("data:")) ) { - data = createFileOutput({ data, fetch }); + data = createFileOutput({ url: data, fetch }); } controller.enqueue(new ServerSentEvent(event.event, data, event.id));