Skip to content

Commit 98b0e22

Browse files
committed
Support structured output
1 parent e859c03 commit 98b0e22

File tree

5 files changed

+105
-20
lines changed

5 files changed

+105
-20
lines changed

e2e/sample-apps/modular.js

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ import {
5858
onValue,
5959
off
6060
} from 'firebase/database';
61-
import { getGenerativeModel, getVertexAI } from 'firebase/vertexai';
61+
import { getGenerativeModel, getVertexAI, Schema } from 'firebase/vertexai';
6262
import { getDataConnect, DataConnect } from 'firebase/data-connect';
6363

6464
/**
@@ -313,23 +313,48 @@ function callPerformance(app) {
313313
async function callVertexAI(app) {
314314
console.log('[VERTEXAI] start');
315315
const vertexAI = getVertexAI(app);
316-
const model = getGenerativeModel(vertexAI, {
317-
mode: 'prefer_on_device'
316+
317+
const responseSchema = Schema.object({
318+
properties: {
319+
characters: Schema.array({
320+
items: Schema.object({
321+
properties: {
322+
name: Schema.string(),
323+
accessory: Schema.string(),
324+
age: Schema.number(),
325+
species: Schema.string()
326+
},
327+
optionalProperties: ['accessory']
328+
})
329+
})
330+
}
318331
});
319-
const singleResult = await model.generateContent([
320-
{ text: 'describe this 20 x 20 px image in two words' },
321-
{
322-
inlineData: {
323-
mimeType: 'image/heic',
324-
data: 'AAAAGGZ0eXBoZWljAAAAAGhlaWNtaWYxAAAB7G1ldGEAAAAAAAAAIWhkbHIAAAAAAAAAAHBpY3QAAAAAAAAAAAAAAAAAAAAAJGRpbmYAAAAcZHJlZgAAAAAAAAABAAAADHVybCAAAAABAAAADnBpdG0AAAAAAAEAAAA4aWluZgAAAAAAAgAAABVpbmZlAgAAAAABAABodmMxAAAAABVpbmZlAgAAAQACAABFeGlmAAAAABppcmVmAAAAAAAAAA5jZHNjAAIAAQABAAABD2lwcnAAAADtaXBjbwAAABNjb2xybmNseAACAAIABoAAAAAMY2xsaQDLAEAAAAAUaXNwZQAAAAAAAAAUAAAADgAAAChjbGFwAAAAFAAAAAEAAAANAAAAAQAAAAAAAAAB/8AAAACAAAAAAAAJaXJvdAAAAAAQcGl4aQAAAAADCAgIAAAAcWh2Y0MBA3AAAACwAAAAAAAe8AD8/fj4AAALA6AAAQAXQAEMAf//A3AAAAMAsAAAAwAAAwAecCShAAEAI0IBAQNwAAADALAAAAMAAAMAHqAUIEHAjw1iHuRZVNwICBgCogABAAlEAcBhcshAUyQAAAAaaXBtYQAAAAAAAAABAAEHgQIDhIUGhwAAACxpbG9jAAAAAEQAAAIAAQAAAAEAAAJsAAABDAACAAAAAQAAAhQAAABYAAAAAW1kYXQAAAAAAAABdAAAAAZFeGlmAABNTQAqAAAACAAEARIAAwAAAAEAAQAAARoABQAAAAEAAAA+ARsABQAAAAEAAABGASgAAwAAAAEAAgAAAAAAAAAAAEgAAAABAAAASAAAAAEAAAEIKAGvoR8wDimTiRYUbALiHkU3ZdZ8DXAcSrRB9GARtVQHvnCE0LEyBGAyb5P4eYr6JAK5UxNX10WNlARq3ZpcGeVD+Xom6LodYasuZKKtDHCz/xnswOtC/ksZzVKhtWQqGvkXcsJnLYqWevNkacnccQ95jbHJBg9nXub69jAAN3xhNOXxjGSxaG9QvES5R7sYICEojRjLF5OB5K3v+okQAwfgWpz/u21ayideOgOZQLAyBkKOv7ymLNCagiPWTlHAuy/3qR1Q7m2ERFaxKIAbLSkIVO/P8m8+anKxhzhC//L8NMAUoF+Sf3aEH9O41fwLc+PlcbrDrjgY2EboD3cn9DyN32Rum2Ym'
332+
333+
const model = getGenerativeModel(vertexAI, {
334+
mode: 'prefer_on_device',
335+
inCloudParams: {
336+
generationConfig: {
337+
responseSchema
325338
}
326339
}
327-
]);
328-
console.log(`Generated text: ${singleResult.response.text()}`);
329-
const chat = model.startChat();
330-
let chatResult = await chat.sendMessage('describe red in two words');
331-
chatResult = await chat.sendMessage('describe blue');
332-
console.log('Chat history:', await chat.getHistory());
340+
});
341+
342+
const singleResult = await model.generateContent({
343+
generationConfig: {
344+
responseSchema
345+
},
346+
contents: [
347+
{
348+
role: 'user',
349+
parts: [
350+
{
351+
text: "For use in a children's card game, generate 10 animal-based characters."
352+
}
353+
]
354+
}
355+
]
356+
});
357+
console.log(`Generated text:`, JSON.parse(singleResult.response.text()));
333358
console.log(`[VERTEXAI] end`);
334359
}
335360

e2e/webpack.config.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ module.exports = [
8888
stats: {
8989
colors: true
9090
},
91-
devtool: 'source-map',
91+
devtool: 'eval-source-map',
9292
devServer: {
9393
static: './build'
9494
}

packages/vertexai/src/methods/chrome-adapter.test.ts

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import {
2828
} from '../types/language-model';
2929
import { match, stub } from 'sinon';
3030
import { GenerateContentRequest, AIErrorCode } from '../types';
31+
import { Schema } from '../api';
3132

3233
use(sinonChai);
3334
use(chaiAsPromised);
@@ -406,6 +407,51 @@ describe('ChromeAdapter', () => {
406407
]
407408
});
408409
});
410+
it('honors response constraint', async () => {
411+
const languageModel = {
412+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
413+
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
414+
} as LanguageModel;
415+
const languageModelProvider = {
416+
create: () => Promise.resolve(languageModel)
417+
} as LanguageModel;
418+
const promptOutput = '{}';
419+
const promptStub = stub(languageModel, 'prompt').resolves(promptOutput);
420+
const adapter = new ChromeAdapter(
421+
languageModelProvider,
422+
'prefer_on_device'
423+
);
424+
const responseSchema = Schema.object({
425+
properties: {}
426+
});
427+
const request = {
428+
generationConfig: {
429+
responseSchema
430+
},
431+
contents: [{ role: 'user', parts: [{ text: 'anything' }] }]
432+
} as GenerateContentRequest;
433+
const response = await adapter.generateContent(request);
434+
expect(promptStub).to.have.been.calledOnceWith(
435+
[
436+
{
437+
type: 'text',
438+
content: request.contents[0].parts[0].text
439+
}
440+
],
441+
{
442+
responseConstraint: responseSchema
443+
}
444+
);
445+
expect(await response.json()).to.deep.equal({
446+
candidates: [
447+
{
448+
content: {
449+
parts: [{ text: promptOutput }]
450+
}
451+
}
452+
]
453+
});
454+
});
409455
});
410456
describe('countTokens', () => {
411457
it('counts tokens is not yet available', async () => {

packages/vertexai/src/methods/chrome-adapter.ts

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ import {
2828
Availability,
2929
LanguageModel,
3030
LanguageModelCreateOptions,
31-
LanguageModelMessageContent
31+
LanguageModelMessageContent,
32+
LanguageModelPromptOptions
3233
} from '../types/language-model';
3334

3435
/**
@@ -106,15 +107,27 @@ export class ChromeAdapter {
106107
*/
107108
async generateContent(request: GenerateContentRequest): Promise<Response> {
108109
const session = await this.createSession();
110+
109111
// TODO: support multiple content objects when Chrome supports
110112
// sequence<LanguageModelMessage>
111113
const contents = await Promise.all(
112114
request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent)
113115
);
114-
const text = await session.prompt(contents);
116+
117+
const options = ChromeAdapter.extractLanguageModelPromptOptions(request);
118+
119+
const text = await session.prompt(contents, options);
115120
return ChromeAdapter.toResponse(text);
116121
}
117122

123+
private static extractLanguageModelPromptOptions(
124+
request: GenerateContentRequest
125+
): LanguageModelPromptOptions {
126+
return {
127+
responseConstraint: request.generationConfig?.responseSchema
128+
};
129+
}
130+
118131
/**
119132
* Generates content stream on device.
120133
*

packages/vertexai/src/types/language-model.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ export interface LanguageModelCreateOptions
4949
systemPrompt?: string;
5050
initialPrompts?: LanguageModelInitialPrompts;
5151
}
52-
interface LanguageModelPromptOptions {
53-
signal?: AbortSignal;
52+
export interface LanguageModelPromptOptions {
53+
responseConstraint?: object;
54+
// TODO: Restore AbortSignal once the API is defined.
5455
}
5556
interface LanguageModelExpectedInput {
5657
type: LanguageModelMessageType;

0 commit comments

Comments
 (0)