Skip to content

Commit a5464a2

Browse files
feat: Add target_model to support multi-model endpoints (#3215)
1 parent 8aa8819 commit a5464a2

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/sagemaker/clarify.py

+5
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def __init__(
200200
custom_attributes=None,
201201
accelerator_type=None,
202202
endpoint_name_prefix=None,
203+
target_model=None,
203204
):
204205
r"""Initializes a configuration of a model and the endpoint to be created for it.
205206
@@ -234,6 +235,9 @@ def __init__(
234235
for making inferences to the model.
235236
endpoint_name_prefix (str): The endpoint name prefix of a new endpoint. Must follow
236237
pattern ``^[a-zA-Z0-9](-\*[a-zA-Z0-9]``.
238+
target_model (str): Sets the target model name when using a multi-model endpoint. For
239+
more information about multi-model endpoints, see
240+
https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html
237241
238242
Raises:
239243
ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid,
@@ -281,6 +285,7 @@ def __init__(
281285
self.predictor_config["content_template"] = content_template
282286
_set(custom_attributes, "custom_attributes", self.predictor_config)
283287
_set(accelerator_type, "accelerator_type", self.predictor_config)
288+
_set(target_model, "target_model", self.predictor_config)
284289

285290
def get_predictor_config(self):
286291
"""Returns part of the predictor dictionary of the analysis config."""

tests/unit/test_clarify.py

+3
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def test_model_config():
240240
accept_type = "text/csv"
241241
content_type = "application/jsonlines"
242242
custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
243+
target_model = "target_model_name"
243244
accelerator_type = "ml.eia1.medium"
244245
model_config = ModelConfig(
245246
model_name=model_name,
@@ -249,6 +250,7 @@ def test_model_config():
249250
content_type=content_type,
250251
custom_attributes=custom_attributes,
251252
accelerator_type=accelerator_type,
253+
target_model=target_model,
252254
)
253255
expected_config = {
254256
"model_name": model_name,
@@ -258,6 +260,7 @@ def test_model_config():
258260
"content_type": content_type,
259261
"custom_attributes": custom_attributes,
260262
"accelerator_type": accelerator_type,
263+
"target_model": target_model,
261264
}
262265
assert expected_config == model_config.get_predictor_config()
263266

0 commit comments

Comments
 (0)