diff --git a/common/api-review/vertexai.api.md b/common/api-review/vertexai.api.md
index fc7d9182586..5c8ef330cbe 100644
--- a/common/api-review/vertexai.api.md
+++ b/common/api-review/vertexai.api.md
@@ -344,7 +344,7 @@ export class GenerativeModel extends VertexAIModel {
}
// @public
-export function getGenerativeModel(vertexAI: VertexAI, onCloudOrHybridParams: ModelParams | HybridParams, requestOptions?: RequestOptions): GenerativeModel;
+export function getGenerativeModel(vertexAI: VertexAI, modelParams: ModelParams | HybridParams, requestOptions?: RequestOptions): GenerativeModel;
// @beta
export function getImagenModel(vertexAI: VertexAI, modelParams: ImagenModelParams, requestOptions?: RequestOptions): ImagenModel;
diff --git a/docs-devsite/vertexai.md b/docs-devsite/vertexai.md
index a3ba28ad609..305d0f09b61 100644
--- a/docs-devsite/vertexai.md
+++ b/docs-devsite/vertexai.md
@@ -19,7 +19,7 @@ The Vertex AI in Firebase Web SDK.
| function(app, ...) |
| [getVertexAI(app, options)](./vertexai.md#getvertexai_04094cf) | Returns a [VertexAI](./vertexai.vertexai.md#vertexai_interface) instance for the given app. |
| function(vertexAI, ...) |
-| [getGenerativeModel(vertexAI, onCloudOrHybridParams, requestOptions)](./vertexai.md#getgenerativemodel_202434f) | Returns a [GenerativeModel](./vertexai.generativemodel.md#generativemodel_class) class with methods for inference and other functionality. |
+| [getGenerativeModel(vertexAI, modelParams, requestOptions)](./vertexai.md#getgenerativemodel_8dbc150) | Returns a [GenerativeModel](./vertexai.generativemodel.md#generativemodel_class) class with methods for inference and other functionality. |
| [getImagenModel(vertexAI, modelParams, requestOptions)](./vertexai.md#getimagenmodel_812c375) | (Public Preview) Returns an [ImagenModel](./vertexai.imagenmodel.md#imagenmodel_class) class with methods for using Imagen.Only Imagen 3 models (named imagen-3.0-*
) are supported. |
## Classes
@@ -101,10 +101,10 @@ The Vertex AI in Firebase Web SDK.
| [ImagenSafetySettings](./vertexai.imagensafetysettings.md#imagensafetysettings_interface) | (Public Preview) Settings for controlling the aggressiveness of filtering out sensitive content.See the [documentation](http://firebase.google.com/docs/vertex-ai/generate-images) for more details. |
| [InlineDataPart](./vertexai.inlinedatapart.md#inlinedatapart_interface) | Content part interface if the part represents an image. |
| [ModalityTokenCount](./vertexai.modalitytokencount.md#modalitytokencount_interface) | Represents token counting info for a single modality. |
-| [ModelParams](./vertexai.modelparams.md#modelparams_interface) | Params passed to [getGenerativeModel()](./vertexai.md#getgenerativemodel_202434f). |
+| [ModelParams](./vertexai.modelparams.md#modelparams_interface) | Params passed to [getGenerativeModel()](./vertexai.md#getgenerativemodel_8dbc150). |
| [ObjectSchemaInterface](./vertexai.objectschemainterface.md#objectschemainterface_interface) | Interface for [ObjectSchema](./vertexai.objectschema.md#objectschema_class) class. |
| [PromptFeedback](./vertexai.promptfeedback.md#promptfeedback_interface) | If the prompt was blocked, this will be populated with blockReason
and the relevant safetyRatings
. |
-| [RequestOptions](./vertexai.requestoptions.md#requestoptions_interface) | Params passed to [getGenerativeModel()](./vertexai.md#getgenerativemodel_202434f). |
+| [RequestOptions](./vertexai.requestoptions.md#requestoptions_interface) | Params passed to [getGenerativeModel()](./vertexai.md#getgenerativemodel_8dbc150). |
| [RetrievedContextAttribution](./vertexai.retrievedcontextattribution.md#retrievedcontextattribution_interface) | |
| [SafetyRating](./vertexai.safetyrating.md#safetyrating_interface) | A safety rating associated with a [GenerateContentCandidate](./vertexai.generatecontentcandidate.md#generatecontentcandidate_interface) |
| [SafetySetting](./vertexai.safetysetting.md#safetysetting_interface) | Safety setting that can be sent as part of request parameters. |
@@ -162,14 +162,14 @@ export declare function getVertexAI(app?: FirebaseApp, options?: VertexAIOptions
## function(vertexAI, ...)
-### getGenerativeModel(vertexAI, onCloudOrHybridParams, requestOptions) {:#getgenerativemodel_202434f}
+### getGenerativeModel(vertexAI, modelParams, requestOptions) {:#getgenerativemodel_8dbc150}
Returns a [GenerativeModel](./vertexai.generativemodel.md#generativemodel_class) class with methods for inference and other functionality.
Signature:
```typescript
-export declare function getGenerativeModel(vertexAI: VertexAI, onCloudOrHybridParams: ModelParams | HybridParams, requestOptions?: RequestOptions): GenerativeModel;
+export declare function getGenerativeModel(vertexAI: VertexAI, modelParams: ModelParams | HybridParams, requestOptions?: RequestOptions): GenerativeModel;
```
#### Parameters
@@ -177,7 +177,7 @@ export declare function getGenerativeModel(vertexAI: VertexAI, onCloudOrHybridPa
| Parameter | Type | Description |
| --- | --- | --- |
| vertexAI | [VertexAI](./vertexai.vertexai.md#vertexai_interface) | |
-| onCloudOrHybridParams | [ModelParams](./vertexai.modelparams.md#modelparams_interface) \| [HybridParams](./vertexai.hybridparams.md#hybridparams_interface) | |
+| modelParams | [ModelParams](./vertexai.modelparams.md#modelparams_interface) \| [HybridParams](./vertexai.hybridparams.md#hybridparams_interface) | |
| requestOptions | [RequestOptions](./vertexai.requestoptions.md#requestoptions_interface) | |
Returns:
diff --git a/docs-devsite/vertexai.modelparams.md b/docs-devsite/vertexai.modelparams.md
index 6645d498d8e..0776b198cf1 100644
--- a/docs-devsite/vertexai.modelparams.md
+++ b/docs-devsite/vertexai.modelparams.md
@@ -10,7 +10,7 @@ https://github.com/firebase/firebase-js-sdk
{% endcomment %}
# ModelParams interface
-Params passed to [getGenerativeModel()](./vertexai.md#getgenerativemodel_202434f).
+Params passed to [getGenerativeModel()](./vertexai.md#getgenerativemodel_8dbc150).
Signature:
diff --git a/docs-devsite/vertexai.requestoptions.md b/docs-devsite/vertexai.requestoptions.md
index 334ce7956d6..4e1ce2b86e3 100644
--- a/docs-devsite/vertexai.requestoptions.md
+++ b/docs-devsite/vertexai.requestoptions.md
@@ -10,7 +10,7 @@ https://github.com/firebase/firebase-js-sdk
{% endcomment %}
# RequestOptions interface
-Params passed to [getGenerativeModel()](./vertexai.md#getgenerativemodel_202434f).
+Params passed to [getGenerativeModel()](./vertexai.md#getgenerativemodel_8dbc150).
Signature:
diff --git a/packages/vertexai/src/api.test.ts b/packages/vertexai/src/api.test.ts
index 4a0b978d858..a38358f806f 100644
--- a/packages/vertexai/src/api.test.ts
+++ b/packages/vertexai/src/api.test.ts
@@ -14,7 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-import { ImagenModelParams, ModelParams, VertexAIErrorCode } from './types';
+import {
+ ImagenModelParams,
+ InferenceMode,
+ ModelParams,
+ VertexAIErrorCode
+} from './types';
import { VertexAIError } from './errors';
import { ImagenModel, getGenerativeModel, getImagenModel } from './api';
import { expect } from 'chai';
@@ -112,6 +117,13 @@ describe('Top level API', () => {
);
}
});
+ it('getGenerativeModel with HybridParams sets the model', () => {
+ const genModel = getGenerativeModel(fakeVertexAI, {
+ mode: InferenceMode.ONLY_ON_CLOUD,
+ onCloudParams: { model: 'my-model' }
+ });
+ expect(genModel.model).to.equal('publishers/google/models/my-model');
+ });
it('getImagenModel throws if no apiKey is provided', () => {
const fakeVertexNoApiKey = {
...fakeVertexAI,
diff --git a/packages/vertexai/src/api.ts b/packages/vertexai/src/api.ts
index 323cfd10e80..7f11dd80844 100644
--- a/packages/vertexai/src/api.ts
+++ b/packages/vertexai/src/api.ts
@@ -71,18 +71,18 @@ export function getVertexAI(
*/
export function getGenerativeModel(
vertexAI: VertexAI,
- onCloudOrHybridParams: ModelParams | HybridParams,
+ modelParams: ModelParams | HybridParams,
requestOptions?: RequestOptions
): GenerativeModel {
- // Disambiguates onCloudOrHybridParams input.
- const hybridParams = onCloudOrHybridParams as HybridParams;
+ // Uses the existence of HybridParams.mode to clarify the type of the modelParams input.
+ const hybridParams = modelParams as HybridParams;
let onCloudParams: ModelParams;
if (hybridParams.mode) {
onCloudParams = hybridParams.onCloudParams || {
model: 'gemini-2.0-flash-lite'
};
} else {
- onCloudParams = onCloudOrHybridParams as ModelParams;
+ onCloudParams = modelParams as ModelParams;
}
if (!onCloudParams.model) {