diff --git a/packages/vertexai/src/methods/chrome-adapter.test.ts b/packages/vertexai/src/methods/chrome-adapter.test.ts index b11fb9c937e..cce97b25f5a 100644 --- a/packages/vertexai/src/methods/chrome-adapter.test.ts +++ b/packages/vertexai/src/methods/chrome-adapter.test.ts @@ -307,4 +307,56 @@ describe('ChromeAdapter', () => { }); }); }); + describe('countTokens', () => { + it('counts tokens from a singular input', async () => { + const inputText = 'first'; + const expectedCount = 10; + const onDeviceParams = { + systemPrompt: 'be yourself' + } as LanguageModelCreateOptions; + + // setting up stubs + const languageModelProvider = { + create: () => Promise.resolve({}) + } as LanguageModel; + const languageModel = { + measureInputUsage: _i => Promise.resolve(123) + } as LanguageModel; + const createStub = stub(languageModelProvider, 'create').resolves( + languageModel + ); + // overrides impl with stub method + const measureInputUsageStub = stub( + languageModel, + 'measureInputUsage' + ).resolves(expectedCount); + + const adapter = new ChromeAdapter( + languageModelProvider, + 'prefer_on_device', + onDeviceParams + ); + const countTokenRequest = { + contents: [{ role: 'user', parts: [{ text: inputText }] }] + } as GenerateContentRequest; + const response = await adapter.countTokens(countTokenRequest); + // Asserts initialization params are proxied. + expect(createStub).to.have.been.calledOnceWith(onDeviceParams); + // Asserts Vertex input type is mapped to Chrome type. + expect(measureInputUsageStub).to.have.been.calledOnceWith([ + { + role: 'user', + content: [ + { + type: 'text', + content: inputText + } + ] + } + ]); + expect(await response.json()).to.deep.equal({ + totalTokens: expectedCount + }); + }); + }); }); diff --git a/packages/vertexai/src/methods/chrome-adapter.ts b/packages/vertexai/src/methods/chrome-adapter.ts index 10844079c03..225d2bd581d 100644 --- a/packages/vertexai/src/methods/chrome-adapter.ts +++ b/packages/vertexai/src/methods/chrome-adapter.ts @@ -17,6 +17,7 @@ import { Content, + CountTokensRequest, GenerateContentRequest, InferenceMode, Part, @@ -117,6 +118,21 @@ export class ChromeAdapter { } as Response; } + async countTokens(request: CountTokensRequest): Promise { + // TODO: Check if the request contains an image, and if so, throw. + const session = await this.createSession( + // TODO: normalize on-device params during construction. + this.onDeviceParams || {} + ); + const messages = ChromeAdapter.toLanguageModelMessages(request.contents); + const tokenCount = await session.measureInputUsage(messages); + return { + json: async () => ({ + totalTokens: tokenCount + }) + } as Response; + } + /** * Asserts inference for the given request can be performed by an on-device model. */ diff --git a/packages/vertexai/src/methods/count-tokens.test.ts b/packages/vertexai/src/methods/count-tokens.test.ts index 9eccbf702fe..f77d2a36c2f 100644 --- a/packages/vertexai/src/methods/count-tokens.test.ts +++ b/packages/vertexai/src/methods/count-tokens.test.ts @@ -25,6 +25,7 @@ import { countTokens } from './count-tokens'; import { CountTokensRequest } from '../types'; import { ApiSettings } from '../types/internal'; import { Task } from '../requests/request'; +import { ChromeAdapter } from './chrome-adapter'; use(sinonChai); use(chaiAsPromised); @@ -55,7 +56,8 @@ describe('countTokens()', () => { const result = await countTokens( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.totalTokens).to.equal(6); expect(result.totalBillableCharacters).to.equal(16); @@ -81,7 +83,8 @@ describe('countTokens()', () => { const result = await countTokens( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.totalTokens).to.equal(1837); expect(result.totalBillableCharacters).to.equal(117); @@ -109,7 +112,8 @@ describe('countTokens()', () => { const result = await countTokens( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.totalTokens).to.equal(258); expect(result).to.not.have.property('totalBillableCharacters'); @@ -135,8 +139,33 @@ describe('countTokens()', () => { json: mockResponse.json } as Response); await expect( - countTokens(fakeApiSettings, 'model', fakeRequestParams) + countTokens( + fakeApiSettings, + 'model', + fakeRequestParams, + new ChromeAdapter() + ) ).to.be.rejectedWith(/404.*not found/); expect(mockFetch).to.be.called; }); + it('on-device', async () => { + const chromeAdapter = new ChromeAdapter(); + const isAvailableStub = stub(chromeAdapter, 'isAvailable').resolves(true); + const mockResponse = getMockResponse( + 'vertexAI', + 'unary-success-total-tokens.json' + ); + const countTokensStub = stub(chromeAdapter, 'countTokens').resolves( + mockResponse as Response + ); + const result = await countTokens( + fakeApiSettings, + 'model', + fakeRequestParams, + chromeAdapter + ); + expect(result.totalTokens).eq(6); + expect(isAvailableStub).to.be.called; + expect(countTokensStub).to.be.calledWith(fakeRequestParams); + }); }); diff --git a/packages/vertexai/src/methods/count-tokens.ts b/packages/vertexai/src/methods/count-tokens.ts index c9d43a5b6fd..108a4b9ee6c 100644 --- a/packages/vertexai/src/methods/count-tokens.ts +++ b/packages/vertexai/src/methods/count-tokens.ts @@ -22,8 +22,9 @@ import { } from '../types'; import { Task, makeRequest } from '../requests/request'; import { ApiSettings } from '../types/internal'; +import { ChromeAdapter } from './chrome-adapter'; -export async function countTokens( +export async function countTokensOnCloud( apiSettings: ApiSettings, model: string, params: CountTokensRequest, @@ -39,3 +40,17 @@ export async function countTokens( ); return response.json(); } + +export async function countTokens( + apiSettings: ApiSettings, + model: string, + params: CountTokensRequest, + chromeAdapter: ChromeAdapter, + requestOptions?: RequestOptions +): Promise { + if (await chromeAdapter.isAvailable(params)) { + return (await chromeAdapter.countTokens(params)).json(); + } + + return countTokensOnCloud(apiSettings, model, params, requestOptions); +} diff --git a/packages/vertexai/src/models/generative-model.ts b/packages/vertexai/src/models/generative-model.ts index c58eb3a1497..0f7e408282c 100644 --- a/packages/vertexai/src/models/generative-model.ts +++ b/packages/vertexai/src/models/generative-model.ts @@ -153,6 +153,11 @@ export class GenerativeModel extends VertexAIModel { request: CountTokensRequest | string | Array ): Promise { const formattedParams = formatGenerateContentInput(request); - return countTokens(this._apiSettings, this.model, formattedParams); + return countTokens( + this._apiSettings, + this.model, + formattedParams, + this.chromeAdapter + ); } }