diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 9188acd437..040c6dd71f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1377,6 +1377,7 @@ def deploy( managed_instance_scaling: Optional[str] = None, inference_component_name=None, routing_config: Optional[Dict[str, Any]] = None, + model_reference_arn: Optional[str] = None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1483,6 +1484,8 @@ def deploy( { "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM } + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1697,7 +1700,8 @@ def deploy( accelerator_type=accelerator_type, tags=tags, serverless_inference_config=serverless_inference_config, - **kwargs, + accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) serverless_inference_config_dict = ( serverless_inference_config._to_request_dict() if is_serverless else None diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index baf9d19a54..5b0c4cf7f9 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -771,7 +771,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set(["model_reference_arn"]) - deploy_args_to_skip: Set[str] = set(["kwargs"]) + deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"]) parent_class_init = Model.__init__ parent_class_init_args = set(signature(parent_class_init).parameters.keys()) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 50f6c370d5..6bfb28f684 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -179,6 +179,8 @@ def test_deploy_accelerator_type( accelerator_type=ACCELERATOR_TYPE, tags=None, serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, @@ -299,6 +301,8 @@ def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base, accelerator_type=None, tags=tags, serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) sagemaker_session.endpoint_from_production_variants.assert_called_with( name=ENDPOINT_NAME, @@ -502,6 +506,8 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, accelerator_type=None, tags=None, serverless_inference_config=serverless_inference_config, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME,