Skip to content

Commit 3d11ecc

Browse files
authored
Add systemInstruction and toolConfig (#8146)
1 parent f25ccbb commit 3d11ecc

File tree

8 files changed

+235
-65
lines changed

8 files changed

+235
-65
lines changed

packages/vertexai/src/methods/chat-session-helpers.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@ const VALID_PART_FIELDS: Array<keyof Part> = [
3030
const VALID_PARTS_PER_ROLE: { [key in Role]: Array<keyof Part> } = {
3131
user: ['text', 'inlineData'],
3232
function: ['functionResponse'],
33-
model: ['text', 'functionCall']
33+
model: ['text', 'functionCall'],
34+
// System instructions shouldn't be in history anyway.
35+
system: ['text']
3436
};
3537

3638
const VALID_PREVIOUS_CONTENT_ROLES: { [key in Role]: Role[] } = {
3739
user: ['model'],
3840
function: ['model'],
39-
model: ['user', 'function']
41+
model: ['user', 'function'],
42+
// System instructions shouldn't be in history.
43+
system: []
4044
};
4145

4246
export function validateChatHistory(history: Content[]): void {

packages/vertexai/src/methods/chat-session.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ export class ChatSession {
8282
safetySettings: this.params?.safetySettings,
8383
generationConfig: this.params?.generationConfig,
8484
tools: this.params?.tools,
85+
toolConfig: this.params?.toolConfig,
86+
systemInstruction: this.params?.systemInstruction,
8587
contents: [...this._history, newContent]
8688
};
8789
let finalResult = {} as GenerateContentResult;
@@ -135,6 +137,8 @@ export class ChatSession {
135137
safetySettings: this.params?.safetySettings,
136138
generationConfig: this.params?.generationConfig,
137139
tools: this.params?.tools,
140+
toolConfig: this.params?.toolConfig,
141+
systemInstruction: this.params?.systemInstruction,
138142
contents: [...this._history, newContent]
139143
};
140144
const streamPromise = generateContentStream(

packages/vertexai/src/methods/generate-content.test.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import * as request from '../requests/request';
2424
import { generateContent } from './generate-content';
2525
import {
2626
GenerateContentRequest,
27+
HarmBlockMethod,
2728
HarmBlockThreshold,
2829
HarmCategory
2930
} from '../types';
@@ -47,7 +48,8 @@ const fakeRequestParams: GenerateContentRequest = {
4748
safetySettings: [
4849
{
4950
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
50-
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
51+
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
52+
method: HarmBlockMethod.SEVERITY
5153
}
5254
]
5355
};

packages/vertexai/src/models/generative-model.test.ts

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
import { expect } from 'chai';
17+
import { use, expect } from 'chai';
1818
import { GenerativeModel } from './generative-model';
19-
import { VertexAI } from '../public-types';
19+
import { FunctionCallingMode, VertexAI } from '../public-types';
20+
import * as request from '../requests/request';
21+
import { match, restore, stub } from 'sinon';
22+
import { getMockResponse } from '../../test-utils/mock-response';
23+
import sinonChai from 'sinon-chai';
24+
25+
use(sinonChai);
2026

2127
const fakeVertexAI: VertexAI = {
2228
app: {
@@ -53,4 +59,157 @@ describe('GenerativeModel', () => {
5359
});
5460
expect(genModel.model).to.equal('tunedModels/my-model');
5561
});
62+
it('passes params through to generateContent', async () => {
63+
const genModel = new GenerativeModel(fakeVertexAI, {
64+
model: 'my-model',
65+
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
66+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
67+
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
68+
});
69+
expect(genModel.tools?.length).to.equal(1);
70+
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
71+
FunctionCallingMode.NONE
72+
);
73+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
74+
const mockResponse = getMockResponse(
75+
'unary-success-basic-reply-short.json'
76+
);
77+
const makeRequestStub = stub(request, 'makeRequest').resolves(
78+
mockResponse as Response
79+
);
80+
await genModel.generateContent('hello');
81+
expect(makeRequestStub).to.be.calledWith(
82+
'publishers/google/models/my-model',
83+
request.Task.GENERATE_CONTENT,
84+
match.any,
85+
false,
86+
match((value: string) => {
87+
return (
88+
value.includes('myfunc') &&
89+
value.includes(FunctionCallingMode.NONE) &&
90+
value.includes('be friendly')
91+
);
92+
}),
93+
{}
94+
);
95+
restore();
96+
});
97+
it('generateContent overrides model values', async () => {
98+
const genModel = new GenerativeModel(fakeVertexAI, {
99+
model: 'my-model',
100+
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
101+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
102+
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
103+
});
104+
expect(genModel.tools?.length).to.equal(1);
105+
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
106+
FunctionCallingMode.NONE
107+
);
108+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
109+
const mockResponse = getMockResponse(
110+
'unary-success-basic-reply-short.json'
111+
);
112+
const makeRequestStub = stub(request, 'makeRequest').resolves(
113+
mockResponse as Response
114+
);
115+
await genModel.generateContent({
116+
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
117+
tools: [{ functionDeclarations: [{ name: 'otherfunc' }] }],
118+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } },
119+
systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }
120+
});
121+
expect(makeRequestStub).to.be.calledWith(
122+
'publishers/google/models/my-model',
123+
request.Task.GENERATE_CONTENT,
124+
match.any,
125+
false,
126+
match((value: string) => {
127+
return (
128+
value.includes('otherfunc') &&
129+
value.includes(FunctionCallingMode.AUTO) &&
130+
value.includes('be formal')
131+
);
132+
}),
133+
{}
134+
);
135+
restore();
136+
});
137+
it('passes params through to chat.sendMessage', async () => {
138+
const genModel = new GenerativeModel(fakeVertexAI, {
139+
model: 'my-model',
140+
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
141+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
142+
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
143+
});
144+
expect(genModel.tools?.length).to.equal(1);
145+
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
146+
FunctionCallingMode.NONE
147+
);
148+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
149+
const mockResponse = getMockResponse(
150+
'unary-success-basic-reply-short.json'
151+
);
152+
const makeRequestStub = stub(request, 'makeRequest').resolves(
153+
mockResponse as Response
154+
);
155+
await genModel.startChat().sendMessage('hello');
156+
expect(makeRequestStub).to.be.calledWith(
157+
'publishers/google/models/my-model',
158+
request.Task.GENERATE_CONTENT,
159+
match.any,
160+
false,
161+
match((value: string) => {
162+
return (
163+
value.includes('myfunc') &&
164+
value.includes(FunctionCallingMode.NONE) &&
165+
value.includes('be friendly')
166+
);
167+
}),
168+
{}
169+
);
170+
restore();
171+
});
172+
it('startChat overrides model values', async () => {
173+
const genModel = new GenerativeModel(fakeVertexAI, {
174+
model: 'my-model',
175+
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
176+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
177+
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
178+
});
179+
expect(genModel.tools?.length).to.equal(1);
180+
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
181+
FunctionCallingMode.NONE
182+
);
183+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
184+
const mockResponse = getMockResponse(
185+
'unary-success-basic-reply-short.json'
186+
);
187+
const makeRequestStub = stub(request, 'makeRequest').resolves(
188+
mockResponse as Response
189+
);
190+
await genModel
191+
.startChat({
192+
tools: [{ functionDeclarations: [{ name: 'otherfunc' }] }],
193+
toolConfig: {
194+
functionCallingConfig: { mode: FunctionCallingMode.AUTO }
195+
},
196+
systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }
197+
})
198+
.sendMessage('hello');
199+
expect(makeRequestStub).to.be.calledWith(
200+
'publishers/google/models/my-model',
201+
request.Task.GENERATE_CONTENT,
202+
match.any,
203+
false,
204+
match((value: string) => {
205+
return (
206+
value.includes('otherfunc') &&
207+
value.includes(FunctionCallingMode.AUTO) &&
208+
value.includes('be formal')
209+
);
210+
}),
211+
{}
212+
);
213+
restore();
214+
});
56215
});

packages/vertexai/src/models/generative-model.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
generateContentStream
2121
} from '../methods/generate-content';
2222
import {
23+
Content,
2324
CountTokensRequest,
2425
CountTokensResponse,
2526
GenerateContentRequest,
@@ -31,7 +32,8 @@ import {
3132
RequestOptions,
3233
SafetySetting,
3334
StartChatParams,
34-
Tool
35+
Tool,
36+
ToolConfig
3537
} from '../types';
3638
import { ChatSession } from '../methods/chat-session';
3739
import { countTokens } from '../methods/count-tokens';
@@ -52,6 +54,8 @@ export class GenerativeModel {
5254
safetySettings: SafetySetting[];
5355
requestOptions?: RequestOptions;
5456
tools?: Tool[];
57+
toolConfig?: ToolConfig;
58+
systemInstruction?: Content;
5559

5660
constructor(
5761
vertexAI: VertexAI,
@@ -88,6 +92,8 @@ export class GenerativeModel {
8892
this.generationConfig = modelParams.generationConfig || {};
8993
this.safetySettings = modelParams.safetySettings || [];
9094
this.tools = modelParams.tools;
95+
this.toolConfig = modelParams.toolConfig;
96+
this.systemInstruction = modelParams.systemInstruction;
9197
this.requestOptions = requestOptions || {};
9298
}
9399

@@ -106,6 +112,8 @@ export class GenerativeModel {
106112
generationConfig: this.generationConfig,
107113
safetySettings: this.safetySettings,
108114
tools: this.tools,
115+
toolConfig: this.toolConfig,
116+
systemInstruction: this.systemInstruction,
109117
...formattedParams
110118
},
111119
this.requestOptions
@@ -129,6 +137,8 @@ export class GenerativeModel {
129137
generationConfig: this.generationConfig,
130138
safetySettings: this.safetySettings,
131139
tools: this.tools,
140+
toolConfig: this.toolConfig,
141+
systemInstruction: this.systemInstruction,
132142
...formattedParams
133143
},
134144
this.requestOptions
@@ -145,6 +155,8 @@ export class GenerativeModel {
145155
this.model,
146156
{
147157
tools: this.tools,
158+
toolConfig: this.toolConfig,
159+
systemInstruction: this.systemInstruction,
148160
...startChatParams
149161
},
150162
this.requestOptions

packages/vertexai/src/types/enums.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ export type Role = (typeof POSSIBLE_ROLES)[number];
2525
* Possible roles.
2626
* @public
2727
*/
28-
export const POSSIBLE_ROLES = ['user', 'model', 'function'] as const;
28+
export const POSSIBLE_ROLES = ['user', 'model', 'function', 'system'] as const;
2929

3030
/**
3131
* Harm categories that would cause prompts or candidates to be blocked.
@@ -133,3 +133,22 @@ export enum FinishReason {
133133
// Unknown reason.
134134
OTHER = 'OTHER'
135135
}
136+
137+
/**
138+
* @public
139+
*/
140+
export enum FunctionCallingMode {
141+
// Unspecified function calling mode. This value should not be used.
142+
MODE_UNSPECIFIED = 'MODE_UNSPECIFIED',
143+
// Default model behavior, model decides to predict either a function call
144+
// or a natural language repspose.
145+
AUTO = 'AUTO',
146+
// Model is constrained to always predicting a function call only.
147+
// If "allowed_function_names" is set, the predicted function call will be
148+
// limited to any one of "allowed_function_names", else the predicted
149+
// function call will be any one of the provided "function_declarations".
150+
ANY = 'ANY',
151+
// Model will not predict any function call. Model behavior is same as when
152+
// not passing any function declarations.
153+
NONE = 'NONE'
154+
}

packages/vertexai/src/types/requests.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
*/
1717

1818
import { Content } from './content';
19-
import { HarmBlockMethod, HarmBlockThreshold, HarmCategory } from './enums';
19+
import {
20+
FunctionCallingMode,
21+
HarmBlockMethod,
22+
HarmBlockThreshold,
23+
HarmCategory
24+
} from './enums';
2025

2126
/**
2227
* Base parameters for a number of methods.
@@ -34,6 +39,8 @@ export interface BaseParams {
3439
export interface ModelParams extends BaseParams {
3540
model: string;
3641
tools?: Tool[];
42+
toolConfig?: ToolConfig;
43+
systemInstruction?: Content;
3744
}
3845

3946
/**
@@ -43,6 +50,8 @@ export interface ModelParams extends BaseParams {
4350
export interface GenerateContentRequest extends BaseParams {
4451
contents: Content[];
4552
tools?: Tool[];
53+
toolConfig?: ToolConfig;
54+
systemInstruction?: Content;
4655
}
4756

4857
/**
@@ -77,6 +86,8 @@ export interface GenerationConfig {
7786
export interface StartChatParams extends BaseParams {
7887
history?: Content[];
7988
tools?: Tool[];
89+
toolConfig?: ToolConfig;
90+
systemInstruction?: Content;
8091
}
8192

8293
/**
@@ -220,3 +231,19 @@ export interface FunctionDeclarationSchemaProperty {
220231
/** Optional. The example of the property. */
221232
example?: unknown;
222233
}
234+
235+
/**
236+
* Tool config. This config is shared for all tools provided in the request.
237+
* @public
238+
*/
239+
export interface ToolConfig {
240+
functionCallingConfig: FunctionCallingConfig;
241+
}
242+
243+
/**
244+
* @public
245+
*/
246+
export interface FunctionCallingConfig {
247+
mode?: FunctionCallingMode;
248+
allowedFunctionNames?: string[];
249+
}

0 commit comments

Comments
 (0)