Skip to content

Commit 0338896

Browse files
authored
Fix types to match Vertex API and convert functionCall() to functionCalls() (#284)
1 parent 57018ca commit 0338896

20 files changed

+197
-230
lines changed

packages/vertexai/package.json

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,10 @@
3636
"build": "rollup -c",
3737
"build:deps": "lerna run --scope @firebase/vertexai --include-dependencies build",
3838
"dev": "rollup -c -w",
39-
"pretest": "yarn ts-node ./test-utils/convert-mocks.ts",
40-
"test": "run-p --npm-path npm lint test:all",
41-
"test:ci": "node ../../scripts/run_tests_in_ci.js -s test:all",
42-
"test:all": "run-p --npm-path npm test:browser test:node",
43-
"test:browser": "karma start --single-run",
44-
"test:node": "TS_NODE_COMPILER_OPTIONS='{\"module\":\"commonjs\"}' nyc --reporter lcovonly -- mocha src/**/*.test.* --config ../../config/mocharc.node.js"
39+
"testsetup": "yarn ts-node ./test-utils/convert-mocks.ts",
40+
"test": "run-p --npm-path npm lint test:browser",
41+
"test:ci": "yarn testsetup && node ../../scripts/run_tests_in_ci.js -s test",
42+
"test:browser": "yarn testsetup && karma start --single-run"
4543
},
4644
"peerDependencies": {
4745
"@firebase/app": "0.x",

packages/vertexai/rollup.config.js

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ const browserBuilds = [
5252
output: [{ file: pkg.esm5, format: 'es', sourcemap: true }],
5353
plugins: [
5454
...es5BuildPlugins,
55-
replace(generateBuildTargetReplaceConfig('esm', 5))
55+
replace({
56+
...generateBuildTargetReplaceConfig('esm', 5),
57+
__PACKAGE_VERSION__: pkg.version
58+
})
5659
],
5760
external: id => deps.some(dep => id === dep || id.startsWith(`${dep}/`))
5861
},
@@ -65,7 +68,10 @@ const browserBuilds = [
6568
},
6669
plugins: [
6770
...es2017BuildPlugins,
68-
replace(generateBuildTargetReplaceConfig('esm', 2017)),
71+
replace({
72+
...generateBuildTargetReplaceConfig('esm', 2017),
73+
__PACKAGE_VERSION__: pkg.version
74+
}),
6975
emitModulePackageFile()
7076
],
7177
external: id => deps.some(dep => id === dep || id.startsWith(`${dep}/`))
@@ -79,7 +85,10 @@ const browserBuilds = [
7985
},
8086
plugins: [
8187
...es2017BuildPlugins,
82-
replace(generateBuildTargetReplaceConfig('cjs', 2017))
88+
replace({
89+
...generateBuildTargetReplaceConfig('cjs', 2017),
90+
__PACKAGE_VERSION__: pkg.version
91+
})
8392
],
8493
external: id => deps.some(dep => id === dep || id.startsWith(`${dep}/`))
8594
}

packages/vertexai/src/methods/count-tokens.ts

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@ export async function countTokens(
2929
params: CountTokensRequest,
3030
requestOptions?: RequestOptions
3131
): Promise<CountTokensResponse> {
32-
const url = new RequestUrl(model, Task.COUNT_TOKENS, apiSettings, false, {});
32+
const url = new RequestUrl(
33+
model,
34+
Task.COUNT_TOKENS,
35+
apiSettings,
36+
false,
37+
requestOptions
38+
);
3339
const response = await makeRequest(
3440
url,
35-
JSON.stringify({ ...params, model }),
41+
JSON.stringify(params),
3642
requestOptions
3743
);
3844
return response.json();

packages/vertexai/src/methods/embed-content.ts

Lines changed: 0 additions & 67 deletions
This file was deleted.

packages/vertexai/src/methods/generate-content.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ describe('generateContent()', () => {
104104
);
105105
expect(result.response.text()).to.include('Quantum mechanics is');
106106
expect(
107-
result.response.candidates?.[0].citationMetadata?.citationSources.length
107+
result.response.candidates?.[0].citationMetadata?.citations.length
108108
).to.equal(1);
109109
expect(makeRequestStub).to.be.calledWith(
110110
match.instanceOf(request.RequestUrl),

packages/vertexai/src/models/generative-model.ts

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@ import {
2020
generateContentStream
2121
} from '../methods/generate-content';
2222
import {
23-
BatchEmbedContentsRequest,
24-
BatchEmbedContentsResponse,
2523
CountTokensRequest,
2624
CountTokensResponse,
27-
EmbedContentRequest,
28-
EmbedContentResponse,
2925
GenerateContentRequest,
3026
GenerateContentResult,
3127
GenerateContentStreamResult,
@@ -39,11 +35,7 @@ import {
3935
} from '../types';
4036
import { ChatSession } from '../methods/chat-session';
4137
import { countTokens } from '../methods/count-tokens';
42-
import { batchEmbedContents, embedContent } from '../methods/embed-content';
43-
import {
44-
formatEmbedContentInput,
45-
formatGenerateContentInput
46-
} from '../requests/request-helpers';
38+
import { formatGenerateContentInput } from '../requests/request-helpers';
4739
import { Vertex } from '../public-types';
4840
import { ERROR_FACTORY, VertexError } from '../errors';
4941
import { ApiSettings } from '../types/internal';
@@ -163,28 +155,4 @@ export class GenerativeModel {
163155
const formattedParams = formatGenerateContentInput(request);
164156
return countTokens(this._apiSettings, this.model, formattedParams);
165157
}
166-
167-
/**
168-
* Embeds the provided content.
169-
*/
170-
async embedContent(
171-
request: EmbedContentRequest | string | Array<string | Part>
172-
): Promise<EmbedContentResponse> {
173-
const formattedParams = formatEmbedContentInput(request);
174-
return embedContent(this._apiSettings, this.model, formattedParams);
175-
}
176-
177-
/**
178-
* Embeds an array of {@link EmbedContentRequest}s.
179-
*/
180-
async batchEmbedContents(
181-
batchEmbedContentRequest: BatchEmbedContentsRequest
182-
): Promise<BatchEmbedContentsResponse> {
183-
return batchEmbedContents(
184-
this._apiSettings,
185-
this.model,
186-
batchEmbedContentRequest,
187-
this.requestOptions
188-
);
189-
}
190158
}

packages/vertexai/src/requests/request-helpers.ts

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
import {
19-
Content,
20-
EmbedContentRequest,
21-
GenerateContentRequest,
22-
Part
23-
} from '../types';
18+
import { Content, GenerateContentRequest, Part } from '../types';
2419
import { ERROR_FACTORY, VertexError } from '../errors';
2520

2621
export function formatNewContent(
@@ -96,13 +91,3 @@ export function formatGenerateContentInput(
9691
return { contents: [content] };
9792
}
9893
}
99-
100-
export function formatEmbedContentInput(
101-
params: EmbedContentRequest | string | Array<string | Part>
102-
): EmbedContentRequest {
103-
if (typeof params === 'string' || Array.isArray(params)) {
104-
const content = formatNewContent(params);
105-
return { content };
106-
}
107-
return params;
108-
}

packages/vertexai/src/requests/request.test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ describe('request methods', () => {
9090
'/v100omega/projects/my-project/locations/us-central1/models/model-name'
9191
);
9292
});
93+
it('custom baseUrl', async () => {
94+
const url = new RequestUrl(
95+
'models/model-name',
96+
Task.GENERATE_CONTENT,
97+
fakeApiSettings,
98+
false,
99+
{ baseUrl: 'https://my.special.endpoint' }
100+
);
101+
expect(url.toString()).to.include('https://my.special.endpoint');
102+
});
93103
it('non-stream - tunedModels/', async () => {
94104
const url = new RequestUrl(
95105
'tunedModels/model-name',

packages/vertexai/src/requests/request.ts

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,19 @@
1818
import { RequestOptions } from '../types';
1919
import { ERROR_FACTORY, VertexError } from '../errors';
2020
import { ApiSettings } from '../types/internal';
21+
import { version } from '../../package.json';
2122

22-
const BASE_URL = 'https://staging-firebaseml.sandbox.googleapis.com';
23+
const DEFAULT_BASE_URL = 'https://firebaseml.googleapis.com';
2324

2425
export const DEFAULT_API_VERSION = 'v2beta';
2526

26-
/**
27-
* We can't `require` package.json if this runs on web. We will use rollup to
28-
* swap in the version number here at build time.
29-
*/
30-
const PACKAGE_VERSION = '__PACKAGE_VERSION__';
27+
const PACKAGE_VERSION = version;
3128
const PACKAGE_LOG_HEADER = 'firebase-vertexai-js';
3229

3330
export enum Task {
3431
GENERATE_CONTENT = 'generateContent',
3532
STREAM_GENERATE_CONTENT = 'streamGenerateContent',
36-
COUNT_TOKENS = 'countTokens',
37-
EMBED_CONTENT = 'embedContent',
38-
BATCH_EMBED_CONTENTS = 'batchEmbedContents'
33+
COUNT_TOKENS = 'countTokens'
3934
}
4035

4136
export class RequestUrl {
@@ -48,7 +43,8 @@ export class RequestUrl {
4843
) {}
4944
toString(): string {
5045
const apiVersion = this.requestOptions?.apiVersion || DEFAULT_API_VERSION;
51-
let url = `${BASE_URL}/${apiVersion}`;
46+
const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL;
47+
let url = `${baseUrl}/${apiVersion}`;
5248
url += `/projects/${this.apiSettings.project}`;
5349
url += `/locations/${this.apiSettings.location}`;
5450
url += `/${this.model}`;
@@ -58,6 +54,17 @@ export class RequestUrl {
5854
}
5955
return url;
6056
}
57+
58+
/**
59+
* If the model needs to be passed to the backend, it needs to
60+
* include project and location path.
61+
*/
62+
get fullModelString(): string {
63+
let modelString = `projects/${this.apiSettings.project}`;
64+
modelString += `/locations/${this.apiSettings.location}`;
65+
modelString += `/${this.model}`;
66+
return modelString;
67+
}
6168
}
6269

6370
/**

packages/vertexai/src/requests/response-helpers.test.ts

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ const fakeResponseText: GenerateContentResponse = {
3939
}
4040
]
4141
};
42+
4243
const fakeResponseFunctionCall: GenerateContentResponse = {
4344
candidates: [
4445
{
@@ -61,6 +62,38 @@ const fakeResponseFunctionCall: GenerateContentResponse = {
6162
]
6263
};
6364

65+
const fakeResponseFunctionCalls: GenerateContentResponse = {
66+
candidates: [
67+
{
68+
index: 0,
69+
content: {
70+
role: 'model',
71+
parts: [
72+
{
73+
functionCall: {
74+
name: 'find_theaters',
75+
args: {
76+
location: 'Mountain View, CA',
77+
movie: 'Barbie'
78+
}
79+
}
80+
},
81+
{
82+
functionCall: {
83+
name: 'find_times',
84+
args: {
85+
location: 'Mountain View, CA',
86+
movie: 'Barbie',
87+
time: '20:00'
88+
}
89+
}
90+
}
91+
]
92+
}
93+
}
94+
]
95+
};
96+
6497
const badFakeResponse: GenerateContentResponse = {
6598
promptFeedback: {
6699
blockReason: BlockReason.SAFETY,
@@ -79,9 +112,16 @@ describe('response-helpers methods', () => {
79112
});
80113
it('good response functionCall', async () => {
81114
const enhancedResponse = addHelpers(fakeResponseFunctionCall);
82-
expect(enhancedResponse.functionCall()).to.deep.equal(
115+
expect(enhancedResponse.functionCalls()).to.deep.equal([
83116
fakeResponseFunctionCall.candidates?.[0].content.parts[0].functionCall
84-
);
117+
]);
118+
});
119+
it('good response functionCalls', async () => {
120+
const enhancedResponse = addHelpers(fakeResponseFunctionCalls);
121+
expect(enhancedResponse.functionCalls()).to.deep.equal([
122+
fakeResponseFunctionCalls.candidates?.[0].content.parts[0].functionCall,
123+
fakeResponseFunctionCalls.candidates?.[0].content.parts[1].functionCall
124+
]);
85125
});
86126
it('bad response safety', async () => {
87127
const enhancedResponse = addHelpers(badFakeResponse);

0 commit comments

Comments
 (0)