From 4dfd31c276af462a03cc6db0ad1a8dbb5fddb899 Mon Sep 17 00:00:00 2001 From: Keerthan Vasist Date: Wed, 6 Jul 2022 14:10:20 -0700 Subject: [PATCH] feat: Add target_model to support multi-model endpoints --- src/sagemaker/clarify.py | 5 +++++ tests/unit/test_clarify.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index eaf78069c3..24fe1f0a48 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -200,6 +200,7 @@ def __init__( custom_attributes=None, accelerator_type=None, endpoint_name_prefix=None, + target_model=None, ): r"""Initializes a configuration of a model and the endpoint to be created for it. @@ -234,6 +235,9 @@ def __init__( for making inferences to the model. endpoint_name_prefix (str): The endpoint name prefix of a new endpoint. Must follow pattern ``^[a-zA-Z0-9](-\*[a-zA-Z0-9]``. + target_model (str): Sets the target model name when using a multi-model endpoint. For + more information about multi-model endpoints, see + https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html Raises: ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid, @@ -281,6 +285,7 @@ def __init__( self.predictor_config["content_template"] = content_template _set(custom_attributes, "custom_attributes", self.predictor_config) _set(accelerator_type, "accelerator_type", self.predictor_config) + _set(target_model, "target_model", self.predictor_config) def get_predictor_config(self): """Returns part of the predictor dictionary of the analysis config.""" diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 0a1d90d74c..1e3ae47f63 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -240,6 +240,7 @@ def test_model_config(): accept_type = "text/csv" content_type = "application/jsonlines" custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4" + target_model = "target_model_name" accelerator_type = "ml.eia1.medium" model_config = ModelConfig( model_name=model_name, @@ -249,6 +250,7 @@ def test_model_config(): content_type=content_type, custom_attributes=custom_attributes, accelerator_type=accelerator_type, + target_model=target_model, ) expected_config = { "model_name": model_name, @@ -258,6 +260,7 @@ def test_model_config(): "content_type": content_type, "custom_attributes": custom_attributes, "accelerator_type": accelerator_type, + "target_model": target_model, } assert expected_config == model_config.get_predictor_config()