diff --git a/packages/vertexai/src/methods/chrome-adapter.test.ts b/packages/vertexai/src/methods/chrome-adapter.test.ts index c350bd366c0..7e957c23614 100644 --- a/packages/vertexai/src/methods/chrome-adapter.test.ts +++ b/packages/vertexai/src/methods/chrome-adapter.test.ts @@ -157,7 +157,7 @@ describe('ChromeAdapter', () => { Promise.resolve({ available: 'after-download' }), - create: () => {} + create: () => { } } } as AI; const downloadPromise = new Promise(() => { @@ -182,7 +182,7 @@ describe('ChromeAdapter', () => { Promise.resolve({ available: 'after-download' }), - create: () => {} + create: () => { } } } as AI; let resolveDownload; @@ -298,4 +298,87 @@ describe('ChromeAdapter', () => { }); }); }); + describe('countTokens', () => { + it('With no initial prompts', async () => { + const aiProvider = { + languageModel: { + create: () => Promise.resolve({}) + } + } as AI; + const inputText = "first"; + const expectedCount = 10; + const model = { + countPromptTokens: _s => Promise.resolve(123), + } as AILanguageModel; + // overrides impl with stub method + const countPromptTokensStub = stub(model, 'countPromptTokens').resolves(expectedCount); + const factoryStub = stub(aiProvider.languageModel, 'create').resolves(model); + const adapter = new ChromeAdapter( + aiProvider, + 'prefer_on_device', + ); + const response = await adapter.countTokens({ + contents: [ + { role: 'user', parts: [{ text: inputText }] } + ] + }); + expect(factoryStub).to.have.been.calledOnceWith({ + // initialPrompts must be empty + initialPrompts: [] + }); + // validate count tokens gets called with the last entry from the input + expect(countPromptTokensStub).to.have.been.calledOnceWith({ + role: 'user', + content: inputText + }); + expect(await response.json()).to.deep.equal({ + totalTokens: expectedCount + }); + }); + it('Extracts initial prompts and then does counts tokens', async () => { + const aiProvider = { + languageModel: { + create: () => Promise.resolve({}) + } + } as AI; + const expectedCount = 10; + const model = { + countPromptTokens: _s => Promise.resolve(123), + } as AILanguageModel; + // overrides impl with stub method + const countPromptTokensStub = stub(model, 'countPromptTokens').resolves(expectedCount); + const factoryStub = stub(aiProvider.languageModel, 'create').resolves(model); + const text = ['first', 'second', 'third']; + const onDeviceParams = { + initialPrompts: [{ role: 'user', content: text[0] }] + } as AILanguageModelCreateOptionsWithSystemPrompt; + const adapter = new ChromeAdapter( + aiProvider, + 'prefer_on_device', + onDeviceParams + ); + const response = await adapter.countTokens({ + contents: [ + { role: 'model', parts: [{ text: text[1] }] }, + { role: 'user', parts: [{ text: text[2] }] } + ] + }); + expect(factoryStub).to.have.been.calledOnceWith({ + initialPrompts: [ + { role: 'user', content: text[0] }, + // Asserts tail is passed as initial prompts, and + // role is normalized from model to assistant. + { role: 'assistant', content: text[1] } + ] + }); + // validate count tokens gets called with the last entry from the input + expect(countPromptTokensStub).to.have.been.calledOnceWith({ + role: 'user', + content: text[2] + }); + expect(await response.json()).to.deep.equal({ + totalTokens: expectedCount + }); + }); + }); }); diff --git a/packages/vertexai/src/methods/chrome-adapter.ts b/packages/vertexai/src/methods/chrome-adapter.ts index 01ee3ca1bff..3c61743764a 100644 --- a/packages/vertexai/src/methods/chrome-adapter.ts +++ b/packages/vertexai/src/methods/chrome-adapter.ts @@ -18,6 +18,7 @@ import { isChrome } from '@firebase/util'; import { Content, + CountTokensRequest, GenerateContentRequest, InferenceMode, Role @@ -102,6 +103,20 @@ export class ChromeAdapter { const stream = await session.promptStreaming(prompt.content); return ChromeAdapter.toStreamResponse(stream); } + async countTokens(request: CountTokensRequest): Promise { + const options = this.onDeviceParams || {}; + options.initialPrompts ??= []; + const extractedInitialPrompts = ChromeAdapter.toInitialPrompts(request.contents); + const currentPrompt = extractedInitialPrompts.pop()!; + options.initialPrompts.push(...extractedInitialPrompts); + const session = await this.session(options); + const tokenCount = await session.countPromptTokens(currentPrompt); + return { + json: async () => ({ + totalTokens: tokenCount, + }) + } as Response; + } private static isOnDeviceRequest(request: GenerateContentRequest): boolean { // Returns false if the prompt is empty. if (request.contents.length === 0) { diff --git a/packages/vertexai/src/methods/count-tokens.test.ts b/packages/vertexai/src/methods/count-tokens.test.ts index a3d7c99b4ba..7bff618f3e0 100644 --- a/packages/vertexai/src/methods/count-tokens.test.ts +++ b/packages/vertexai/src/methods/count-tokens.test.ts @@ -25,6 +25,7 @@ import { countTokens } from './count-tokens'; import { CountTokensRequest } from '../types'; import { ApiSettings } from '../types/internal'; import { Task } from '../requests/request'; +import { ChromeAdapter } from './chrome-adapter'; use(sinonChai); use(chaiAsPromised); @@ -52,7 +53,8 @@ describe('countTokens()', () => { const result = await countTokens( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.totalTokens).to.equal(6); expect(result.totalBillableCharacters).to.equal(16); @@ -77,7 +79,8 @@ describe('countTokens()', () => { const result = await countTokens( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.totalTokens).to.equal(1837); expect(result.totalBillableCharacters).to.equal(117); @@ -104,7 +107,8 @@ describe('countTokens()', () => { const result = await countTokens( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.totalTokens).to.equal(258); expect(result).to.not.have.property('totalBillableCharacters'); @@ -127,7 +131,7 @@ describe('countTokens()', () => { json: mockResponse.json } as Response); await expect( - countTokens(fakeApiSettings, 'model', fakeRequestParams) + countTokens(fakeApiSettings, 'model', fakeRequestParams, new ChromeAdapter()) ).to.be.rejectedWith(/404.*not found/); expect(mockFetch).to.be.called; }); diff --git a/packages/vertexai/src/methods/count-tokens.ts b/packages/vertexai/src/methods/count-tokens.ts index c9d43a5b6fd..c6b6fab3934 100644 --- a/packages/vertexai/src/methods/count-tokens.ts +++ b/packages/vertexai/src/methods/count-tokens.ts @@ -22,8 +22,9 @@ import { } from '../types'; import { Task, makeRequest } from '../requests/request'; import { ApiSettings } from '../types/internal'; +import { ChromeAdapter } from './chrome-adapter'; -export async function countTokens( +export async function countTokensOnCloud( apiSettings: ApiSettings, model: string, params: CountTokensRequest, @@ -39,3 +40,17 @@ export async function countTokens( ); return response.json(); } + +export async function countTokens( + apiSettings: ApiSettings, + model: string, + params: CountTokensRequest, + chromeAdapter: ChromeAdapter, + requestOptions?: RequestOptions +): Promise { + if (await chromeAdapter.isAvailable(params)) { + return (await chromeAdapter.countTokens(params)).json(); + } else { + return countTokensOnCloud(apiSettings, model, params, requestOptions); + } +} diff --git a/packages/vertexai/src/models/generative-model.ts b/packages/vertexai/src/models/generative-model.ts index bf72ae0be9f..60ec72be6fe 100644 --- a/packages/vertexai/src/models/generative-model.ts +++ b/packages/vertexai/src/models/generative-model.ts @@ -154,6 +154,6 @@ export class GenerativeModel extends VertexAIModel { request: CountTokensRequest | string | Array ): Promise { const formattedParams = formatGenerateContentInput(request); - return countTokens(this._apiSettings, this.model, formattedParams); + return countTokens(this._apiSettings, this.model, formattedParams, this.chromeAdapter); } }