Skip to content

Sid hybrid count token #8925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: erikeldridge-vertex-stream
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions packages/vertexai/src/methods/chrome-adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ describe('ChromeAdapter', () => {
Promise.resolve({
available: 'after-download'
}),
create: () => {}
create: () => { }
}
} as AI;
const downloadPromise = new Promise<AILanguageModel>(() => {
Expand All @@ -182,7 +182,7 @@ describe('ChromeAdapter', () => {
Promise.resolve({
available: 'after-download'
}),
create: () => {}
create: () => { }
}
} as AI;
let resolveDownload;
Expand Down Expand Up @@ -298,4 +298,87 @@ describe('ChromeAdapter', () => {
});
});
});
describe('countTokens', () => {
it('With no initial prompts', async () => {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: describing this as a test for a singular input, eg it('counts tokens from a singular input', would be more self-explanatory from my perspective.

const aiProvider = {
languageModel: {
create: () => Promise.resolve({})
}
} as AI;
const inputText = "first";
const expectedCount = 10;
const model = {
countPromptTokens: _s => Promise.resolve(123),
} as AILanguageModel;
// overrides impl with stub method
const countPromptTokensStub = stub(model, 'countPromptTokens').resolves(expectedCount);
const factoryStub = stub(aiProvider.languageModel, 'create').resolves(model);
const adapter = new ChromeAdapter(
aiProvider,
'prefer_on_device',
);
const response = await adapter.countTokens({
contents: [
{ role: 'user', parts: [{ text: inputText }] }
]
});
expect(factoryStub).to.have.been.calledOnceWith({
// initialPrompts must be empty
initialPrompts: []
});
// validate count tokens gets called with the last entry from the input
expect(countPromptTokensStub).to.have.been.calledOnceWith({
role: 'user',
content: inputText
});
expect(await response.json()).to.deep.equal({
totalTokens: expectedCount
});
});
it('Extracts initial prompts and then does counts tokens', async () => {
const aiProvider = {
languageModel: {
create: () => Promise.resolve({})
}
} as AI;
const expectedCount = 10;
const model = {
countPromptTokens: _s => Promise.resolve(123),
} as AILanguageModel;
// overrides impl with stub method
const countPromptTokensStub = stub(model, 'countPromptTokens').resolves(expectedCount);
const factoryStub = stub(aiProvider.languageModel, 'create').resolves(model);
const text = ['first', 'second', 'third'];
const onDeviceParams = {
initialPrompts: [{ role: 'user', content: text[0] }]
} as AILanguageModelCreateOptionsWithSystemPrompt;
const adapter = new ChromeAdapter(
aiProvider,
'prefer_on_device',
onDeviceParams
);
const response = await adapter.countTokens({
contents: [
{ role: 'model', parts: [{ text: text[1] }] },
{ role: 'user', parts: [{ text: text[2] }] }
]
});
expect(factoryStub).to.have.been.calledOnceWith({
initialPrompts: [
{ role: 'user', content: text[0] },
// Asserts tail is passed as initial prompts, and
// role is normalized from model to assistant.
{ role: 'assistant', content: text[1] }
]
});
// validate count tokens gets called with the last entry from the input
expect(countPromptTokensStub).to.have.been.calledOnceWith({
role: 'user',
content: text[2]
});
expect(await response.json()).to.deep.equal({
totalTokens: expectedCount
});
});
});
});
15 changes: 15 additions & 0 deletions packages/vertexai/src/methods/chrome-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import { isChrome } from '@firebase/util';
import {
Content,
CountTokensRequest,
GenerateContentRequest,
InferenceMode,
Role
Expand Down Expand Up @@ -102,6 +103,20 @@ export class ChromeAdapter {
const stream = await session.promptStreaming(prompt.content);
return ChromeAdapter.toStreamResponse(stream);
}
async countTokens(request: CountTokensRequest): Promise<Response> {
const options = this.onDeviceParams || {};
options.initialPrompts ??= [];
const extractedInitialPrompts = ChromeAdapter.toInitialPrompts(request.contents);
const currentPrompt = extractedInitialPrompts.pop()!;
options.initialPrompts.push(...extractedInitialPrompts);
const session = await this.session(options);
const tokenCount = await session.countPromptTokens(currentPrompt);
return {
json: async () => ({
totalTokens: tokenCount,
})
} as Response;
}
private static isOnDeviceRequest(request: GenerateContentRequest): boolean {
// Returns false if the prompt is empty.
if (request.contents.length === 0) {
Expand Down
12 changes: 8 additions & 4 deletions packages/vertexai/src/methods/count-tokens.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { countTokens } from './count-tokens';
import { CountTokensRequest } from '../types';
import { ApiSettings } from '../types/internal';
import { Task } from '../requests/request';
import { ChromeAdapter } from './chrome-adapter';

use(sinonChai);
use(chaiAsPromised);
Expand Down Expand Up @@ -52,7 +53,8 @@ describe('countTokens()', () => {
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams
fakeRequestParams,
new ChromeAdapter()
);
expect(result.totalTokens).to.equal(6);
expect(result.totalBillableCharacters).to.equal(16);
Expand All @@ -77,7 +79,8 @@ describe('countTokens()', () => {
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams
fakeRequestParams,
new ChromeAdapter()
);
expect(result.totalTokens).to.equal(1837);
expect(result.totalBillableCharacters).to.equal(117);
Expand All @@ -104,7 +107,8 @@ describe('countTokens()', () => {
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams
fakeRequestParams,
new ChromeAdapter()
);
expect(result.totalTokens).to.equal(258);
expect(result).to.not.have.property('totalBillableCharacters');
Expand All @@ -127,7 +131,7 @@ describe('countTokens()', () => {
json: mockResponse.json
} as Response);
await expect(
countTokens(fakeApiSettings, 'model', fakeRequestParams)
countTokens(fakeApiSettings, 'model', fakeRequestParams, new ChromeAdapter())
).to.be.rejectedWith(/404.*not found/);
expect(mockFetch).to.be.called;
});
Expand Down
17 changes: 16 additions & 1 deletion packages/vertexai/src/methods/count-tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import {
} from '../types';
import { Task, makeRequest } from '../requests/request';
import { ApiSettings } from '../types/internal';
import { ChromeAdapter } from './chrome-adapter';

export async function countTokens(
export async function countTokensOnCloud(
apiSettings: ApiSettings,
model: string,
params: CountTokensRequest,
Expand All @@ -39,3 +40,17 @@ export async function countTokens(
);
return response.json();
}

export async function countTokens(
apiSettings: ApiSettings,
model: string,
params: CountTokensRequest,
chromeAdapter: ChromeAdapter,
requestOptions?: RequestOptions
): Promise<CountTokensResponse> {
if (await chromeAdapter.isAvailable(params)) {
return (await chromeAdapter.countTokens(params)).json();
} else {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: unnecessary else, given the return statement above

return countTokensOnCloud(apiSettings, model, params, requestOptions);
}
}
2 changes: 1 addition & 1 deletion packages/vertexai/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,6 @@ export class GenerativeModel extends VertexAIModel {
request: CountTokensRequest | string | Array<string | Part>
): Promise<CountTokensResponse> {
const formattedParams = formatGenerateContentInput(request);
return countTokens(this._apiSettings, this.model, formattedParams);
return countTokens(this._apiSettings, this.model, formattedParams, this.chromeAdapter);
}
}