Skip to content

Commit 7d2e06d

Browse files
chuyang-dengChoiByungWook
authored andcommitted
Add new APIs to clean up resources from predictor and transformer. (#630)
1 parent c012a0f commit 7d2e06d

14 files changed

+217
-41
lines changed

CHANGELOG.rst

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ CHANGELOG
88
* doc-fix: update information about saving models in the MXNet README
99
* doc-fix: change ReadTheDocs links from latest to stable
1010
* doc-fix: add ``transform_fn`` information and fix ``input_fn`` signature in the MXNet README
11+
* feature: Support for ``Predictor`` to delete endpoint configuration by default when calling ``delete_endpoint()``
12+
* feature: Support for ``model`` to delete SageMaker model
13+
* feature: Support for ``Transformer`` to delete SageMaker model
1114

1215
1.18.2
1316
======

README.rst

+16-7
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,16 @@ Here is an end to end example of how to use a SageMaker Estimator:
189189
# Serializes data and makes a prediction request to the SageMaker endpoint
190190
response = mxnet_predictor.predict(data)
191191
192-
# Tears down the SageMaker endpoint
193-
mxnet_estimator.delete_endpoint()
192+
# Tears down the SageMaker endpoint and endpoint configuration
193+
mxnet_predictor.delete_endpoint()
194194
195195
196+
The example above will eventually delete both the SageMaker endpoint and endpoint configuration through `delete_endpoint()`. If you want to keep your SageMaker endpoint configuration, use the value False for the `delete_endpoint_config` parameter, as shown below.
197+
198+
.. code:: python
199+
# Only delete the SageMaker endpoint, while keeping the corresponding endpoint configuration.
200+
mxnet_predictor.delete_endpoint(delete_endpoint_config=False)
201+
196202
Additionally, it is possible to deploy a different endpoint configuration, which links to your model, to an already existing SageMaker endpoint.
197203
This can be done by specifying the existing endpoint name for the ``endpoint_name`` parameter along with the ``update_endpoint`` parameter as ``True`` within your ``deploy()`` call.
198204
For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.update_endpoint>`__.
@@ -220,8 +226,8 @@ For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/r
220226
# Serializes data and makes a prediction request to the SageMaker endpoint
221227
response = mxnet_predictor.predict(data)
222228
223-
# Tears down the SageMaker endpoint
224-
mxnet_estimator.delete_endpoint()
229+
# Tears down the SageMaker endpoint and endpoint configuration
230+
mxnet_predictor.delete_endpoint()
225231
226232
Training Metrics
227233
~~~~~~~~~~~~~~~~
@@ -274,8 +280,8 @@ We can take the example in `Using Estimators <#using-estimators>`__ , and use e
274280
# Serializes data and makes a prediction request to the local endpoint
275281
response = mxnet_predictor.predict(data)
276282
277-
# Tears down the endpoint container
278-
mxnet_estimator.delete_endpoint()
283+
# Tears down the endpoint container and deletes the corresponding endpoint configuration
284+
mxnet_predictor.delete_endpoint()
279285
280286
281287
If you have an existing model and want to deploy it locally, don't specify a sagemaker_session argument to the ``MXNetModel`` constructor.
@@ -297,7 +303,7 @@ Here is an end-to-end example:
297303
data = numpy.zeros(shape=(1, 1, 28, 28))
298304
predictor.predict(data)
299305
300-
# Tear down the endpoint container
306+
# Tear down the endpoint container and delete the corresponding endpoint configuration
301307
predictor.delete_endpoint()
302308
303309
@@ -322,6 +328,9 @@ Here is an end-to-end example:
322328
transformer.transform('s3://my/transform/data, content_type='text/csv', split_type='Line')
323329
transformer.wait()
324330
331+
# Deletes the SageMaker model
332+
transformer.delete_model()
333+
325334
326335
For detailed examples of running Docker in local mode, see:
327336

src/sagemaker/local/local_session.py

+8
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ def delete_endpoint(self, EndpointName):
150150
if EndpointName in LocalSagemakerClient._endpoints:
151151
LocalSagemakerClient._endpoints[EndpointName].stop()
152152

153+
def delete_endpoint_config(self, EndpointConfigName):
154+
if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
155+
del LocalSagemakerClient._endpoint_configs[EndpointConfigName]
156+
157+
def delete_model(self, ModelName):
158+
if ModelName in LocalSagemakerClient._models:
159+
del LocalSagemakerClient._models[ModelName]
160+
153161

154162
class LocalSagemakerRuntimeClient(object):
155163
"""A SageMaker Runtime client that calls a local endpoint only.

src/sagemaker/model.py

+11
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,17 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
302302
env=env, tags=tags, base_transform_job_name=self.name,
303303
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)
304304

305+
def delete_model(self):
306+
"""Delete an Amazon SageMaker Model.
307+
308+
Raises:
309+
ValueError: if the model is not created yet.
310+
311+
"""
312+
if self.name is None:
313+
raise ValueError('The SageMaker model must be created first before attempting to delete.')
314+
self.sagemaker_session.delete_model(self.name)
315+
305316

306317
SCRIPT_PARAM_NAME = 'sagemaker_program'
307318
DIR_PARAM_NAME = 'sagemaker_submit_directory'

src/sagemaker/predictor.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,25 @@ def _create_request_args(self, data, initial_args=None):
105105
args['Body'] = data
106106
return args
107107

108-
def delete_endpoint(self):
109-
"""Delete the Amazon SageMaker endpoint backing this predictor.
108+
def _delete_endpoint_config(self):
109+
"""Delete the Amazon SageMaker endpoint configuration
110+
110111
"""
112+
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
113+
endpoint_config_name = endpoint_description['EndpointConfigName']
114+
self.sagemaker_session.delete_endpoint_config(endpoint_config_name)
115+
116+
def delete_endpoint(self, delete_endpoint_config=True):
117+
"""Delete the Amazon SageMaker endpoint and endpoint configuration backing this predictor.
118+
119+
Args:
120+
delete_endpoint_config (bool): Flag to indicate whether to delete the corresponding SageMaker endpoint
121+
configuration tied to the endpoint. If False, only the endpoint will be deleted. (default: True)
122+
123+
"""
124+
if delete_endpoint_config:
125+
self._delete_endpoint_config()
126+
111127
self.sagemaker_session.delete_endpoint(self.endpoint)
112128

113129

src/sagemaker/session.py

+19
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,25 @@ def delete_endpoint(self, endpoint_name):
780780
LOGGER.info('Deleting endpoint with name: {}'.format(endpoint_name))
781781
self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
782782

783+
def delete_endpoint_config(self, endpoint_config_name):
784+
"""Delete an Amazon SageMaker endpoint configuration.
785+
786+
Args:
787+
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to delete.
788+
"""
789+
LOGGER.info('Deleting endpoint configuration with name: {}'.format(endpoint_config_name))
790+
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
791+
792+
def delete_model(self, model_name):
793+
"""Delete an Amazon SageMaker Model.
794+
795+
Args:
796+
model_name (str): Name of the Amazon SageMaker model to delete.
797+
798+
"""
799+
LOGGER.info('Deleting model with name: {}'.format(model_name))
800+
self.sagemaker_client.delete_model(ModelName=model_name)
801+
783802
def wait_for_job(self, job, poll=5):
784803
"""Wait for an Amazon SageMaker training job to complete.
785804

src/sagemaker/transformer.py

+6
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
112112
self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
113113
split_type)
114114

115+
def delete_model(self):
116+
"""Delete the corresponding SageMaker model for this Transformer.
117+
118+
"""
119+
self.sagemaker_session.delete_model(self.model_name)
120+
115121
def _retrieve_image_name(self):
116122
model_desc = self.sagemaker_session.sagemaker_client.describe_model(ModelName=self.model_name)
117123
return model_desc['PrimaryContainer']['Image']

tests/integ/test_transformer.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sagemaker.transformer import Transformer
2525
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
2626
from tests.integ.kms_utils import get_or_create_kms_key
27-
from tests.integ.timeout import timeout
27+
from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer
2828
from tests.integ.vpc_test_utils import get_or_create_vpc_resources
2929

3030

@@ -56,7 +56,8 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version):
5656
kms_key_arn = get_or_create_kms_key(kms_client, account_id)
5757

5858
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
59-
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
59+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
60+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
6061
transformer.wait()
6162

6263
job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
@@ -100,7 +101,8 @@ def test_attach_transform_kmeans(sagemaker_session):
100101

101102
attached_transformer = Transformer.attach(transformer.latest_transform_job.name,
102103
sagemaker_session=sagemaker_session)
103-
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
104+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
105+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
104106
attached_transformer.wait()
105107

106108

@@ -135,12 +137,12 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
135137
key_prefix=transform_input_key_prefix)
136138

137139
transformer = _create_transformer_and_transform_job(mx, transform_input)
138-
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
140+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
141+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
139142
transformer.wait()
140-
141-
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
142-
assert set(subnet_ids) == set(model_desc['VpcConfig']['Subnets'])
143-
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
143+
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
144+
assert set(subnet_ids) == set(model_desc['VpcConfig']['Subnets'])
145+
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
144146

145147

146148
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):

tests/integ/timeout.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,9 @@ def timeout(seconds=0, minutes=0, hours=0):
3232
"""
3333
Add a signal-based timeout to any block of code.
3434
If multiple time units are specified, they will be added together to determine time limit.
35-
3635
Usage:
37-
3836
with timeout(seconds=5):
3937
my_slow_function(...)
40-
41-
4238
Args:
4339
- seconds: The time limit, in seconds.
4440
- minutes: The time limit, in minutes.
@@ -75,9 +71,9 @@ def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, second
7571
sagemaker_session.delete_endpoint(endpoint_name)
7672
LOGGER.info('deleted endpoint {}'.format(endpoint_name))
7773

78-
_show_endpoint_logs(endpoint_name, sagemaker_session)
74+
_show_logs(endpoint_name, 'Endpoints', sagemaker_session)
7975
if no_errors:
80-
_cleanup_endpoint_logs(endpoint_name, sagemaker_session)
76+
_cleanup_logs(endpoint_name, 'Endpoints', sagemaker_session)
8177
return
8278
except ClientError as ce:
8379
if ce.response['Error']['Code'] == 'ValidationException':
@@ -87,8 +83,34 @@ def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, second
8783
sleep(10)
8884

8985

90-
def _show_endpoint_logs(endpoint_name, sagemaker_session):
91-
log_group = '/aws/sagemaker/Endpoints/{}'.format(endpoint_name)
86+
@contextmanager
87+
def timeout_and_delete_model_with_transformer(transformer, sagemaker_session, seconds=0, minutes=0, hours=0):
88+
with timeout(seconds=seconds, minutes=minutes, hours=hours) as t:
89+
no_errors = False
90+
try:
91+
yield [t]
92+
no_errors = True
93+
finally:
94+
attempts = 3
95+
96+
while attempts > 0:
97+
attempts -= 1
98+
try:
99+
transformer.delete_model()
100+
LOGGER.info('deleted SageMaker model {}'.format(transformer.model_name))
101+
102+
_show_logs(transformer.model_name, 'Models', sagemaker_session)
103+
if no_errors:
104+
_cleanup_logs(transformer.model_name, 'Models', sagemaker_session)
105+
return
106+
except ClientError as ce:
107+
if ce.response['Error']['Code'] == 'ValidationException':
108+
pass
109+
sleep(10)
110+
111+
112+
def _show_logs(resource_name, resource_type, sagemaker_session):
113+
log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name)
92114
try:
93115
# print out logs before deletion for debuggability
94116
LOGGER.info('cloudwatch logs for log group {}:'.format(log_group))
@@ -100,8 +122,8 @@ def _show_endpoint_logs(endpoint_name, sagemaker_session):
100122
'stacktrace for debugging.', log_group)
101123

102124

103-
def _cleanup_endpoint_logs(endpoint_name, sagemaker_session):
104-
log_group = '/aws/sagemaker/Endpoints/{}'.format(endpoint_name)
125+
def _cleanup_logs(resource_name, resource_type, sagemaker_session):
126+
log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name)
105127
try:
106128
# print out logs before deletion for debuggability
107129
LOGGER.info('deleting cloudwatch log group {}:'.format(log_group))

tests/unit/test_local_session.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
BAD_RESPONSE = urllib3.HTTPResponse()
2828
BAD_RESPONSE.status = 502
2929

30+
ENDPOINT_CONFIG_NAME = 'test-endpoint-config'
31+
PRODUCTION_VARIANTS = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}]
32+
33+
MODEL_NAME = 'test-model'
34+
PRIMARY_CONTAINER = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}
35+
3036

3137
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
3238
@patch('sagemaker.local.local_session.LocalSession')
@@ -148,25 +154,32 @@ def test_create_training_job_not_fully_replicated(train, LocalSession):
148154
@patch('sagemaker.local.local_session.LocalSession')
149155
def test_create_model(LocalSession):
150156
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
151-
model_name = 'my-model'
152-
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}
153157

154-
local_sagemaker_client.create_model(model_name, primary_container)
158+
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
159+
160+
assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models
161+
162+
163+
@patch('sagemaker.local.local_session.LocalSession')
164+
def test_delete_model(LocalSession):
165+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
166+
167+
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
168+
assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models
155169

156-
assert 'my-model' in sagemaker.local.local_session.LocalSagemakerClient._models
170+
local_sagemaker_client.delete_model(MODEL_NAME)
171+
assert MODEL_NAME not in sagemaker.local.local_session.LocalSagemakerClient._models
157172

158173

159174
@patch('sagemaker.local.local_session.LocalSession')
160175
def test_describe_model(LocalSession):
161176
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
162-
model_name = 'test-model'
163-
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}
164177

165178
with pytest.raises(ClientError):
166179
local_sagemaker_client.describe_model('model-does-not-exist')
167180

168-
local_sagemaker_client.create_model(model_name, primary_container)
169-
response = local_sagemaker_client.describe_model('test-model')
181+
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
182+
response = local_sagemaker_client.describe_model(MODEL_NAME)
170183

171184
assert response['ModelName'] == 'test-model'
172185
assert response['PrimaryContainer']['ModelDataUrl'] == '/some/model/path'
@@ -212,10 +225,20 @@ def test_describe_endpoint_config(LocalSession):
212225
@patch('sagemaker.local.local_session.LocalSession')
213226
def test_create_endpoint_config(LocalSession):
214227
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
215-
production_variants = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}]
216-
local_sagemaker_client.create_endpoint_config('my-endpoint-config', production_variants)
228+
local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)
229+
230+
assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
231+
232+
233+
@patch('sagemaker.local.local_session.LocalSession')
234+
def test_delete_endpoint_config(LocalSession):
235+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
236+
237+
local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)
238+
assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
217239

218-
assert 'my-endpoint-config' in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
240+
local_sagemaker_client.delete_endpoint_config(ENDPOINT_CONFIG_NAME)
241+
assert ENDPOINT_CONFIG_NAME not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
219242

220243

221244
@patch('sagemaker.local.image._SageMakerContainer.serve')
@@ -316,7 +339,7 @@ def test_update_endpoint(LocalSession):
316339
endpoint_name = 'my-endpoint'
317340
endpoint_config = 'my-endpoint-config'
318341
expected_error_message = 'Update endpoint name is not supported in local session.'
319-
with pytest.raises(NotImplementedError, message=expected_error_message):
342+
with pytest.raises(NotImplementedError, match=expected_error_message):
320343
local_sagemaker_client.update_endpoint(endpoint_name, endpoint_config)
321344

322345

0 commit comments

Comments
 (0)