Skip to content

Commit a46fa4a

Browse files
authored
VinF Hybrid Inference: narrow Chrome input type (#8953)
1 parent 622916c commit a46fa4a

File tree

4 files changed

+52
-97
lines changed

4 files changed

+52
-97
lines changed

e2e/sample-apps/modular.js

+18-17
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,14 @@ async function callVertexAI(app) {
314314
console.log('[VERTEXAI] start');
315315
const vertexAI = getVertexAI(app);
316316
const model = getGenerativeModel(vertexAI, {
317-
mode: 'prefer_in_cloud'
317+
mode: 'prefer_on_device'
318318
});
319-
const result = await model.generateContentStream("What is Roko's Basalisk?");
320-
for await (const chunk of result.stream) {
321-
console.log(chunk.text());
322-
}
323-
console.log(`[VERTEXAI] counted tokens: ${result.totalTokens}`);
319+
const singleResult = await model.generateContent([
320+
{ text: 'describe the following:' },
321+
{ text: 'the mojave desert' }
322+
]);
323+
console.log(`Generated text: ${singleResult.response.text()}`);
324+
console.log(`[VERTEXAI] end`);
324325
}
325326

326327
/**
@@ -346,18 +347,18 @@ async function main() {
346347
const app = initializeApp(config);
347348
setLogLevel('warn');
348349

349-
callAppCheck(app);
350-
await authLogin(app);
351-
await callStorage(app);
352-
await callFirestore(app);
353-
await callDatabase(app);
354-
await callMessaging(app);
355-
callAnalytics(app);
356-
callPerformance(app);
357-
await callFunctions(app);
350+
// callAppCheck(app);
351+
// await authLogin(app);
352+
// await callStorage(app);
353+
// await callFirestore(app);
354+
// await callDatabase(app);
355+
// await callMessaging(app);
356+
// callAnalytics(app);
357+
// callPerformance(app);
358+
// await callFunctions(app);
358359
await callVertexAI(app);
359-
callDataConnect(app);
360-
await authLogout(app);
360+
// callDataConnect(app);
361+
// await authLogout(app);
361362
console.log('DONE');
362363
}
363364

packages/vertexai/src/methods/chrome-adapter.test.ts

+11-40
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import { ChromeAdapter } from './chrome-adapter';
2222
import {
2323
Availability,
2424
LanguageModel,
25-
LanguageModelCreateOptions
25+
LanguageModelCreateOptions,
26+
LanguageModelMessageContent
2627
} from '../types/language-model';
2728
import { stub } from 'sinon';
2829
import { GenerateContentRequest } from '../types';
@@ -105,22 +106,6 @@ describe('ChromeAdapter', () => {
105106
})
106107
).to.be.false;
107108
});
108-
it('returns false if request content has multiple parts', async () => {
109-
const adapter = new ChromeAdapter(
110-
{} as LanguageModel,
111-
'prefer_on_device'
112-
);
113-
expect(
114-
await adapter.isAvailable({
115-
contents: [
116-
{
117-
role: 'user',
118-
parts: [{ text: 'a' }, { text: 'b' }]
119-
}
120-
]
121-
})
122-
).to.be.false;
123-
});
124109
it('returns false if request content has non-text part', async () => {
125110
const adapter = new ChromeAdapter(
126111
{} as LanguageModel,
@@ -281,7 +266,8 @@ describe('ChromeAdapter', () => {
281266
create: () => Promise.resolve({})
282267
} as LanguageModel;
283268
const languageModel = {
284-
prompt: i => Promise.resolve(i)
269+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
270+
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
285271
} as LanguageModel;
286272
const createStub = stub(languageModelProvider, 'create').resolves(
287273
languageModel
@@ -305,13 +291,8 @@ describe('ChromeAdapter', () => {
305291
// Asserts Vertex input type is mapped to Chrome type.
306292
expect(promptStub).to.have.been.calledOnceWith([
307293
{
308-
role: request.contents[0].role,
309-
content: [
310-
{
311-
type: 'text',
312-
content: request.contents[0].parts[0].text
313-
}
314-
]
294+
type: 'text',
295+
content: request.contents[0].parts[0].text
315296
}
316297
]);
317298
// Asserts expected output.
@@ -366,21 +347,16 @@ describe('ChromeAdapter', () => {
366347
// Asserts Vertex input type is mapped to Chrome type.
367348
expect(measureInputUsageStub).to.have.been.calledOnceWith([
368349
{
369-
role: 'user',
370-
content: [
371-
{
372-
type: 'text',
373-
content: inputText
374-
}
375-
]
350+
type: 'text',
351+
content: inputText
376352
}
377353
]);
378354
expect(await response.json()).to.deep.equal({
379355
totalTokens: expectedCount
380356
});
381357
});
382358
});
383-
describe('generateContentStreamOnDevice', () => {
359+
describe('generateContentStream', () => {
384360
it('generates content stream', async () => {
385361
const languageModelProvider = {
386362
create: () => Promise.resolve({})
@@ -413,13 +389,8 @@ describe('ChromeAdapter', () => {
413389
expect(createStub).to.have.been.calledOnceWith(onDeviceParams);
414390
expect(promptStub).to.have.been.calledOnceWith([
415391
{
416-
role: request.contents[0].role,
417-
content: [
418-
{
419-
type: 'text',
420-
content: request.contents[0].parts[0].text
421-
}
422-
]
392+
type: 'text',
393+
content: request.contents[0].parts[0].text
423394
}
424395
]);
425396
const actual = await toStringArray(response.body!);

packages/vertexai/src/methods/chrome-adapter.ts

+19-34
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,15 @@
1616
*/
1717

1818
import {
19-
Content,
2019
CountTokensRequest,
2120
GenerateContentRequest,
2221
InferenceMode,
23-
Part,
24-
Role
22+
Part
2523
} from '../types';
2624
import {
2725
Availability,
2826
LanguageModel,
2927
LanguageModelCreateOptions,
30-
LanguageModelMessage,
31-
LanguageModelMessageRole,
3228
LanguageModelMessageContent
3329
} from '../types/language-model';
3430

@@ -100,8 +96,12 @@ export class ChromeAdapter {
10096
// TODO: normalize on-device params during construction.
10197
this.onDeviceParams || {}
10298
);
103-
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
104-
const text = await session.prompt(messages);
99+
// TODO: support multiple content objects when Chrome supports
100+
// sequence<LanguageModelMessage>
101+
const contents = request.contents[0].parts.map(
102+
ChromeAdapter.toLanguageModelMessageContent
103+
);
104+
const text = await session.prompt(contents);
105105
return ChromeAdapter.toResponse(text);
106106
}
107107

@@ -120,8 +120,12 @@ export class ChromeAdapter {
120120
// TODO: normalize on-device params during construction.
121121
this.onDeviceParams || {}
122122
);
123-
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
124-
const stream = await session.promptStreaming(messages);
123+
// TODO: support multiple content objects when Chrome supports
124+
// sequence<LanguageModelMessage>
125+
const contents = request.contents[0].parts.map(
126+
ChromeAdapter.toLanguageModelMessageContent
127+
);
128+
const stream = await session.promptStreaming(contents);
125129
return ChromeAdapter.toStreamResponse(stream);
126130
}
127131

@@ -131,8 +135,12 @@ export class ChromeAdapter {
131135
// TODO: normalize on-device params during construction.
132136
this.onDeviceParams || {}
133137
);
134-
const messages = ChromeAdapter.toLanguageModelMessages(request.contents);
135-
const tokenCount = await session.measureInputUsage(messages);
138+
// TODO: support multiple content objects when Chrome supports
139+
// sequence<LanguageModelMessage>
140+
const contents = request.contents[0].parts.map(
141+
ChromeAdapter.toLanguageModelMessageContent
142+
);
143+
const tokenCount = await session.measureInputUsage(contents);
136144
return {
137145
json: async () => ({
138146
totalTokens: tokenCount
@@ -155,10 +163,6 @@ export class ChromeAdapter {
155163
return false;
156164
}
157165

158-
if (content.parts.length > 1) {
159-
return false;
160-
}
161-
162166
if (!content.parts[0].text) {
163167
return false;
164168
}
@@ -188,25 +192,6 @@ export class ChromeAdapter {
188192
});
189193
}
190194

191-
/**
192-
* Converts a Vertex role string to a Chrome role string.
193-
*/
194-
private static toOnDeviceRole(role: Role): LanguageModelMessageRole {
195-
return role === 'model' ? 'assistant' : 'user';
196-
}
197-
198-
/**
199-
* Converts a Vertex Content object to a Chrome LanguageModelMessage object.
200-
*/
201-
private static toLanguageModelMessages(
202-
contents: Content[]
203-
): LanguageModelMessage[] {
204-
return contents.map(c => ({
205-
role: ChromeAdapter.toOnDeviceRole(c.role),
206-
content: c.parts.map(ChromeAdapter.toLanguageModelMessageContent)
207-
}));
208-
}
209-
210195
/**
211196
* Converts a Vertex Part object to a Chrome LanguageModelMessageContent object.
212197
*/

packages/vertexai/src/types/language-model.ts

+4-6
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,12 @@ interface LanguageModelExpectedInput {
5656
type: LanguageModelMessageType;
5757
languages?: string[];
5858
}
59-
export type LanguageModelPrompt =
60-
| LanguageModelMessage[]
61-
| LanguageModelMessageShorthand[]
62-
| string;
59+
// TODO: revert to type from Prompt API explainer once it's supported.
60+
export type LanguageModelPrompt = LanguageModelMessageContent[];
6361
type LanguageModelInitialPrompts =
6462
| LanguageModelMessage[]
6563
| LanguageModelMessageShorthand[];
66-
export interface LanguageModelMessage {
64+
interface LanguageModelMessage {
6765
role: LanguageModelMessageRole;
6866
content: LanguageModelMessageContent[];
6967
}
@@ -75,7 +73,7 @@ export interface LanguageModelMessageContent {
7573
type: LanguageModelMessageType;
7674
content: LanguageModelMessageContentValue;
7775
}
78-
export type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
76+
type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
7977
type LanguageModelMessageType = 'text' | 'image' | 'audio';
8078
type LanguageModelMessageContentValue =
8179
| ImageBitmapSource

0 commit comments

Comments
 (0)