Skip to content

Add new APIs to clean up resources from predictor and transformer. #630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ CHANGELOG

* doc-fix: update information about saving models in the MXNet README
* doc-fix: change ReadTheDocs links from latest to stable
* feature: Support for predictor class to delete endpoint configuration by default when calling ``delete_endpoint()``
* feature: Support for model class to delete SageMaker model
* feature: Support for transformer to delete Sagemaker model

1.18.2
======
Expand Down
23 changes: 16 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,16 @@ Here is an end to end example of how to use a SageMaker Estimator:
# Serializes data and makes a prediction request to the SageMaker endpoint
response = mxnet_predictor.predict(data)

# Tears down the SageMaker endpoint
mxnet_estimator.delete_endpoint()
# Tears down the SageMaker endpoint and endpoint configuration
mxnet_predictor.delete_endpoint()


The example above will eventually delete both the SageMaker endpoint and endpoint configuration through `delete_endpoint()`. If you want to keep your SageMaker endpoint configuration, use the value False for the `delete_endpoint_config` parameter, as shown below.

.. code:: python
# Only delete the SageMaker endpoint, while keeping the corresponding endpoint configuration.
mxnet_predictor.delete_endpoint(delete_endpoint_config=False)

Additionally, it is possible to deploy a different endpoint configuration, which links to your model, to an already existing SageMaker endpoint.
This can be done by specifying the existing endpoint name for the ``endpoint_name`` parameter along with the ``update_endpoint`` parameter as ``True`` within your ``deploy()`` call.
For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.update_endpoint>`__.
Expand Down Expand Up @@ -220,8 +226,8 @@ For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/r
# Serializes data and makes a prediction request to the SageMaker endpoint
response = mxnet_predictor.predict(data)

# Tears down the SageMaker endpoint
mxnet_estimator.delete_endpoint()
# Tears down the SageMaker endpoint and endpoint configuration
mxnet_predictor.delete_endpoint()

Training Metrics
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -274,8 +280,8 @@ We can take the example in `Using Estimators <#using-estimators>`__ , and use e
# Serializes data and makes a prediction request to the local endpoint
response = mxnet_predictor.predict(data)

# Tears down the endpoint container
mxnet_estimator.delete_endpoint()
# Tears down the endpoint container and deletes the corresponding endpoint configuration
mxnet_predictor.delete_endpoint()


If you have an existing model and want to deploy it locally, don't specify a sagemaker_session argument to the ``MXNetModel`` constructor.
Expand All @@ -297,7 +303,7 @@ Here is an end-to-end example:
data = numpy.zeros(shape=(1, 1, 28, 28))
predictor.predict(data)

# Tear down the endpoint container
# Tear down the endpoint container and delete the corresponding endpoint configuration
predictor.delete_endpoint()


Expand All @@ -322,6 +328,9 @@ Here is an end-to-end example:
transformer.transform('s3://my/transform/data, content_type='text/csv', split_type='Line')
transformer.wait()

# Deletes the SageMaker model
transformer.delete_model()


For detailed examples of running Docker in local mode, see:

Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def delete_endpoint(self, EndpointName):
if EndpointName in LocalSagemakerClient._endpoints:
LocalSagemakerClient._endpoints[EndpointName].stop()

def delete_endpoint_config(self, EndpointConfigName):
if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
del LocalSagemakerClient._endpoint_configs[EndpointConfigName]

def delete_model(self, ModelName):
if ModelName in LocalSagemakerClient._models:
del LocalSagemakerClient._models[ModelName]


class LocalSagemakerRuntimeClient(object):
"""A SageMaker Runtime client that calls a local endpoint only.
Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
env=env, tags=tags, base_transform_job_name=self.name,
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)

def delete_model(self):
"""Delete an Amazon SageMaker ``Model``.

Raises: ValueError if model is not deployed yet.

"""
if self.name is None:
raise ValueError('The SageMaker model must be deployed first before attempting to delete.')
self.sagemaker_session.delete_model(self.name)


SCRIPT_PARAM_NAME = 'sagemaker_program'
DIR_PARAM_NAME = 'sagemaker_submit_directory'
Expand Down
19 changes: 17 additions & 2 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,24 @@ def _create_request_args(self, data, initial_args=None):
args['Body'] = data
return args

def delete_endpoint(self):
"""Delete the Amazon SageMaker endpoint backing this predictor.
def _delete_endpoint_config(self):
"""Delete the Amazon SageMaker endpoint configuration
"""
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
endpoint_config_name = endpoint_description['EndpointConfigName']
self.sagemaker_session.delete_endpoint_config(endpoint_config_name)

def delete_endpoint(self, delete_endpoint_config=True):
"""Delete the Amazon SageMaker endpoint and endpoint configuration backing this predictor.

Args:
delete_endpoint_config (bool): Flag to indicate whether to delete the corresponding SageMaker endpoint
configuration tied to the endpoint. If False, only the endpoint will be deleted. (default: True)

"""
if delete_endpoint_config:
self._delete_endpoint_config()

self.sagemaker_session.delete_endpoint(self.endpoint)


Expand Down
19 changes: 19 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,25 @@ def delete_endpoint(self, endpoint_name):
LOGGER.info('Deleting endpoint with name: {}'.format(endpoint_name))
self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)

def delete_endpoint_config(self, endpoint_config_name):
"""Delete an Amazon SageMaker endpoint configuration.

Args:
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to delete.
"""
LOGGER.info('Deleting endpoint configuration with name: {}'.format(endpoint_config_name))
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

def delete_model(self, model_name):
"""Delete an Amazon SageMaker ``Model``.

Args:
model_name (str): Name of the Amazon SageMaker model to delete.

"""
LOGGER.info('Deleting model with name: {}'.format(model_name))
self.sagemaker_client.delete_model(ModelName=model_name)

def wait_for_job(self, job, poll=5):
"""Wait for an Amazon SageMaker training job to complete.

Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
split_type)

def delete_model(self):
"""Delete the corresponding SageMaker model for this Transformer.

"""
self.sagemaker_session.delete_model(self.model_name)

def _retrieve_image_name(self):
model_desc = self.sagemaker_session.sagemaker_client.describe_model(ModelName=self.model_name)
return model_desc['PrimaryContainer']['Image']
Expand Down
8 changes: 5 additions & 3 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sagemaker.transformer import Transformer
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
from tests.integ.kms_utils import get_or_create_kms_key
from tests.integ.timeout import timeout
from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer
from tests.integ.vpc_test_utils import get_or_create_vpc_resources


Expand Down Expand Up @@ -56,7 +56,8 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version):
kms_key_arn = get_or_create_kms_key(kms_client, account_id)

transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
transformer.wait()

job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
Expand Down Expand Up @@ -100,7 +101,8 @@ def test_attach_transform_kmeans(sagemaker_session):

attached_transformer = Transformer.attach(transformer.latest_transform_job.name,
sagemaker_session=sagemaker_session)
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
attached_transformer.wait()


Expand Down
38 changes: 32 additions & 6 deletions tests/integ/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, second
sagemaker_session.delete_endpoint(endpoint_name)
LOGGER.info('deleted endpoint {}'.format(endpoint_name))

_show_endpoint_logs(endpoint_name, sagemaker_session)
_show_logs(endpoint_name, 'Endpoints', sagemaker_session)
if no_errors:
_cleanup_endpoint_logs(endpoint_name, sagemaker_session)
_cleanup_logs(endpoint_name, 'Endpoints', sagemaker_session)
return
except ClientError as ce:
if ce.response['Error']['Code'] == 'ValidationException':
Expand All @@ -87,8 +87,34 @@ def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, second
sleep(10)


def _show_endpoint_logs(endpoint_name, sagemaker_session):
log_group = '/aws/sagemaker/Endpoints/{}'.format(endpoint_name)
@contextmanager
def timeout_and_delete_model_with_transformer(transformer, sagemaker_session, seconds=0, minutes=0, hours=0):
with timeout(seconds=seconds, minutes=minutes, hours=hours) as t:
no_errors = False
try:
yield [t]
no_errors = True
finally:
attempts = 3

while attempts > 0:
attempts -= 1
try:
transformer.delete_model()
LOGGER.info('deleted SageMaker model {}'.format(transformer.model_name))

_show_logs(transformer.model_name, 'Models', sagemaker_session)
if no_errors:
_cleanup_logs(transformer.model_name, 'Models', sagemaker_session)
return
except ClientError as ce:
if ce.response['Error']['Code'] == 'ValidationException':
pass
sleep(10)


def _show_logs(resource_name, resource_type, sagemaker_session):
log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name)
try:
# print out logs before deletion for debuggability
LOGGER.info('cloudwatch logs for log group {}:'.format(log_group))
Expand All @@ -100,8 +126,8 @@ def _show_endpoint_logs(endpoint_name, sagemaker_session):
'stacktrace for debugging.', log_group)


def _cleanup_endpoint_logs(endpoint_name, sagemaker_session):
log_group = '/aws/sagemaker/Endpoints/{}'.format(endpoint_name)
def _cleanup_logs(resource_name, resource_type, sagemaker_session):
log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name)
try:
# print out logs before deletion for debuggability
LOGGER.info('deleting cloudwatch log group {}:'.format(log_group))
Expand Down
47 changes: 35 additions & 12 deletions tests/unit/test_local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
BAD_RESPONSE = urllib3.HTTPResponse()
BAD_RESPONSE.status = 502

ENDPOINT_CONFIG_NAME = 'test-endpoint-config'
PRODUCTION_VARIANTS = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}]

MODEL_NAME = 'test-model'
PRIMARY_CONTAINER = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}


@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
@patch('sagemaker.local.local_session.LocalSession')
Expand Down Expand Up @@ -148,25 +154,32 @@ def test_create_training_job_not_fully_replicated(train, LocalSession):
@patch('sagemaker.local.local_session.LocalSession')
def test_create_model(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
model_name = 'my-model'
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}

local_sagemaker_client.create_model(model_name, primary_container)
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)

assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models


@patch('sagemaker.local.local_session.LocalSession')
def test_delete_model(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()

local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models

assert 'my-model' in sagemaker.local.local_session.LocalSagemakerClient._models
local_sagemaker_client.delete_model(MODEL_NAME)
assert MODEL_NAME not in sagemaker.local.local_session.LocalSagemakerClient._models


@patch('sagemaker.local.local_session.LocalSession')
def test_describe_model(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
model_name = 'test-model'
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}

with pytest.raises(ClientError):
local_sagemaker_client.describe_model('model-does-not-exist')

local_sagemaker_client.create_model(model_name, primary_container)
response = local_sagemaker_client.describe_model('test-model')
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
response = local_sagemaker_client.describe_model(MODEL_NAME)

assert response['ModelName'] == 'test-model'
assert response['PrimaryContainer']['ModelDataUrl'] == '/some/model/path'
Expand Down Expand Up @@ -212,10 +225,20 @@ def test_describe_endpoint_config(LocalSession):
@patch('sagemaker.local.local_session.LocalSession')
def test_create_endpoint_config(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
production_variants = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}]
local_sagemaker_client.create_endpoint_config('my-endpoint-config', production_variants)
local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)

assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs


@patch('sagemaker.local.local_session.LocalSession')
def test_delete_endpoint_config(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()

local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)
assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs

assert 'my-endpoint-config' in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
local_sagemaker_client.delete_endpoint_config(ENDPOINT_CONFIG_NAME)
assert ENDPOINT_CONFIG_NAME not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs


@patch('sagemaker.local.image._SageMakerContainer.serve')
Expand Down Expand Up @@ -316,7 +339,7 @@ def test_update_endpoint(LocalSession):
endpoint_name = 'my-endpoint'
endpoint_config = 'my-endpoint-config'
expected_error_message = 'Update endpoint name is not supported in local session.'
with pytest.raises(NotImplementedError, message=expected_error_message):
with pytest.raises(NotImplementedError, match=expected_error_message):
local_sagemaker_client.update_endpoint(endpoint_name, endpoint_config)


Expand Down
20 changes: 19 additions & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, sagemaker_session, **kwargs):
sagemaker_session=sagemaker_session, **kwargs)

def create_predictor(self, endpoint_name):
return RealTimePredictor(endpoint_name, self.sagemaker_session)
return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session)


@pytest.fixture()
Expand Down Expand Up @@ -335,3 +335,21 @@ def test_model_package_create_transformer_with_product_id(sagemaker_session):
assert transformer.model_name == 'auto-generated-model'
assert transformer.instance_type == 'ml.m4.xlarge'
assert transformer.env is None


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@patch('time.strftime', MagicMock(return_value=TIMESTAMP))
def test_model_delete_model(sagemaker_session, tmpdir):
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1)
model.delete_model()

sagemaker_session.delete_model.assert_called_with(model.name)


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@patch('time.strftime', MagicMock(return_value=TIMESTAMP))
def test_delete_non_deployed_model(sagemaker_session, tmpdir):
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
with pytest.raises(ValueError, match='The SageMaker model must be deployed first before attempting to delete.'):
model.delete_model()
19 changes: 19 additions & 0 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,22 @@ def test_predict_call_with_headers_and_csv():
assert kwargs == expected_request_args

assert result == CSV_RETURN_VALUE


def test_delete_endpoint_with_config():
sagemaker_session = empty_sagemaker_session()
sagemaker_session.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointConfigName': 'endpoint-config'})
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
predictor.delete_endpoint()

sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
sagemaker_session.delete_endpoint_config.assert_called_with('endpoint-config')


def test_delete_endpoint_only():
sagemaker_session = empty_sagemaker_session()
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
predictor.delete_endpoint(delete_endpoint_config=False)

sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
sagemaker_session.delete_endpoint_config.assert_not_called()
Loading