Skip to content

Commit 75f3554

Browse files
imujjwal96knakad
authored andcommitted
fix: add update_endpoint for PipelineModel (#972)
Create the update_endpoint argument and added a check for its value to create a new endpoint config for an existing endpoint.
1 parent 3895569 commit 75f3554

File tree

3 files changed

+128
-4
lines changed

3 files changed

+128
-4
lines changed

src/sagemaker/pipeline.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@ def pipeline_container_def(self, instance_type):
8383
return sagemaker.pipeline_container_def(self.models, instance_type)
8484

8585
def deploy(
86-
self, initial_instance_count, instance_type, endpoint_name=None, tags=None, wait=True
86+
self,
87+
initial_instance_count,
88+
instance_type,
89+
endpoint_name=None,
90+
tags=None,
91+
wait=True,
92+
update_endpoint=False,
8793
):
8894
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a
8995
``Predictor``.
@@ -110,6 +116,11 @@ def deploy(
110116
specific endpoint.
111117
wait (bool): Whether the call should wait until the deployment of
112118
model completes (default: True).
119+
update_endpoint (bool): Flag to update the model in an existing
120+
Amazon SageMaker endpoint. If True, this will deploy a new
121+
EndpointConfig to an already existing endpoint and delete
122+
resources corresponding to the previous EndpointConfig. If
123+
False, a new endpoint will be created. Default: False
113124
114125
Returns:
115126
callable[string, sagemaker.session.Session] or None: Invocation of
@@ -130,9 +141,21 @@ def deploy(
130141
self.name, instance_type, initial_instance_count
131142
)
132143
self.endpoint_name = endpoint_name or self.name
133-
self.sagemaker_session.endpoint_from_production_variants(
134-
self.endpoint_name, [production_variant], tags, wait=wait
135-
)
144+
145+
if update_endpoint:
146+
endpoint_config_name = self.sagemaker_session.create_endpoint_config(
147+
name=self.name,
148+
model_name=self.name,
149+
initial_instance_count=initial_instance_count,
150+
instance_type=instance_type,
151+
tags=tags,
152+
)
153+
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
154+
else:
155+
self.sagemaker_session.endpoint_from_production_variants(
156+
self.endpoint_name, [production_variant], tags, wait=wait
157+
)
158+
136159
if self.predictor_cls:
137160
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
138161
return None

tests/integ/test_inference_pipeline.py

+66
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
import os
17+
import time
1718

1819
import pytest
1920
from tests.integ import DATA_DIR, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
@@ -148,3 +149,68 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
148149
with pytest.raises(Exception) as exception:
149150
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
150151
assert "Could not find model" in str(exception.value)
152+
153+
154+
def test_inference_pipeline_model_deploy_with_update_endpoint(sagemaker_session):
155+
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
156+
xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model")
157+
endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp())
158+
sparkml_model_data = sagemaker_session.upload_data(
159+
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
160+
key_prefix="integ-test-data/sparkml/model",
161+
)
162+
xgb_model_data = sagemaker_session.upload_data(
163+
path=os.path.join(xgboost_data_path, "xgb_model.tar.gz"),
164+
key_prefix="integ-test-data/xgboost/model",
165+
)
166+
167+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
168+
sparkml_model = SparkMLModel(
169+
model_data=sparkml_model_data,
170+
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
171+
sagemaker_session=sagemaker_session,
172+
)
173+
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
174+
xgb_model = Model(
175+
model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session
176+
)
177+
model = PipelineModel(
178+
models=[sparkml_model, xgb_model],
179+
role="SageMakerRole",
180+
sagemaker_session=sagemaker_session,
181+
)
182+
model.deploy(1, "ml.t2.medium", endpoint_name=endpoint_name)
183+
old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
184+
EndpointName=endpoint_name
185+
)
186+
old_config_name = old_endpoint["EndpointConfigName"]
187+
188+
model.deploy(1, "ml.m4.xlarge", update_endpoint=True, endpoint_name=endpoint_name)
189+
190+
# Wait for endpoint to finish updating
191+
max_retry_count = 40 # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
192+
current_retry_count = 0
193+
while current_retry_count <= max_retry_count:
194+
if current_retry_count >= max_retry_count:
195+
raise Exception("Endpoint status not 'InService' within expected timeout.")
196+
time.sleep(30)
197+
new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
198+
EndpointName=endpoint_name
199+
)
200+
current_retry_count += 1
201+
if new_endpoint["EndpointStatus"] == "InService":
202+
break
203+
204+
new_config_name = new_endpoint["EndpointConfigName"]
205+
new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
206+
EndpointConfigName=new_config_name
207+
)
208+
209+
assert old_config_name != new_config_name
210+
assert new_config["ProductionVariants"][0]["InstanceType"] == "ml.m4.xlarge"
211+
assert new_config["ProductionVariants"][0]["InitialInstanceCount"] == 1
212+
213+
model.delete_model()
214+
with pytest.raises(Exception) as exception:
215+
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
216+
assert "Could not find model" in str(exception.value)

tests/unit/test_pipeline_model.py

+35
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,41 @@ def test_deploy_endpoint_name(tfo, time, sagemaker_session):
159159
)
160160

161161

162+
@patch("tarfile.open")
163+
@patch("time.strftime", return_value=TIMESTAMP)
164+
def test_deploy_update_endpoint(tfo, time, sagemaker_session):
165+
framework_model = DummyFrameworkModel(sagemaker_session)
166+
endpoint_name = "endpoint-name"
167+
sparkml_model = SparkMLModel(
168+
model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session
169+
)
170+
model = PipelineModel(
171+
models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session
172+
)
173+
model.deploy(
174+
instance_type=INSTANCE_TYPE,
175+
initial_instance_count=1,
176+
endpoint_name=endpoint_name,
177+
update_endpoint=True,
178+
)
179+
180+
sagemaker_session.create_endpoint_config.assert_called_with(
181+
name=model.name,
182+
model_name=model.name,
183+
initial_instance_count=INSTANCE_COUNT,
184+
instance_type=INSTANCE_TYPE,
185+
tags=None,
186+
)
187+
config_name = sagemaker_session.create_endpoint_config(
188+
name=model.name,
189+
model_name=model.name,
190+
initial_instance_count=INSTANCE_COUNT,
191+
instance_type=INSTANCE_TYPE,
192+
)
193+
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name)
194+
sagemaker_session.create_endpoint.assert_not_called()
195+
196+
162197
@patch("tarfile.open")
163198
@patch("time.strftime", return_value=TIMESTAMP)
164199
def test_transformer(tfo, time, sagemaker_session):

0 commit comments

Comments
 (0)