Skip to content

Commit b7d6a59

Browse files
committed
feature: add Predictor.update_endpoint()
1 parent d388519 commit b7d6a59

File tree

6 files changed

+288
-50
lines changed

6 files changed

+288
-50
lines changed

src/sagemaker/predictor.py

+101-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
2424
from sagemaker.model_monitor import DataCaptureConfig
25-
from sagemaker.session import Session
25+
from sagemaker.session import production_variant, Session
2626
from sagemaker.utils import name_from_base
2727

2828
from sagemaker.model_monitor.model_monitoring import (
@@ -157,6 +157,106 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
157157
args["Body"] = data
158158
return args
159159

160+
def update_endpoint(
161+
self,
162+
initial_instance_count=None,
163+
instance_type=None,
164+
accelerator_type=None,
165+
model_name=None,
166+
tags=None,
167+
kms_key=None,
168+
data_capture_config_dict=None,
169+
wait=True,
170+
):
171+
"""Update the existing endpoint with the provided attributes.
172+
173+
This creates a new EndpointConfig in the process. If ``initial_instance_count``,
174+
``instance_type``, ``accelerator_type``, or ``model_name`` is specified, then a new
175+
``ProductionVariant`` configuration is created; values from the existing configuration
176+
are not preserved if any of those parameters are specified.
177+
178+
Args:
179+
initial_instance_count (int): The initial number of instances to run in the endpoint.
180+
This is required if ``instance_type``, ``accelerator_type``, or ``model_name`` is
181+
specified. Otherwise, the values from the existing endpoint configuration's
182+
``ProductionVariant``s are used.
183+
instance_type (str): The EC2 instance type to deploy the endpoint to.
184+
This is required if ``initial_instance_count`` or ``accelerator_type`` is specified.
185+
Otherwise, the values from the existing endpoint configuration's
186+
``ProductionVariant``s are used.
187+
accelerator_type (str): The type of Elastic Inference accelerator to attach to
188+
the endpoint, e.g. 'ml.eia1.medium'. If not specified, and
189+
``initial_instance_count``, ``instance_type``, and ``model_name`` are also ``None``,
190+
the values from the existing endpoint configuration's ``ProductionVariant``s are
191+
used. Otherwise, no Elastic Inference accelerator is attached to the endpoint.
192+
model_name (str): The name of the model to be associated with the endpoint.
193+
This is required if ``initial_instance_count``, ``instance_type``, or
194+
``accelerator_type`` is specified and if there is more than one model associated
195+
with the endpoint. Otherwise, the existing model for the endpoint is used.
196+
tags (list[dict[str, str]]): The list of tags to add to the endpoint
197+
config. If not specified, the tags of the existing endpoint configuration are used.
198+
If any of the existing tags are reserved AWS ones (i.e. begin with "aws"),
199+
they are not carried over to the new endpoint configuration.
200+
kms_key (str): The KMS key that is used to encrypt the data on the storage volume
201+
attached to the instance hosting the endpoint If not specified,
202+
the KMS key of the existing endpoint configuration is used.
203+
data_capture_config_dict (dict): The endpoint data capture configuration
204+
for use with Amazon SageMaker Model Monitoring. If not specified,
205+
the data capture configuration of the existing endpoint configuration is used.
206+
207+
Raises:
208+
ValueError: If there is not enough information to create a new ``ProductionVariant``:
209+
210+
- If ``initial_instance_count``, ``accelerator_type``, or ``model_name`` is
211+
specified, but ``instance_type`` is ``None``.
212+
- If ``initial_instance_count``, ``instance_type``, or ``accelerator_type`` is
213+
specified and either ``model_name`` is ``None`` or there are multiple models
214+
associated with the endpoint.
215+
"""
216+
production_variants = None
217+
218+
if initial_instance_count or instance_type or accelerator_type or model_name:
219+
if instance_type is None or initial_instance_count is None:
220+
raise ValueError(
221+
"Missing initial_instance_count and/or instance_type. Provided values: "
222+
"initial_instance_count={}, instance_type={}, accelerator_type={}, "
223+
"model_name={}.".format(
224+
initial_instance_count, instance_type, accelerator_type, model_name
225+
)
226+
)
227+
228+
if model_name is None:
229+
if len(self._model_names) > 1:
230+
raise ValueError(
231+
"Unable to choose a default model for a new EndpointConfig because "
232+
"the endpoint has multiple models: {}".format(", ".join(self._model_names))
233+
)
234+
model_name = self._model_names[0]
235+
else:
236+
self._model_names = [model_name]
237+
238+
production_variant_config = production_variant(
239+
model_name,
240+
instance_type,
241+
initial_instance_count=initial_instance_count,
242+
accelerator_type=accelerator_type,
243+
)
244+
production_variants = [production_variant_config]
245+
246+
new_endpoint_config_name = name_from_base(self._endpoint_config_name)
247+
self.sagemaker_session.create_endpoint_config_from_existing(
248+
self._endpoint_config_name,
249+
new_endpoint_config_name,
250+
new_tags=tags,
251+
new_kms_key=kms_key,
252+
new_data_capture_config_dict=data_capture_config_dict,
253+
new_production_variants=production_variants,
254+
)
255+
self.sagemaker_session.update_endpoint(
256+
self.endpoint_name, new_endpoint_config_name, wait=wait
257+
)
258+
self._endpoint_config_name = new_endpoint_config_name
259+
160260
def _delete_endpoint_config(self):
161261
"""Delete the Amazon SageMaker endpoint configuration"""
162262
self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name)

src/sagemaker/session.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,7 @@ def create_endpoint_config_from_existing(
23332333
new_tags=None,
23342334
new_kms_key=None,
23352335
new_data_capture_config_dict=None,
2336+
new_production_variants=None,
23362337
):
23372338
"""Create an Amazon SageMaker endpoint configuration from an existing one. Updating any
23382339
values that were passed in.
@@ -2346,7 +2347,7 @@ def create_endpoint_config_from_existing(
23462347
new_config_name (str): Name of the Amazon SageMaker endpoint configuration to create.
23472348
existing_config_name (str): Name of the existing Amazon SageMaker endpoint
23482349
configuration.
2349-
new_tags(List[dict[str, str]]): Optional. The list of tags to add to the endpoint
2350+
new_tags (list[dict[str, str]]): Optional. The list of tags to add to the endpoint
23502351
config. If not specified, the tags of the existing endpoint configuration are used.
23512352
If any of the existing tags are reserved AWS ones (i.e. begin with "aws"),
23522353
they are not carried over to the new endpoint configuration.
@@ -2357,6 +2358,9 @@ def create_endpoint_config_from_existing(
23572358
capture for use with Amazon SageMaker Model Monitoring (default: None).
23582359
If not specified, the data capture configuration of the existing
23592360
endpoint configuration is used.
2361+
new_production_variants (list[dict]): The configuration for which model(s) to host and
2362+
the resources to deploy for hosting the model(s). If not specified,
2363+
the ``ProductionVariants`` of the existing endpoint configuration is used.
23602364
23612365
Returns:
23622366
str: Name of the endpoint point configuration created.
@@ -2370,9 +2374,12 @@ def create_endpoint_config_from_existing(
23702374

23712375
request = {
23722376
"EndpointConfigName": new_config_name,
2373-
"ProductionVariants": existing_endpoint_config_desc["ProductionVariants"],
23742377
}
23752378

2379+
request["ProductionVariants"] = (
2380+
new_production_variants or existing_endpoint_config_desc["ProductionVariants"]
2381+
)
2382+
23762383
request_tags = new_tags or self.list_tags(
23772384
existing_endpoint_config_desc["EndpointConfigArn"]
23782385
)

tests/integ/test_inference_pipeline.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from sagemaker.predictor import Predictor, json_serializer
3030
from sagemaker.sparkml.model import SparkMLModel
3131
from sagemaker.utils import sagemaker_timestamp
32-
from tests.integ.retry import retries
3332

3433
SPARKML_DATA_PATH = os.path.join(DATA_DIR, "sparkml_model")
3534
XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model")
@@ -151,7 +150,7 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
151150
assert "Could not find model" in str(exception.value)
152151

153152

154-
def test_inference_pipeline_model_deploy_with_update_endpoint(
153+
def test_inference_pipeline_model_deploy_and_update_endpoint(
155154
sagemaker_session, cpu_instance_type, alternative_cpu_instance_type
156155
):
157156
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
@@ -181,24 +180,18 @@ def test_inference_pipeline_model_deploy_with_update_endpoint(
181180
role="SageMakerRole",
182181
sagemaker_session=sagemaker_session,
183182
)
184-
model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
185-
old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
183+
predictor = model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
184+
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
186185
EndpointName=endpoint_name
187186
)
188-
old_config_name = old_endpoint["EndpointConfigName"]
187+
old_config_name = endpoint_desc["EndpointConfigName"]
189188

190-
model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)
189+
predictor.update_endpoint(initial_instance_count=1, instance_type=cpu_instance_type)
191190

192-
# Wait for endpoint to finish updating
193-
# Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
194-
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
195-
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
196-
EndpointName=endpoint_name
197-
)
198-
if new_endpoint["EndpointStatus"] == "InService":
199-
break
200-
201-
new_config_name = new_endpoint["EndpointConfigName"]
191+
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
192+
EndpointName=endpoint_name
193+
)
194+
new_config_name = endpoint_desc["EndpointConfigName"]
202195
new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
203196
EndpointConfigName=new_config_name
204197
)

tests/integ/test_multidatamodel.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -496,25 +496,20 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
496496
result = predictor.predict(data, target_model=PRETRAINED_MODEL_PATH_2)
497497
assert result == "Invoked model: {}".format(PRETRAINED_MODEL_PATH_2)
498498

499-
old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
499+
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
500500
EndpointName=endpoint_name
501501
)
502-
old_config_name = old_endpoint["EndpointConfigName"]
502+
old_config_name = endpoint_desc["EndpointConfigName"]
503503

504504
# Update endpoint
505-
multi_data_model.deploy(
506-
1, alternative_cpu_instance_type, endpoint_name=endpoint_name, update_endpoint=True
505+
predictor.update_endpoint(
506+
initial_instance_count=1, instance_type=alternative_cpu_instance_type
507507
)
508508

509-
# Wait for endpoint to finish updating
510-
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
511-
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
512-
EndpointName=endpoint_name
513-
)
514-
if new_endpoint["EndpointStatus"] == "InService":
515-
break
516-
517-
new_config_name = new_endpoint["EndpointConfigName"]
509+
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
510+
EndpointName=endpoint_name
511+
)
512+
new_config_name = endpoint_desc["EndpointConfigName"]
518513

519514
new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
520515
EndpointConfigName=new_config_name
@@ -531,6 +526,7 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
531526
EndpointConfigName=new_config_name
532527
)
533528
multi_data_model.delete_model()
529+
534530
with pytest.raises(Exception) as exception:
535531
sagemaker_session.sagemaker_client.describe_model(ModelName=model_name)
536532
assert "Could not find model" in str(exception.value)

tests/integ/test_mxnet_train.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from sagemaker.utils import sagemaker_timestamp
2525
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2626
from tests.integ.kms_utils import get_or_create_kms_key
27-
from tests.integ.retry import retries
2827
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2928

3029

@@ -203,7 +202,7 @@ def test_deploy_model_with_tags_and_kms(
203202
assert endpoint_config["KmsKeyId"] == kms_key_arn
204203

205204

206-
def test_deploy_model_with_update_endpoint(
205+
def test_deploy_model_and_update_endpoint(
207206
mxnet_training_job,
208207
sagemaker_session,
209208
mxnet_full_version,
@@ -227,24 +226,18 @@ def test_deploy_model_with_update_endpoint(
227226
sagemaker_session=sagemaker_session,
228227
framework_version=mxnet_full_version,
229228
)
230-
model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
231-
old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
229+
predictor = model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
230+
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
232231
EndpointName=endpoint_name
233232
)
234-
old_config_name = old_endpoint["EndpointConfigName"]
235-
236-
model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)
233+
old_config_name = endpoint_desc["EndpointConfigName"]
237234

238-
# Wait for endpoint to finish updating
239-
# Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
240-
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
241-
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
242-
EndpointName=endpoint_name
243-
)
244-
if new_endpoint["EndpointStatus"] == "InService":
245-
break
235+
predictor.update_endpoint(initial_instance_count=1, instance_type=cpu_instance_type)
246236

247-
new_config_name = new_endpoint["EndpointConfigName"]
237+
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
238+
EndpointName=endpoint_name
239+
)
240+
new_config_name = endpoint_desc["EndpointConfigName"]
248241
new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
249242
EndpointConfigName=new_config_name
250243
)

0 commit comments

Comments
 (0)