Skip to content

Commit 1b5f74f

Browse files
committed
add systemInstruction and toolConfig
1 parent f25ccbb commit 1b5f74f

File tree

7 files changed

+230
-63
lines changed

7 files changed

+230
-63
lines changed

packages/vertexai/src/methods/chat-session.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ export class ChatSession {
8282
safetySettings: this.params?.safetySettings,
8383
generationConfig: this.params?.generationConfig,
8484
tools: this.params?.tools,
85+
toolConfig: this.params?.toolConfig,
86+
systemInstruction: this.params?.systemInstruction,
8587
contents: [...this._history, newContent]
8688
};
8789
let finalResult = {} as GenerateContentResult;
@@ -135,6 +137,8 @@ export class ChatSession {
135137
safetySettings: this.params?.safetySettings,
136138
generationConfig: this.params?.generationConfig,
137139
tools: this.params?.tools,
140+
toolConfig: this.params?.toolConfig,
141+
systemInstruction: this.params?.systemInstruction,
138142
contents: [...this._history, newContent]
139143
};
140144
const streamPromise = generateContentStream(

packages/vertexai/src/methods/generate-content.test.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import * as request from '../requests/request';
2424
import { generateContent } from './generate-content';
2525
import {
2626
GenerateContentRequest,
27+
HarmBlockMethod,
2728
HarmBlockThreshold,
2829
HarmCategory
2930
} from '../types';
@@ -47,7 +48,8 @@ const fakeRequestParams: GenerateContentRequest = {
4748
safetySettings: [
4849
{
4950
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
50-
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
51+
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
52+
method: HarmBlockMethod.SEVERITY
5153
}
5254
]
5355
};

packages/vertexai/src/models/generative-model.test.ts

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,16 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
import { expect } from 'chai';
17+
import { use, expect } from 'chai';
1818
import { GenerativeModel } from './generative-model';
19-
import { VertexAI } from '../public-types';
19+
import { FunctionCallingMode, VertexAI } from '../public-types';
20+
import * as request from '../requests/request';
21+
import { match, restore, stub } from 'sinon';
22+
import { getMockResponse } from '../../test-utils/mock-response';
23+
import { Task } from '../requests/request';
24+
import sinonChai from 'sinon-chai';
25+
26+
use(sinonChai);
2027

2128
const fakeVertexAI: VertexAI = {
2229
app: {
@@ -53,4 +60,157 @@ describe('GenerativeModel', () => {
5360
});
5461
expect(genModel.model).to.equal('tunedModels/my-model');
5562
});
63+
it('passes params through to generateContent', async () => {
64+
const genModel = new GenerativeModel(fakeVertexAI, {
65+
model: 'my-model',
66+
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
67+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
68+
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
69+
});
70+
expect(genModel.tools?.length).to.equal(1);
71+
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
72+
FunctionCallingMode.NONE
73+
);
74+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
75+
const mockResponse = getMockResponse(
76+
'unary-success-basic-reply-short.json'
77+
);
78+
const makeRequestStub = stub(request, 'makeRequest').resolves(
79+
mockResponse as Response
80+
);
81+
await genModel.generateContent('hello');
82+
expect(makeRequestStub).to.be.calledWith(
83+
'publishers/google/models/my-model',
84+
Task.GENERATE_CONTENT,
85+
match.any,
86+
false,
87+
match((value: string) => {
88+
return (
89+
value.includes('myfunc') &&
90+
value.includes(FunctionCallingMode.NONE) &&
91+
value.includes('be friendly')
92+
);
93+
}),
94+
{}
95+
);
96+
restore();
97+
});
98+
it('generateContent overrides model values', async () => {
99+
const genModel = new GenerativeModel(fakeVertexAI, {
100+
model: 'my-model',
101+
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
102+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
103+
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
104+
});
105+
expect(genModel.tools?.length).to.equal(1);
106+
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
107+
FunctionCallingMode.NONE
108+
);
109+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
110+
const mockResponse = getMockResponse(
111+
'unary-success-basic-reply-short.json'
112+
);
113+
const makeRequestStub = stub(request, 'makeRequest').resolves(
114+
mockResponse as Response
115+
);
116+
await genModel.generateContent({
117+
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
118+
tools: [{ functionDeclarations: [{ name: 'otherfunc' }] }],
119+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } },
120+
systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }
121+
});
122+
expect(makeRequestStub).to.be.calledWith(
123+
'publishers/google/models/my-model',
124+
Task.GENERATE_CONTENT,
125+
match.any,
126+
false,
127+
match((value: string) => {
128+
return (
129+
value.includes('otherfunc') &&
130+
value.includes(FunctionCallingMode.AUTO) &&
131+
value.includes('be formal')
132+
);
133+
}),
134+
{}
135+
);
136+
restore();
137+
});
138+
it('passes params through to chat.sendMessage', async () => {
139+
const genModel = new GenerativeModel(fakeVertexAI, {
140+
model: 'my-model',
141+
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
142+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
143+
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
144+
});
145+
expect(genModel.tools?.length).to.equal(1);
146+
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
147+
FunctionCallingMode.NONE
148+
);
149+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
150+
const mockResponse = getMockResponse(
151+
'unary-success-basic-reply-short.json'
152+
);
153+
const makeRequestStub = stub(request, 'makeRequest').resolves(
154+
mockResponse as Response
155+
);
156+
await genModel.startChat().sendMessage('hello');
157+
expect(makeRequestStub).to.be.calledWith(
158+
'publishers/google/models/my-model',
159+
Task.GENERATE_CONTENT,
160+
match.any,
161+
false,
162+
match((value: string) => {
163+
return (
164+
value.includes('myfunc') &&
165+
value.includes(FunctionCallingMode.NONE) &&
166+
value.includes('be friendly')
167+
);
168+
}),
169+
{}
170+
);
171+
restore();
172+
});
173+
it('startChat overrides model values', async () => {
174+
const genModel = new GenerativeModel(fakeVertexAI, {
175+
model: 'my-model',
176+
tools: [{ functionDeclarations: [{ name: 'myfunc' }] }],
177+
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
178+
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
179+
});
180+
expect(genModel.tools?.length).to.equal(1);
181+
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
182+
FunctionCallingMode.NONE
183+
);
184+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
185+
const mockResponse = getMockResponse(
186+
'unary-success-basic-reply-short.json'
187+
);
188+
const makeRequestStub = stub(request, 'makeRequest').resolves(
189+
mockResponse as Response
190+
);
191+
await genModel
192+
.startChat({
193+
tools: [{ functionDeclarations: [{ name: 'otherfunc' }] }],
194+
toolConfig: {
195+
functionCallingConfig: { mode: FunctionCallingMode.AUTO }
196+
},
197+
systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }
198+
})
199+
.sendMessage('hello');
200+
expect(makeRequestStub).to.be.calledWith(
201+
'publishers/google/models/my-model',
202+
Task.GENERATE_CONTENT,
203+
match.any,
204+
false,
205+
match((value: string) => {
206+
return (
207+
value.includes('otherfunc') &&
208+
value.includes(FunctionCallingMode.AUTO) &&
209+
value.includes('be formal')
210+
);
211+
}),
212+
{}
213+
);
214+
restore();
215+
});
56216
});

packages/vertexai/src/models/generative-model.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
generateContentStream
2121
} from '../methods/generate-content';
2222
import {
23+
Content,
2324
CountTokensRequest,
2425
CountTokensResponse,
2526
GenerateContentRequest,
@@ -31,7 +32,8 @@ import {
3132
RequestOptions,
3233
SafetySetting,
3334
StartChatParams,
34-
Tool
35+
Tool,
36+
ToolConfig
3537
} from '../types';
3638
import { ChatSession } from '../methods/chat-session';
3739
import { countTokens } from '../methods/count-tokens';
@@ -52,6 +54,8 @@ export class GenerativeModel {
5254
safetySettings: SafetySetting[];
5355
requestOptions?: RequestOptions;
5456
tools?: Tool[];
57+
toolConfig?: ToolConfig;
58+
systemInstruction?: Content;
5559

5660
constructor(
5761
vertexAI: VertexAI,
@@ -88,6 +92,8 @@ export class GenerativeModel {
8892
this.generationConfig = modelParams.generationConfig || {};
8993
this.safetySettings = modelParams.safetySettings || [];
9094
this.tools = modelParams.tools;
95+
this.toolConfig = modelParams.toolConfig;
96+
this.systemInstruction = modelParams.systemInstruction;
9197
this.requestOptions = requestOptions || {};
9298
}
9399

@@ -106,6 +112,8 @@ export class GenerativeModel {
106112
generationConfig: this.generationConfig,
107113
safetySettings: this.safetySettings,
108114
tools: this.tools,
115+
toolConfig: this.toolConfig,
116+
systemInstruction: this.systemInstruction,
109117
...formattedParams
110118
},
111119
this.requestOptions
@@ -129,6 +137,8 @@ export class GenerativeModel {
129137
generationConfig: this.generationConfig,
130138
safetySettings: this.safetySettings,
131139
tools: this.tools,
140+
toolConfig: this.toolConfig,
141+
systemInstruction: this.systemInstruction,
132142
...formattedParams
133143
},
134144
this.requestOptions
@@ -145,6 +155,8 @@ export class GenerativeModel {
145155
this.model,
146156
{
147157
tools: this.tools,
158+
toolConfig: this.toolConfig,
159+
systemInstruction: this.systemInstruction,
148160
...startChatParams
149161
},
150162
this.requestOptions

packages/vertexai/src/types/enums.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ export type Role = (typeof POSSIBLE_ROLES)[number];
2525
* Possible roles.
2626
* @public
2727
*/
28-
export const POSSIBLE_ROLES = ['user', 'model', 'function'] as const;
28+
export const POSSIBLE_ROLES = ['user', 'model', 'function', 'system'] as const;
2929

3030
/**
3131
* Harm categories that would cause prompts or candidates to be blocked.
@@ -133,3 +133,22 @@ export enum FinishReason {
133133
// Unknown reason.
134134
OTHER = 'OTHER'
135135
}
136+
137+
/**
138+
* @public
139+
*/
140+
export enum FunctionCallingMode {
141+
// Unspecified function calling mode. This value should not be used.
142+
MODE_UNSPECIFIED = 'MODE_UNSPECIFIED',
143+
// Default model behavior, model decides to predict either a function call
144+
// or a natural language repspose.
145+
AUTO = 'AUTO',
146+
// Model is constrained to always predicting a function call only.
147+
// If "allowed_function_names" are set, the predicted function call will be
148+
// limited to any one of "allowed_function_names", else the predicted
149+
// function call will be any one of the provided "function_declarations".
150+
ANY = 'ANY',
151+
// Model will not predict any function call. Model behavior is same as when
152+
// not passing any function declarations.
153+
NONE = 'NONE'
154+
}

packages/vertexai/src/types/requests.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
*/
1717

1818
import { Content } from './content';
19-
import { HarmBlockMethod, HarmBlockThreshold, HarmCategory } from './enums';
19+
import {
20+
FunctionCallingMode,
21+
HarmBlockMethod,
22+
HarmBlockThreshold,
23+
HarmCategory
24+
} from './enums';
2025

2126
/**
2227
* Base parameters for a number of methods.
@@ -34,6 +39,8 @@ export interface BaseParams {
3439
export interface ModelParams extends BaseParams {
3540
model: string;
3641
tools?: Tool[];
42+
toolConfig?: ToolConfig;
43+
systemInstruction?: Content;
3744
}
3845

3946
/**
@@ -43,6 +50,8 @@ export interface ModelParams extends BaseParams {
4350
export interface GenerateContentRequest extends BaseParams {
4451
contents: Content[];
4552
tools?: Tool[];
53+
toolConfig?: ToolConfig;
54+
systemInstruction?: Content;
4655
}
4756

4857
/**
@@ -77,6 +86,8 @@ export interface GenerationConfig {
7786
export interface StartChatParams extends BaseParams {
7887
history?: Content[];
7988
tools?: Tool[];
89+
toolConfig?: ToolConfig;
90+
systemInstruction?: Content;
8091
}
8192

8293
/**
@@ -220,3 +231,19 @@ export interface FunctionDeclarationSchemaProperty {
220231
/** Optional. The example of the property. */
221232
example?: unknown;
222233
}
234+
235+
/**
236+
* Tool config. This config is shared for all tools provided in the request.
237+
* @public
238+
*/
239+
export interface ToolConfig {
240+
functionCallingConfig: FunctionCallingConfig;
241+
}
242+
243+
/**
244+
* @public
245+
*/
246+
export interface FunctionCallingConfig {
247+
mode?: FunctionCallingMode;
248+
allowedFunctionNames?: string[];
249+
}

0 commit comments

Comments
 (0)