Skip to content

Commit b3d55df

Browse files
erikeldridgegsiddh
authored andcommitted
Define ChromeAdapter class
1 parent 82838c1 commit b3d55df

File tree

9 files changed

+233
-73
lines changed

9 files changed

+233
-73
lines changed

common/api-review/vertexai.api.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ export class BooleanSchema extends Schema {
4242
// @public
4343
export class ChatSession {
4444
// Warning: (ae-forgotten-export) The symbol "ApiSettings" needs to be exported by the entry point index.d.ts
45-
constructor(apiSettings: ApiSettings, model: string, params?: StartChatParams | undefined, requestOptions?: RequestOptions | undefined);
45+
// Warning: (ae-forgotten-export) The symbol "ChromeAdapter" needs to be exported by the entry point index.d.ts
46+
constructor(apiSettings: ApiSettings, model: string, chromeAdapter: ChromeAdapter, params?: StartChatParams | undefined, requestOptions?: RequestOptions | undefined);
4647
getHistory(): Promise<Content[]>;
4748
// (undocumented)
4849
model: string;
@@ -324,7 +325,7 @@ export interface GenerativeContentBlob {
324325

325326
// @public
326327
export class GenerativeModel extends VertexAIModel {
327-
constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions);
328+
constructor(vertexAI: VertexAI, modelParams: ModelParams, chromeAdapter: ChromeAdapter, requestOptions?: RequestOptions);
328329
countTokens(request: CountTokensRequest | string | Array<string | Part>): Promise<CountTokensResponse>;
329330
static DEFAULT_HYBRID_IN_CLOUD_MODEL: string;
330331
generateContent(request: GenerateContentRequest | string | Array<string | Part>): Promise<GenerateContentResult>;

packages/vertexai/src/api.ts

+10-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import {
3030
} from './types';
3131
import { VertexAIError } from './errors';
3232
import { VertexAIModel, GenerativeModel, ImagenModel } from './models';
33+
import { ChromeAdapter } from './methods/chrome-adapter';
3334

3435
export { ChatSession } from './methods/chat-session';
3536
export * from './requests/schema-builder';
@@ -91,7 +92,15 @@ export function getGenerativeModel(
9192
`Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })`
9293
);
9394
}
94-
return new GenerativeModel(vertexAI, inCloudParams, requestOptions);
95+
return new GenerativeModel(
96+
vertexAI,
97+
inCloudParams,
98+
new ChromeAdapter(
99+
hybridParams.mode,
100+
hybridParams.onDeviceParams
101+
),
102+
requestOptions
103+
);
95104
}
96105

97106
/**

packages/vertexai/src/methods/chat-session.test.ts

+16-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import * as generateContentMethods from './generate-content';
2323
import { GenerateContentStreamResult } from '../types';
2424
import { ChatSession } from './chat-session';
2525
import { ApiSettings } from '../types/internal';
26+
import { ChromeAdapter } from './chrome-adapter';
2627

2728
use(sinonChai);
2829
use(chaiAsPromised);
@@ -44,7 +45,11 @@ describe('ChatSession', () => {
4445
generateContentMethods,
4546
'generateContent'
4647
).rejects('generateContent failed');
47-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
48+
const chatSession = new ChatSession(
49+
fakeApiSettings,
50+
'a-model',
51+
new ChromeAdapter()
52+
);
4853
await expect(chatSession.sendMessage('hello')).to.be.rejected;
4954
expect(generateContentStub).to.be.calledWith(
5055
fakeApiSettings,
@@ -61,7 +66,11 @@ describe('ChatSession', () => {
6166
generateContentMethods,
6267
'generateContentStream'
6368
).rejects('generateContentStream failed');
64-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
69+
const chatSession = new ChatSession(
70+
fakeApiSettings,
71+
'a-model',
72+
new ChromeAdapter()
73+
);
6574
await expect(chatSession.sendMessageStream('hello')).to.be.rejected;
6675
expect(generateContentStreamStub).to.be.calledWith(
6776
fakeApiSettings,
@@ -80,7 +89,11 @@ describe('ChatSession', () => {
8089
generateContentMethods,
8190
'generateContentStream'
8291
).resolves({} as unknown as GenerateContentStreamResult);
83-
const chatSession = new ChatSession(fakeApiSettings, 'a-model');
92+
const chatSession = new ChatSession(
93+
fakeApiSettings,
94+
'a-model',
95+
new ChromeAdapter()
96+
);
8497
await chatSession.sendMessageStream('hello');
8598
expect(generateContentStreamStub).to.be.calledWith(
8699
fakeApiSettings,

packages/vertexai/src/methods/chat-session.ts

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import { validateChatHistory } from './chat-session-helpers';
3030
import { generateContent, generateContentStream } from './generate-content';
3131
import { ApiSettings } from '../types/internal';
3232
import { logger } from '../logger';
33+
import { ChromeAdapter } from './chrome-adapter';
3334

3435
/**
3536
* Do not log a message for this error.
@@ -50,6 +51,7 @@ export class ChatSession {
5051
constructor(
5152
apiSettings: ApiSettings,
5253
public model: string,
54+
private chromeAdapter: ChromeAdapter,
5355
public params?: StartChatParams,
5456
public requestOptions?: RequestOptions
5557
) {
@@ -95,6 +97,7 @@ export class ChatSession {
9597
this._apiSettings,
9698
this.model,
9799
generateContentRequest,
100+
this.chromeAdapter,
98101
this.requestOptions
99102
)
100103
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import {
2+
EnhancedGenerateContentResponse,
3+
GenerateContentRequest,
4+
InferenceMode
5+
} from '../types';
6+
import { LanguageModelCreateOptions } from '../types/language-model';
7+
8+
/**
9+
* Defines an inference "backend" that uses Chrome's on-device model,
10+
* and encapsulates logic for detecting when on-device is possible.
11+
*/
12+
export class ChromeAdapter {
13+
constructor(
14+
private mode?: InferenceMode,
15+
private onDeviceParams?: LanguageModelCreateOptions
16+
) {}
17+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
18+
async isAvailable(request: GenerateContentRequest): Promise<boolean> {
19+
return false;
20+
}
21+
async generateContentOnDevice(
22+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
23+
request: GenerateContentRequest
24+
): Promise<EnhancedGenerateContentResponse> {
25+
return {
26+
text: () => '',
27+
functionCalls: () => undefined
28+
};
29+
}
30+
}

packages/vertexai/src/methods/generate-content.test.ts

+50-10
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import {
3030
} from '../types';
3131
import { ApiSettings } from '../types/internal';
3232
import { Task } from '../requests/request';
33+
import { ChromeAdapter } from './chrome-adapter';
3334

3435
use(sinonChai);
3536
use(chaiAsPromised);
@@ -70,7 +71,8 @@ describe('generateContent()', () => {
7071
const result = await generateContent(
7172
fakeApiSettings,
7273
'model',
73-
fakeRequestParams
74+
fakeRequestParams,
75+
new ChromeAdapter()
7476
);
7577
expect(result.response.text()).to.include('Mountain View, California');
7678
expect(makeRequestStub).to.be.calledWith(
@@ -95,7 +97,8 @@ describe('generateContent()', () => {
9597
const result = await generateContent(
9698
fakeApiSettings,
9799
'model',
98-
fakeRequestParams
100+
fakeRequestParams,
101+
new ChromeAdapter()
99102
);
100103
expect(result.response.text()).to.include('Use Freshly Ground Coffee');
101104
expect(result.response.text()).to.include('30 minutes of brewing');
@@ -118,7 +121,8 @@ describe('generateContent()', () => {
118121
const result = await generateContent(
119122
fakeApiSettings,
120123
'model',
121-
fakeRequestParams
124+
fakeRequestParams,
125+
new ChromeAdapter()
122126
);
123127
expect(result.response.usageMetadata?.totalTokenCount).to.equal(1913);
124128
expect(result.response.usageMetadata?.candidatesTokenCount).to.equal(76);
@@ -153,7 +157,8 @@ describe('generateContent()', () => {
153157
const result = await generateContent(
154158
fakeApiSettings,
155159
'model',
156-
fakeRequestParams
160+
fakeRequestParams,
161+
new ChromeAdapter()
157162
);
158163
expect(result.response.text()).to.include(
159164
'Some information cited from an external source'
@@ -180,7 +185,8 @@ describe('generateContent()', () => {
180185
const result = await generateContent(
181186
fakeApiSettings,
182187
'model',
183-
fakeRequestParams
188+
fakeRequestParams,
189+
new ChromeAdapter()
184190
);
185191
expect(result.response.text).to.throw('SAFETY');
186192
expect(makeRequestStub).to.be.calledWith(
@@ -202,7 +208,8 @@ describe('generateContent()', () => {
202208
const result = await generateContent(
203209
fakeApiSettings,
204210
'model',
205-
fakeRequestParams
211+
fakeRequestParams,
212+
new ChromeAdapter()
206213
);
207214
expect(result.response.text).to.throw('SAFETY');
208215
expect(makeRequestStub).to.be.calledWith(
@@ -224,7 +231,8 @@ describe('generateContent()', () => {
224231
const result = await generateContent(
225232
fakeApiSettings,
226233
'model',
227-
fakeRequestParams
234+
fakeRequestParams,
235+
new ChromeAdapter()
228236
);
229237
expect(result.response.text()).to.equal('');
230238
expect(makeRequestStub).to.be.calledWith(
@@ -246,7 +254,8 @@ describe('generateContent()', () => {
246254
const result = await generateContent(
247255
fakeApiSettings,
248256
'model',
249-
fakeRequestParams
257+
fakeRequestParams,
258+
new ChromeAdapter()
250259
);
251260
expect(result.response.text()).to.include('Some text');
252261
expect(makeRequestStub).to.be.calledWith(
@@ -268,7 +277,12 @@ describe('generateContent()', () => {
268277
json: mockResponse.json
269278
} as Response);
270279
await expect(
271-
generateContent(fakeApiSettings, 'model', fakeRequestParams)
280+
generateContent(
281+
fakeApiSettings,
282+
'model',
283+
fakeRequestParams,
284+
new ChromeAdapter()
285+
)
272286
).to.be.rejectedWith(/400.*invalid argument/);
273287
expect(mockFetch).to.be.called;
274288
});
@@ -283,10 +297,36 @@ describe('generateContent()', () => {
283297
json: mockResponse.json
284298
} as Response);
285299
await expect(
286-
generateContent(fakeApiSettings, 'model', fakeRequestParams)
300+
generateContent(
301+
fakeApiSettings,
302+
'model',
303+
fakeRequestParams,
304+
new ChromeAdapter()
305+
)
287306
).to.be.rejectedWith(
288307
/firebasevertexai\.googleapis[\s\S]*my-project[\s\S]*api-not-enabled/
289308
);
290309
expect(mockFetch).to.be.called;
291310
});
311+
it('on-device', async () => {
312+
const expectedText = 'hi';
313+
const chromeAdapter = new ChromeAdapter();
314+
const mockIsAvailable = stub(chromeAdapter, 'isAvailable').resolves(true);
315+
const mockGenerateContent = stub(
316+
chromeAdapter,
317+
'generateContentOnDevice'
318+
).resolves({
319+
text: () => expectedText,
320+
functionCalls: () => undefined
321+
});
322+
const result = await generateContent(
323+
fakeApiSettings,
324+
'model',
325+
fakeRequestParams,
326+
chromeAdapter
327+
);
328+
expect(result.response.text()).to.equal(expectedText);
329+
expect(mockIsAvailable).to.be.called;
330+
expect(mockGenerateContent).to.be.calledWith(fakeRequestParams);
331+
});
292332
});

packages/vertexai/src/methods/generate-content.ts

+25-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
import {
19+
EnhancedGenerateContentResponse,
1920
GenerateContentRequest,
2021
GenerateContentResponse,
2122
GenerateContentResult,
@@ -26,6 +27,7 @@ import { Task, makeRequest } from '../requests/request';
2627
import { createEnhancedContentResponse } from '../requests/response-helpers';
2728
import { processStream } from '../requests/stream-reader';
2829
import { ApiSettings } from '../types/internal';
30+
import { ChromeAdapter } from './chrome-adapter';
2931

3032
export async function generateContentStream(
3133
apiSettings: ApiSettings,
@@ -44,12 +46,12 @@ export async function generateContentStream(
4446
return processStream(response);
4547
}
4648

47-
export async function generateContent(
49+
async function generateContentOnCloud(
4850
apiSettings: ApiSettings,
4951
model: string,
5052
params: GenerateContentRequest,
5153
requestOptions?: RequestOptions
52-
): Promise<GenerateContentResult> {
54+
): Promise<EnhancedGenerateContentResponse> {
5355
const response = await makeRequest(
5456
model,
5557
Task.GENERATE_CONTENT,
@@ -60,6 +62,27 @@ export async function generateContent(
6062
);
6163
const responseJson: GenerateContentResponse = await response.json();
6264
const enhancedResponse = createEnhancedContentResponse(responseJson);
65+
return enhancedResponse;
66+
}
67+
68+
export async function generateContent(
69+
apiSettings: ApiSettings,
70+
model: string,
71+
params: GenerateContentRequest,
72+
chromeAdapter: ChromeAdapter,
73+
requestOptions?: RequestOptions
74+
): Promise<GenerateContentResult> {
75+
let enhancedResponse;
76+
if (await chromeAdapter.isAvailable(params)) {
77+
enhancedResponse = await chromeAdapter.generateContentOnDevice(params);
78+
} else {
79+
enhancedResponse = await generateContentOnCloud(
80+
apiSettings,
81+
model,
82+
params,
83+
requestOptions
84+
);
85+
}
6386
return {
6487
response: enhancedResponse
6588
};

0 commit comments

Comments
 (0)