Skip to content

feature: add Predictor.update_endpoint() #1656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
from sagemaker.model_monitor import DataCaptureConfig
from sagemaker.session import Session
from sagemaker.session import production_variant, Session
from sagemaker.utils import name_from_base

from sagemaker.model_monitor.model_monitoring import (
Expand Down Expand Up @@ -157,6 +157,106 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
args["Body"] = data
return args

def update_endpoint(
self,
initial_instance_count=None,
instance_type=None,
accelerator_type=None,
model_name=None,
tags=None,
kms_key=None,
data_capture_config_dict=None,
wait=True,
):
"""Update the existing endpoint with the provided attributes.

This creates a new EndpointConfig in the process. If ``initial_instance_count``,
``instance_type``, ``accelerator_type``, or ``model_name`` is specified, then a new
ProductionVariant configuration is created; values from the existing configuration
are not preserved if any of those parameters are specified.

Args:
initial_instance_count (int): The initial number of instances to run in the endpoint.
This is required if ``instance_type``, ``accelerator_type``, or ``model_name`` is
specified. Otherwise, the values from the existing endpoint configuration's
ProductionVariants are used.
instance_type (str): The EC2 instance type to deploy the endpoint to.
This is required if ``initial_instance_count`` or ``accelerator_type`` is specified.
Otherwise, the values from the existing endpoint configuration's
``ProductionVariants`` are used.
accelerator_type (str): The type of Elastic Inference accelerator to attach to
the endpoint, e.g. "ml.eia1.medium". If not specified, and
``initial_instance_count``, ``instance_type``, and ``model_name`` are also ``None``,
the values from the existing endpoint configuration's ``ProductionVariants`` are
used. Otherwise, no Elastic Inference accelerator is attached to the endpoint.
model_name (str): The name of the model to be associated with the endpoint.
This is required if ``initial_instance_count``, ``instance_type``, or
``accelerator_type`` is specified and if there is more than one model associated
with the endpoint. Otherwise, the existing model for the endpoint is used.
tags (list[dict[str, str]]): The list of tags to add to the endpoint
config. If not specified, the tags of the existing endpoint configuration are used.
If any of the existing tags are reserved AWS ones (i.e. begin with "aws"),
they are not carried over to the new endpoint configuration.
kms_key (str): The KMS key that is used to encrypt the data on the storage volume
attached to the instance hosting the endpoint If not specified,
the KMS key of the existing endpoint configuration is used.
data_capture_config_dict (dict): The endpoint data capture configuration
for use with Amazon SageMaker Model Monitoring. If not specified,
the data capture configuration of the existing endpoint configuration is used.

Raises:
ValueError: If there is not enough information to create a new ``ProductionVariant``:

- If ``initial_instance_count``, ``accelerator_type``, or ``model_name`` is
specified, but ``instance_type`` is ``None``.
- If ``initial_instance_count``, ``instance_type``, or ``accelerator_type`` is
specified and either ``model_name`` is ``None`` or there are multiple models
associated with the endpoint.
"""
production_variants = None

if initial_instance_count or instance_type or accelerator_type or model_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m just curious about the None check here. Why are we checking the args and raise ValueError instead of making them required variables?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my thinking was that if someone wanted to update only the tags, KMS key, or data capture config, they shouldn't need to provide all the other info too.

if instance_type is None or initial_instance_count is None:
raise ValueError(
"Missing initial_instance_count and/or instance_type. Provided values: "
"initial_instance_count={}, instance_type={}, accelerator_type={}, "
"model_name={}.".format(
initial_instance_count, instance_type, accelerator_type, model_name
)
)

if model_name is None:
if len(self._model_names) > 1:
raise ValueError(
"Unable to choose a default model for a new EndpointConfig because "
"the endpoint has multiple models: {}".format(", ".join(self._model_names))
)
model_name = self._model_names[0]
else:
self._model_names = [model_name]

production_variant_config = production_variant(
model_name,
instance_type,
initial_instance_count=initial_instance_count,
accelerator_type=accelerator_type,
)
production_variants = [production_variant_config]

new_endpoint_config_name = name_from_base(self._endpoint_config_name)
self.sagemaker_session.create_endpoint_config_from_existing(
self._endpoint_config_name,
new_endpoint_config_name,
new_tags=tags,
new_kms_key=kms_key,
new_data_capture_config_dict=data_capture_config_dict,
new_production_variants=production_variants,
)
self.sagemaker_session.update_endpoint(
self.endpoint_name, new_endpoint_config_name, wait=wait
)
self._endpoint_config_name = new_endpoint_config_name

def _delete_endpoint_config(self):
"""Delete the Amazon SageMaker endpoint configuration"""
self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name)
Expand Down
11 changes: 9 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,7 @@ def create_endpoint_config_from_existing(
new_tags=None,
new_kms_key=None,
new_data_capture_config_dict=None,
new_production_variants=None,
):
"""Create an Amazon SageMaker endpoint configuration from an existing one. Updating any
values that were passed in.
Expand All @@ -2346,7 +2347,7 @@ def create_endpoint_config_from_existing(
new_config_name (str): Name of the Amazon SageMaker endpoint configuration to create.
existing_config_name (str): Name of the existing Amazon SageMaker endpoint
configuration.
new_tags(List[dict[str, str]]): Optional. The list of tags to add to the endpoint
new_tags (list[dict[str, str]]): Optional. The list of tags to add to the endpoint
config. If not specified, the tags of the existing endpoint configuration are used.
If any of the existing tags are reserved AWS ones (i.e. begin with "aws"),
they are not carried over to the new endpoint configuration.
Expand All @@ -2357,6 +2358,9 @@ def create_endpoint_config_from_existing(
capture for use with Amazon SageMaker Model Monitoring (default: None).
If not specified, the data capture configuration of the existing
endpoint configuration is used.
new_production_variants (list[dict]): The configuration for which model(s) to host and
the resources to deploy for hosting the model(s). If not specified,
the ``ProductionVariants`` of the existing endpoint configuration is used.

Returns:
str: Name of the endpoint point configuration created.
Expand All @@ -2370,9 +2374,12 @@ def create_endpoint_config_from_existing(

request = {
"EndpointConfigName": new_config_name,
"ProductionVariants": existing_endpoint_config_desc["ProductionVariants"],
}

request["ProductionVariants"] = (
new_production_variants or existing_endpoint_config_desc["ProductionVariants"]
)

request_tags = new_tags or self.list_tags(
existing_endpoint_config_desc["EndpointConfigArn"]
)
Expand Down
26 changes: 10 additions & 16 deletions tests/integ/test_inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from sagemaker.predictor import Predictor, json_serializer
from sagemaker.sparkml.model import SparkMLModel
from sagemaker.utils import sagemaker_timestamp
from tests.integ.retry import retries

SPARKML_DATA_PATH = os.path.join(DATA_DIR, "sparkml_model")
XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model")
Expand Down Expand Up @@ -151,7 +150,7 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
assert "Could not find model" in str(exception.value)


def test_inference_pipeline_model_deploy_with_update_endpoint(
def test_inference_pipeline_model_deploy_and_update_endpoint(
sagemaker_session, cpu_instance_type, alternative_cpu_instance_type
):
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
Expand Down Expand Up @@ -179,26 +178,21 @@ def test_inference_pipeline_model_deploy_with_update_endpoint(
model = PipelineModel(
models=[sparkml_model, xgb_model],
role="SageMakerRole",
predictor_cls=Predictor,
sagemaker_session=sagemaker_session,
)
model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
predictor = model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
old_config_name = old_endpoint["EndpointConfigName"]
old_config_name = endpoint_desc["EndpointConfigName"]

model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)
predictor.update_endpoint(initial_instance_count=1, instance_type=cpu_instance_type)

# Wait for endpoint to finish updating
# Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
if new_endpoint["EndpointStatus"] == "InService":
break

new_config_name = new_endpoint["EndpointConfigName"]
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
new_config_name = endpoint_desc["EndpointConfigName"]
new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
EndpointConfigName=new_config_name
)
Expand Down
22 changes: 9 additions & 13 deletions tests/integ/test_multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,25 +496,20 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
result = predictor.predict(data, target_model=PRETRAINED_MODEL_PATH_2)
assert result == "Invoked model: {}".format(PRETRAINED_MODEL_PATH_2)

old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
old_config_name = old_endpoint["EndpointConfigName"]
old_config_name = endpoint_desc["EndpointConfigName"]

# Update endpoint
multi_data_model.deploy(
1, alternative_cpu_instance_type, endpoint_name=endpoint_name, update_endpoint=True
predictor.update_endpoint(
initial_instance_count=1, instance_type=alternative_cpu_instance_type
)

# Wait for endpoint to finish updating
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
if new_endpoint["EndpointStatus"] == "InService":
break

new_config_name = new_endpoint["EndpointConfigName"]
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
new_config_name = endpoint_desc["EndpointConfigName"]

new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
EndpointConfigName=new_config_name
Expand All @@ -531,6 +526,7 @@ def test_multi_data_model_deploy_pretrained_models_update_endpoint(
EndpointConfigName=new_config_name
)
multi_data_model.delete_model()

with pytest.raises(Exception) as exception:
sagemaker_session.sagemaker_client.describe_model(ModelName=model_name)
assert "Could not find model" in str(exception.value)
Expand Down
25 changes: 9 additions & 16 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from sagemaker.utils import sagemaker_timestamp
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.kms_utils import get_or_create_kms_key
from tests.integ.retry import retries
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name


Expand Down Expand Up @@ -203,7 +202,7 @@ def test_deploy_model_with_tags_and_kms(
assert endpoint_config["KmsKeyId"] == kms_key_arn


def test_deploy_model_with_update_endpoint(
def test_deploy_model_and_update_endpoint(
mxnet_training_job,
sagemaker_session,
mxnet_full_version,
Expand All @@ -227,24 +226,18 @@ def test_deploy_model_with_update_endpoint(
sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version,
)
model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
predictor = model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
old_config_name = old_endpoint["EndpointConfigName"]

model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)
old_config_name = endpoint_desc["EndpointConfigName"]

# Wait for endpoint to finish updating
# Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
if new_endpoint["EndpointStatus"] == "InService":
break
predictor.update_endpoint(initial_instance_count=1, instance_type=cpu_instance_type)

new_config_name = new_endpoint["EndpointConfigName"]
endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=endpoint_name
)
new_config_name = endpoint_desc["EndpointConfigName"]
new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
EndpointConfigName=new_config_name
)
Expand Down
Loading