Skip to content

Commit 55193a2

Browse files
author
Chuyang Deng
committed
Add delete_model to Predictor and Pipeline.
1 parent f381475 commit 55193a2

26 files changed

+301
-23
lines changed

src/sagemaker/pipeline.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,14 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
103103
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
104104
if self.predictor_cls:
105105
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
106+
107+
def delete_model(self):
108+
"""Delete the SageMaker model backing this pipeline model. This does not delete the list of SageMaker models used
109+
in multiple containers to build the inference pipeline.
110+
111+
"""
112+
113+
if self.name is None:
114+
raise ValueError('The SageMaker model must be created before attempting to delete.')
115+
116+
self.sagemaker_session.delete_model(self.name)

src/sagemaker/predictor.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(self, endpoint, sagemaker_session=None, serializer=None, deserializ
5656
self.deserializer = deserializer
5757
self.content_type = content_type or getattr(serializer, 'content_type', None)
5858
self.accept = accept or getattr(deserializer, 'accept', None)
59+
self._model_names = self._get_model_names()
5960

6061
def predict(self, data, initial_args=None):
6162
"""Return the inference from the specified endpoint.
@@ -109,16 +110,15 @@ def _delete_endpoint_config(self):
109110
"""Delete the Amazon SageMaker endpoint configuration
110111
111112
"""
112-
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
113-
endpoint_config_name = endpoint_description['EndpointConfigName']
114-
self.sagemaker_session.delete_endpoint_config(endpoint_config_name)
113+
self.sagemaker_session.delete_endpoint_config(self._endpoint_config_name)
115114

116115
def delete_endpoint(self, delete_endpoint_config=True):
117116
"""Delete the Amazon SageMaker endpoint and endpoint configuration backing this predictor.
118117
119118
Args:
120-
delete_endpoint_config (bool): Flag to indicate whether to delete the corresponding SageMaker endpoint
121-
configuration tied to the endpoint. If False, only the endpoint will be deleted. (default: True)
119+
delete_endpoint_config (bool, optional): Flag to indicate whether to delete endpoint configuration together
120+
with endpoint. Defaults to True. If True, both endpoint and endpoint configuration will be deleted. If
121+
False, only endpoint will be deleted.
122122
123123
"""
124124
if delete_endpoint_config:
@@ -130,15 +130,14 @@ def delete_model(self):
130130
"""Deletes the Amazon SageMaker models backing this predictor.
131131
132132
"""
133-
model_names = self._get_model_names()
134-
for model_name in model_names:
133+
for model_name in self._model_names:
135134
self.sagemaker_session.delete_model(model_name)
136135

137136
def _get_model_names(self):
138137
endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
139-
endpoint_config_name = endpoint_desc['EndpointConfigName']
138+
self._endpoint_config_name = endpoint_desc['EndpointConfigName']
140139
endpoint_config = self.sagemaker_session.sagemaker_client.describe_endpoint_config(
141-
EndpointConfigName=endpoint_config_name)
140+
EndpointConfigName=self._endpoint_config_name)
142141
production_variants = endpoint_config['ProductionVariants']
143142
return map(lambda d: d['ModelName'], production_variants)
144143

tests/integ/test_inference_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,8 @@ def test_inference_pipeline_model_deploy(sagemaker_session):
9292

9393
invalid_data = "1.0,28.0,C,38.0,71.5,1.0"
9494
assert (predictor.predict(invalid_data) is None)
95+
96+
model.delete_model()
97+
with pytest.raises(Exception) as exception:
98+
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
99+
assert 'Could not find model' in str(exception.value)

tests/unit/test_chainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@
4545
GPU = 'ml.p2.xlarge'
4646
CPU = 'ml.c4.xlarge'
4747

48+
ENDPOINT_DESC = {
49+
'EndpointConfigName': 'test-endpoint'
50+
}
51+
52+
ENDPOINT_CONFIG_DESC = {
53+
'ProductionVariants': [{'ModelName': 'model-1'},
54+
{'ModelName': 'model-2'}]
55+
}
56+
4857

4958
@pytest.fixture()
5059
def sagemaker_session():
@@ -54,6 +63,8 @@ def sagemaker_session():
5463

5564
describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}}
5665
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
66+
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
67+
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5768
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
5869
session.expand_role = Mock(name="expand_role", return_value=ROLE)
5970
return session

tests/unit/test_estimator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,15 @@
102102
'ModelDataUrl': MODEL_DATA,
103103
}
104104

105+
ENDPOINT_DESC = {
106+
'EndpointConfigName': 'test-endpoint'
107+
}
108+
109+
ENDPOINT_CONFIG_DESC = {
110+
'ProductionVariants': [{'ModelName': 'model-1'},
111+
{'ModelName': 'model-2'}]
112+
}
113+
105114

106115
class DummyFramework(Framework):
107116
__framework_name__ = 'dummy'
@@ -146,6 +155,8 @@ def sagemaker_session():
146155
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
147156
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
148157
return_value=DESCRIBE_TRAINING_JOB_RESULT)
158+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
159+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
149160
return sms
150161

151162

tests/unit/test_fm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@
3737
}
3838
}
3939

40+
ENDPOINT_DESC = {
41+
'EndpointConfigName': 'test-endpoint'
42+
}
43+
44+
ENDPOINT_CONFIG_DESC = {
45+
'ProductionVariants': [{'ModelName': 'model-1'},
46+
{'ModelName': 'model-2'}]
47+
}
48+
4049

4150
@pytest.fixture()
4251
def sagemaker_session():
@@ -47,6 +56,8 @@ def sagemaker_session():
4756
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4857
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4958
return_value=DESCRIBE_TRAINING_JOB_RESULT)
59+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
60+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5061
return sms
5162

5263

tests/unit/test_ipinsights.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@
3939
}
4040
}
4141

42+
ENDPOINT_DESC = {
43+
'EndpointConfigName': 'test-endpoint'
44+
}
45+
46+
ENDPOINT_CONFIG_DESC = {
47+
'ProductionVariants': [{'ModelName': 'model-1'},
48+
{'ModelName': 'model-2'}]
49+
}
50+
4251

4352
@pytest.fixture()
4453
def sagemaker_session():
@@ -49,6 +58,8 @@ def sagemaker_session():
4958
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
5059
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
5160
return_value=DESCRIBE_TRAINING_JOB_RESULT)
61+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
62+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5263

5364
return sms
5465

tests/unit/test_kmeans.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@
3636
}
3737
}
3838

39+
ENDPOINT_DESC = {
40+
'EndpointConfigName': 'test-endpoint'
41+
}
42+
43+
ENDPOINT_CONFIG_DESC = {
44+
'ProductionVariants': [{'ModelName': 'model-1'},
45+
{'ModelName': 'model-2'}]
46+
}
47+
3948

4049
@pytest.fixture()
4150
def sagemaker_session():
@@ -46,6 +55,8 @@ def sagemaker_session():
4655
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4756
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4857
return_value=DESCRIBE_TRAINING_JOB_RESULT)
58+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
59+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
4960

5061
return sms
5162

tests/unit/test_knn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@
4040
}
4141
}
4242

43+
ENDPOINT_DESC = {
44+
'EndpointConfigName': 'test-endpoint'
45+
}
46+
47+
ENDPOINT_CONFIG_DESC = {
48+
'ProductionVariants': [{'ModelName': 'model-1'},
49+
{'ModelName': 'model-2'}]
50+
}
51+
4352

4453
@pytest.fixture()
4554
def sagemaker_session():
@@ -50,6 +59,8 @@ def sagemaker_session():
5059
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
5160
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
5261
return_value=DESCRIBE_TRAINING_JOB_RESULT)
62+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
63+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5364

5465
return sms
5566

tests/unit/test_lda.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@
3535
}
3636
}
3737

38+
ENDPOINT_DESC = {
39+
'EndpointConfigName': 'test-endpoint'
40+
}
41+
42+
ENDPOINT_CONFIG_DESC = {
43+
'ProductionVariants': [{'ModelName': 'model-1'},
44+
{'ModelName': 'model-2'}]
45+
}
46+
3847

3948
@pytest.fixture()
4049
def sagemaker_session():
@@ -44,6 +53,8 @@ def sagemaker_session():
4453
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4554
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4655
return_value=DESCRIBE_TRAINING_JOB_RESULT)
56+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
57+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
4758

4859
return sms
4960

tests/unit/test_linear_learner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@
3737
}
3838
}
3939

40+
ENDPOINT_DESC = {
41+
'EndpointConfigName': 'test-endpoint'
42+
}
43+
44+
ENDPOINT_CONFIG_DESC = {
45+
'ProductionVariants': [{'ModelName': 'model-1'},
46+
{'ModelName': 'model-2'}]
47+
}
48+
4049

4150
@pytest.fixture()
4251
def sagemaker_session():
@@ -47,6 +56,8 @@ def sagemaker_session():
4756
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4857
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4958
return_value=DESCRIBE_TRAINING_JOB_RESULT)
59+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
60+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5061

5162
return sms
5263

tests/unit/test_mxnet.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@
4545
CPU_C5 = 'ml.c5.xlarge'
4646
LAUNCH_PS_DISTRIBUTIONS_DICT = {'parameter_server': {'enabled': True}}
4747

48+
ENDPOINT_DESC = {
49+
'EndpointConfigName': 'test-endpoint'
50+
}
51+
52+
ENDPOINT_CONFIG_DESC = {
53+
'ProductionVariants': [{'ModelName': 'model-1'},
54+
{'ModelName': 'model-2'}]
55+
}
56+
4857

4958
@pytest.fixture()
5059
def sagemaker_session():
@@ -55,6 +64,8 @@ def sagemaker_session():
5564
describe = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/m.tar.gz'}}
5665
describe_compilation = {'ModelArtifacts': {'S3ModelArtifacts': 's3://m/model_c5.tar.gz'}}
5766
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
67+
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
68+
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5869
session.wait_for_compilation_job = Mock(return_value=describe_compilation)
5970
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
6071
session.expand_role = Mock(name="expand_role", return_value=ROLE)

tests/unit/test_ntm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@
3636
}
3737
}
3838

39+
ENDPOINT_DESC = {
40+
'EndpointConfigName': 'test-endpoint'
41+
}
42+
43+
ENDPOINT_CONFIG_DESC = {
44+
'ProductionVariants': [{'ModelName': 'model-1'},
45+
{'ModelName': 'model-2'}]
46+
}
47+
3948

4049
@pytest.fixture()
4150
def sagemaker_session():
@@ -46,6 +55,8 @@ def sagemaker_session():
4655
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4756
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4857
return_value=DESCRIBE_TRAINING_JOB_RESULT)
58+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
59+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
4960

5061
return sms
5162

tests/unit/test_object2vec.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@
4545
}
4646
}
4747

48+
ENDPOINT_DESC = {
49+
'EndpointConfigName': 'test-endpoint'
50+
}
51+
52+
ENDPOINT_CONFIG_DESC = {
53+
'ProductionVariants': [{'ModelName': 'model-1'},
54+
{'ModelName': 'model-2'}]
55+
}
56+
4857

4958
@pytest.fixture()
5059
def sagemaker_session():
@@ -55,6 +64,8 @@ def sagemaker_session():
5564
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
5665
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
5766
return_value=DESCRIBE_TRAINING_JOB_RESULT)
67+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
68+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
5869

5970
return sms
6071

tests/unit/test_pca.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@
3636
}
3737
}
3838

39+
ENDPOINT_DESC = {
40+
'EndpointConfigName': 'test-endpoint'
41+
}
42+
43+
ENDPOINT_CONFIG_DESC = {
44+
'ProductionVariants': [{'ModelName': 'model-1'},
45+
{'ModelName': 'model-2'}]
46+
}
47+
3948

4049
@pytest.fixture()
4150
def sagemaker_session():
@@ -46,6 +55,8 @@ def sagemaker_session():
4655
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4756
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
4857
return_value=DESCRIBE_TRAINING_JOB_RESULT)
58+
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
59+
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
4960

5061
return sms
5162

tests/unit/test_pipeline_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,22 @@ def test_deploy_tags(tfo, time, sagemaker_session):
138138
'InitialInstanceCount': 1,
139139
'VariantName': 'AllTraffic'}],
140140
tags)
141+
142+
143+
def test_delete_model_without_deploy(sagemaker_session):
144+
pipeline_model = PipelineModel([], role=ROLE, sagemaker_session=sagemaker_session)
145+
146+
expected_error_message = 'The SageMaker model must be created before attempting to delete.'
147+
with pytest.raises(ValueError, match=expected_error_message):
148+
pipeline_model.delete_model()
149+
150+
151+
@patch('tarfile.open')
152+
@patch('time.strftime', return_value=TIMESTAMP)
153+
def test_delete_model(tfo, time, sagemaker_session):
154+
framework_model = DummyFrameworkModel(sagemaker_session)
155+
pipeline_model = PipelineModel([framework_model], role=ROLE, sagemaker_session=sagemaker_session)
156+
pipeline_model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1)
157+
158+
pipeline_model.delete_model()
159+
sagemaker_session.delete_model.assert_called_with(pipeline_model.name)

0 commit comments

Comments
 (0)