Skip to content

Commit 7fa4dce

Browse files
feat: Add target_model to support multi-model endpoints
1 parent 46c68a9 commit 7fa4dce

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/sagemaker/clarify.py

+4
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def __init__(
185185
custom_attributes=None,
186186
accelerator_type=None,
187187
endpoint_name_prefix=None,
188+
target_model=None,
188189
):
189190
r"""Initializes a configuration of a model and the endpoint to be created for it.
190191
@@ -218,6 +219,8 @@ def __init__(
218219
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
219220
endpoint_name_prefix (str): The endpoint name prefix of a new endpoint. Must follow
220221
pattern "^[a-zA-Z0-9](-\*[a-zA-Z0-9]".
222+
target_model (str): Sets the target model name when using a multi-model endpoint. For more information
223+
about multi-model endpoints
221224
"""
222225
self.predictor_config = {
223226
"model_name": model_name,
@@ -261,6 +264,7 @@ def __init__(
261264
self.predictor_config["content_template"] = content_template
262265
_set(custom_attributes, "custom_attributes", self.predictor_config)
263266
_set(accelerator_type, "accelerator_type", self.predictor_config)
267+
_set(target_model, "target_model", self.predictor_config)
264268

265269
def get_predictor_config(self):
266270
"""Returns part of the predictor dictionary of the analysis config."""

tests/unit/test_clarify.py

+2
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def test_model_config():
240240
accept_type = "text/csv"
241241
content_type = "application/jsonlines"
242242
custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
243+
target_model = "target_model_name"
243244
accelerator_type = "ml.eia1.medium"
244245
model_config = ModelConfig(
245246
model_name=model_name,
@@ -258,6 +259,7 @@ def test_model_config():
258259
"content_type": content_type,
259260
"custom_attributes": custom_attributes,
260261
"accelerator_type": accelerator_type,
262+
"target_model_name": target_model,
261263
}
262264
assert expected_config == model_config.get_predictor_config()
263265

0 commit comments

Comments
 (0)