Skip to content

Commit 627b561

Browse files
authored
Allow text-only systemInstruction (#8208)
1 parent 1aadc47 commit 627b561

File tree

5 files changed

+265
-8
lines changed

5 files changed

+265
-8
lines changed

packages/vertexai/src/models/generative-model.test.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,31 @@ describe('GenerativeModel', () => {
9494
);
9595
restore();
9696
});
97+
it('passes text-only systemInstruction through to generateContent', async () => {
98+
const genModel = new GenerativeModel(fakeVertexAI, {
99+
model: 'my-model',
100+
systemInstruction: 'be friendly'
101+
});
102+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
103+
const mockResponse = getMockResponse(
104+
'unary-success-basic-reply-short.json'
105+
);
106+
const makeRequestStub = stub(request, 'makeRequest').resolves(
107+
mockResponse as Response
108+
);
109+
await genModel.generateContent('hello');
110+
expect(makeRequestStub).to.be.calledWith(
111+
'publishers/google/models/my-model',
112+
request.Task.GENERATE_CONTENT,
113+
match.any,
114+
false,
115+
match((value: string) => {
116+
return value.includes('be friendly');
117+
}),
118+
{}
119+
);
120+
restore();
121+
});
97122
it('generateContent overrides model values', async () => {
98123
const genModel = new GenerativeModel(fakeVertexAI, {
99124
model: 'my-model',
@@ -169,6 +194,31 @@ describe('GenerativeModel', () => {
169194
);
170195
restore();
171196
});
197+
it('passes text-only systemInstruction through to chat.sendMessage', async () => {
198+
const genModel = new GenerativeModel(fakeVertexAI, {
199+
model: 'my-model',
200+
systemInstruction: 'be friendly'
201+
});
202+
expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly');
203+
const mockResponse = getMockResponse(
204+
'unary-success-basic-reply-short.json'
205+
);
206+
const makeRequestStub = stub(request, 'makeRequest').resolves(
207+
mockResponse as Response
208+
);
209+
await genModel.startChat().sendMessage('hello');
210+
expect(makeRequestStub).to.be.calledWith(
211+
'publishers/google/models/my-model',
212+
request.Task.GENERATE_CONTENT,
213+
match.any,
214+
false,
215+
match((value: string) => {
216+
return value.includes('be friendly');
217+
}),
218+
{}
219+
);
220+
restore();
221+
});
172222
it('startChat overrides model values', async () => {
173223
const genModel = new GenerativeModel(fakeVertexAI, {
174224
model: 'my-model',

packages/vertexai/src/models/generative-model.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ import {
3737
} from '../types';
3838
import { ChatSession } from '../methods/chat-session';
3939
import { countTokens } from '../methods/count-tokens';
40-
import { formatGenerateContentInput } from '../requests/request-helpers';
40+
import {
41+
formatGenerateContentInput,
42+
formatSystemInstruction
43+
} from '../requests/request-helpers';
4144
import { VertexAI } from '../public-types';
4245
import { ERROR_FACTORY, VertexError } from '../errors';
4346
import { ApiSettings } from '../types/internal';
@@ -93,7 +96,9 @@ export class GenerativeModel {
9396
this.safetySettings = modelParams.safetySettings || [];
9497
this.tools = modelParams.tools;
9598
this.toolConfig = modelParams.toolConfig;
96-
this.systemInstruction = modelParams.systemInstruction;
99+
this.systemInstruction = formatSystemInstruction(
100+
modelParams.systemInstruction
101+
);
97102
this.requestOptions = requestOptions || {};
98103
}
99104

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/**
2+
* @license
3+
* Copyright 2024 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
import { expect, use } from 'chai';
19+
import sinonChai from 'sinon-chai';
20+
import { Content } from '../types';
21+
import { formatGenerateContentInput } from './request-helpers';
22+
23+
use(sinonChai);
24+
25+
describe('request formatting methods', () => {
26+
describe('formatGenerateContentInput', () => {
27+
it('formats a text string into a request', () => {
28+
const result = formatGenerateContentInput('some text content');
29+
expect(result).to.deep.equal({
30+
contents: [
31+
{
32+
role: 'user',
33+
parts: [{ text: 'some text content' }]
34+
}
35+
]
36+
});
37+
});
38+
it('formats an array of strings into a request', () => {
39+
const result = formatGenerateContentInput(['txt1', 'txt2']);
40+
expect(result).to.deep.equal({
41+
contents: [
42+
{
43+
role: 'user',
44+
parts: [{ text: 'txt1' }, { text: 'txt2' }]
45+
}
46+
]
47+
});
48+
});
49+
it('formats an array of Parts into a request', () => {
50+
const result = formatGenerateContentInput([
51+
{ text: 'txt1' },
52+
{ text: 'txtB' }
53+
]);
54+
expect(result).to.deep.equal({
55+
contents: [
56+
{
57+
role: 'user',
58+
parts: [{ text: 'txt1' }, { text: 'txtB' }]
59+
}
60+
]
61+
});
62+
});
63+
it('formats a mixed array into a request', () => {
64+
const result = formatGenerateContentInput(['txtA', { text: 'txtB' }]);
65+
expect(result).to.deep.equal({
66+
contents: [
67+
{
68+
role: 'user',
69+
parts: [{ text: 'txtA' }, { text: 'txtB' }]
70+
}
71+
]
72+
});
73+
});
74+
it('preserves other properties of request', () => {
75+
const result = formatGenerateContentInput({
76+
contents: [
77+
{
78+
role: 'user',
79+
parts: [{ text: 'txtA' }]
80+
}
81+
],
82+
generationConfig: { topK: 100 }
83+
});
84+
expect(result).to.deep.equal({
85+
contents: [
86+
{
87+
role: 'user',
88+
parts: [{ text: 'txtA' }]
89+
}
90+
],
91+
generationConfig: { topK: 100 }
92+
});
93+
});
94+
it('formats systemInstructions if provided as text', () => {
95+
const result = formatGenerateContentInput({
96+
contents: [
97+
{
98+
role: 'user',
99+
parts: [{ text: 'txtA' }]
100+
}
101+
],
102+
systemInstruction: 'be excited'
103+
});
104+
expect(result).to.deep.equal({
105+
contents: [
106+
{
107+
role: 'user',
108+
parts: [{ text: 'txtA' }]
109+
}
110+
],
111+
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
112+
});
113+
});
114+
it('formats systemInstructions if provided as Part', () => {
115+
const result = formatGenerateContentInput({
116+
contents: [
117+
{
118+
role: 'user',
119+
parts: [{ text: 'txtA' }]
120+
}
121+
],
122+
systemInstruction: { text: 'be excited' }
123+
});
124+
expect(result).to.deep.equal({
125+
contents: [
126+
{
127+
role: 'user',
128+
parts: [{ text: 'txtA' }]
129+
}
130+
],
131+
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
132+
});
133+
});
134+
it('formats systemInstructions if provided as Content (no role)', () => {
135+
const result = formatGenerateContentInput({
136+
contents: [
137+
{
138+
role: 'user',
139+
parts: [{ text: 'txtA' }]
140+
}
141+
],
142+
systemInstruction: { parts: [{ text: 'be excited' }] } as Content
143+
});
144+
expect(result).to.deep.equal({
145+
contents: [
146+
{
147+
role: 'user',
148+
parts: [{ text: 'txtA' }]
149+
}
150+
],
151+
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
152+
});
153+
});
154+
it('passes thru systemInstructions if provided as Content', () => {
155+
const result = formatGenerateContentInput({
156+
contents: [
157+
{
158+
role: 'user',
159+
parts: [{ text: 'txtA' }]
160+
}
161+
],
162+
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
163+
});
164+
expect(result).to.deep.equal({
165+
contents: [
166+
{
167+
role: 'user',
168+
parts: [{ text: 'txtA' }]
169+
}
170+
],
171+
systemInstruction: { role: 'system', parts: [{ text: 'be excited' }] }
172+
});
173+
});
174+
});
175+
});

packages/vertexai/src/requests/request-helpers.ts

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,25 @@
1818
import { Content, GenerateContentRequest, Part } from '../types';
1919
import { ERROR_FACTORY, VertexError } from '../errors';
2020

21+
export function formatSystemInstruction(
22+
input?: string | Part | Content
23+
): Content | undefined {
24+
// null or undefined
25+
if (input == null) {
26+
return undefined;
27+
} else if (typeof input === 'string') {
28+
return { role: 'system', parts: [{ text: input }] } as Content;
29+
} else if ((input as Part).text) {
30+
return { role: 'system', parts: [input as Part] };
31+
} else if ((input as Content).parts) {
32+
if (!(input as Content).role) {
33+
return { role: 'system', parts: (input as Content).parts };
34+
} else {
35+
return input as Content;
36+
}
37+
}
38+
}
39+
2140
export function formatNewContent(
2241
request: string | Array<string | Part>
2342
): Content {
@@ -84,10 +103,18 @@ function assignRoleToPartsAndValidateSendMessageRequest(
84103
export function formatGenerateContentInput(
85104
params: GenerateContentRequest | string | Array<string | Part>
86105
): GenerateContentRequest {
106+
let formattedRequest: GenerateContentRequest;
87107
if ((params as GenerateContentRequest).contents) {
88-
return params as GenerateContentRequest;
108+
formattedRequest = params as GenerateContentRequest;
89109
} else {
110+
// Array or string
90111
const content = formatNewContent(params as string | Array<string | Part>);
91-
return { contents: [content] };
112+
formattedRequest = { contents: [content] };
113+
}
114+
if ((params as GenerateContentRequest).systemInstruction) {
115+
formattedRequest.systemInstruction = formatSystemInstruction(
116+
(params as GenerateContentRequest).systemInstruction
117+
);
92118
}
119+
return formattedRequest;
93120
}

packages/vertexai/src/types/requests.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
import { Content } from './content';
18+
import { Content, Part } from './content';
1919
import {
2020
FunctionCallingMode,
2121
HarmBlockMethod,
@@ -40,7 +40,7 @@ export interface ModelParams extends BaseParams {
4040
model: string;
4141
tools?: Tool[];
4242
toolConfig?: ToolConfig;
43-
systemInstruction?: Content;
43+
systemInstruction?: string | Part | Content;
4444
}
4545

4646
/**
@@ -51,7 +51,7 @@ export interface GenerateContentRequest extends BaseParams {
5151
contents: Content[];
5252
tools?: Tool[];
5353
toolConfig?: ToolConfig;
54-
systemInstruction?: Content;
54+
systemInstruction?: string | Part | Content;
5555
}
5656

5757
/**
@@ -87,7 +87,7 @@ export interface StartChatParams extends BaseParams {
8787
history?: Content[];
8888
tools?: Tool[];
8989
toolConfig?: ToolConfig;
90-
systemInstruction?: Content;
90+
systemInstruction?: string | Part | Content;
9191
}
9292

9393
/**

0 commit comments

Comments
 (0)