Skip to content

Commit 85ee540

Browse files
authored
Merge a49dc31 into b981a45
2 parents b981a45 + a49dc31 commit 85ee540

File tree

6 files changed

+72
-17
lines changed

6 files changed

+72
-17
lines changed

packages/vertexai/src/api.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import { Provider } from '@firebase/component';
2020
import { getModularInstance } from '@firebase/util';
2121
import { DEFAULT_LOCATION, VERTEX_TYPE } from './constants';
2222
import { VertexAIService } from './service';
23-
import { VertexAI } from './public-types';
23+
import { VertexAI, VertexAIOptions } from './public-types';
2424
import { ERROR_FACTORY, VertexError } from './errors';
2525
import { ModelParams, RequestOptions } from './types';
2626
import { GenerativeModel } from './models/generative-model';
@@ -42,13 +42,16 @@ declare module '@firebase/component' {
4242
*
4343
* @param app - The {@link @firebase/app#FirebaseApp} to use.
4444
*/
45-
export function getVertexAI(app: FirebaseApp = getApp()): VertexAI {
45+
export function getVertexAI(
46+
app: FirebaseApp = getApp(),
47+
options?: VertexAIOptions
48+
): VertexAI {
4649
app = getModularInstance(app);
4750
// Dependencies
4851
const vertexProvider: Provider<'vertex'> = _getProvider(app, VERTEX_TYPE);
4952

5053
return vertexProvider.getImmediate({
51-
identifier: DEFAULT_LOCATION
54+
identifier: options?.location || DEFAULT_LOCATION
5255
});
5356
}
5457

packages/vertexai/src/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ function registerVertex(): void {
3737
_registerComponent(
3838
new Component(
3939
VERTEX_TYPE,
40-
container => {
40+
(container, { instanceIdentifier: location }) => {
4141
// getImmediate for FirebaseApp will always succeed
4242
const app = container.getProvider('app').getImmediate();
4343
const appCheckProvider = container.getProvider('app-check-internal');
44-
return new VertexAIService(app, appCheckProvider);
44+
return new VertexAIService(app, appCheckProvider, { location });
4545
},
4646
ComponentType.PUBLIC
4747
).setMultipleInstances(true)

packages/vertexai/src/public-types.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ export interface VertexAI {
3030
app: FirebaseApp;
3131
location: string;
3232
}
33+
34+
/**
35+
* Options when initializing the Firebase Vertex AI SDK.
36+
* @public
37+
*/
38+
export interface VertexAIOptions {
39+
location?: string;
40+
}

packages/vertexai/src/requests/stream-reader.test.ts

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ import {
3232
FinishReason,
3333
GenerateContentResponse,
3434
HarmCategory,
35-
HarmProbability
35+
HarmProbability,
36+
SafetyRating
3637
} from '../types';
3738

3839
use(sinonChai);
@@ -229,7 +230,7 @@ describe('aggregateResponses', () => {
229230
{
230231
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
231232
probability: HarmProbability.LOW
232-
}
233+
} as SafetyRating
233234
]
234235
}
235236
}
@@ -256,7 +257,7 @@ describe('aggregateResponses', () => {
256257
{
257258
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
258259
probability: HarmProbability.NEGLIGIBLE
259-
}
260+
} as SafetyRating
260261
]
261262
}
262263
],
@@ -266,7 +267,7 @@ describe('aggregateResponses', () => {
266267
{
267268
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
268269
probability: HarmProbability.LOW
269-
}
270+
} as SafetyRating
270271
]
271272
}
272273
},
@@ -284,7 +285,7 @@ describe('aggregateResponses', () => {
284285
{
285286
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
286287
probability: HarmProbability.NEGLIGIBLE
287-
}
288+
} as SafetyRating
288289
],
289290
citationMetadata: {
290291
citations: [
@@ -304,7 +305,7 @@ describe('aggregateResponses', () => {
304305
{
305306
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
306307
probability: HarmProbability.HIGH
307-
}
308+
} as SafetyRating
308309
]
309310
}
310311
},
@@ -322,7 +323,7 @@ describe('aggregateResponses', () => {
322323
{
323324
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
324325
probability: HarmProbability.MEDIUM
325-
}
326+
} as SafetyRating
326327
],
327328
citationMetadata: {
328329
citations: [
@@ -348,7 +349,7 @@ describe('aggregateResponses', () => {
348349
{
349350
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
350351
probability: HarmProbability.HIGH
351-
}
352+
} as SafetyRating
352353
]
353354
}
354355
}

packages/vertexai/src/service.test.ts

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/**
2+
* @license
3+
* Copyright 2024 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
import { DEFAULT_LOCATION } from './constants';
18+
import { VertexAIService } from './service';
19+
import { expect } from 'chai';
20+
21+
const fakeApp = {
22+
name: 'DEFAULT',
23+
automaticDataCollectionEnabled: true,
24+
options: {
25+
apiKey: 'key',
26+
projectId: 'my-project'
27+
}
28+
};
29+
30+
describe('VertexAIService', () => {
31+
it('uses default location if not specified', () => {
32+
const vertexAI = new VertexAIService(fakeApp);
33+
expect(vertexAI.location).to.equal(DEFAULT_LOCATION);
34+
});
35+
it('uses custom location if specified', () => {
36+
const vertexAI = new VertexAIService(
37+
fakeApp,
38+
/* appCheckProvider */ undefined,
39+
{ location: 'somewhere' }
40+
);
41+
expect(vertexAI.location).to.equal('somewhere');
42+
});
43+
});

packages/vertexai/src/service.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import { FirebaseApp, _FirebaseService } from '@firebase/app';
19-
import { VertexAI } from './public-types';
19+
import { VertexAI, VertexAIOptions } from './public-types';
2020
import {
2121
AppCheckInternalComponentName,
2222
FirebaseAppCheckInternal
@@ -30,12 +30,12 @@ export class VertexAIService implements VertexAI, _FirebaseService {
3030

3131
constructor(
3232
public app: FirebaseApp,
33-
appCheckProvider?: Provider<AppCheckInternalComponentName>
33+
appCheckProvider?: Provider<AppCheckInternalComponentName>,
34+
public options?: VertexAIOptions
3435
) {
3536
const appCheck = appCheckProvider?.getImmediate({ optional: true });
3637
this.appCheck = appCheck || null;
37-
// TODO: add in user-set location option when that feature is available
38-
this.location = DEFAULT_LOCATION;
38+
this.location = this.options?.location || DEFAULT_LOCATION;
3939
}
4040

4141
_delete(): Promise<void> {

0 commit comments

Comments
 (0)