Skip to content

Add new APIs to clean up resources from predictor and transformer. #630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ CHANGELOG
* doc-fix: update information about saving models in the MXNet README
* doc-fix: change ReadTheDocs links from latest to stable
* doc-fix: add ``transform_fn`` information and fix ``input_fn`` signature in the MXNet README
* feature: Support for ``Predictor`` to delete endpoint configuration by default when calling ``delete_endpoint()``
* feature: Support for ``model`` to delete SageMaker model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the model should be uppercase, but we can fix in the other PR.

* feature: Support for ``Transformer`` to delete SageMaker model

1.18.2
======
Expand Down
23 changes: 16 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,16 @@ Here is an end to end example of how to use a SageMaker Estimator:
# Serializes data and makes a prediction request to the SageMaker endpoint
response = mxnet_predictor.predict(data)

# Tears down the SageMaker endpoint
mxnet_estimator.delete_endpoint()
# Tears down the SageMaker endpoint and endpoint configuration
mxnet_predictor.delete_endpoint()


The example above will eventually delete both the SageMaker endpoint and endpoint configuration through `delete_endpoint()`. If you want to keep your SageMaker endpoint configuration, use the value False for the `delete_endpoint_config` parameter, as shown below.

.. code:: python
# Only delete the SageMaker endpoint, while keeping the corresponding endpoint configuration.
mxnet_predictor.delete_endpoint(delete_endpoint_config=False)

Additionally, it is possible to deploy a different endpoint configuration, which links to your model, to an already existing SageMaker endpoint.
This can be done by specifying the existing endpoint name for the ``endpoint_name`` parameter along with the ``update_endpoint`` parameter as ``True`` within your ``deploy()`` call.
For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.update_endpoint>`__.
Expand Down Expand Up @@ -220,8 +226,8 @@ For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/r
# Serializes data and makes a prediction request to the SageMaker endpoint
response = mxnet_predictor.predict(data)

# Tears down the SageMaker endpoint
mxnet_estimator.delete_endpoint()
# Tears down the SageMaker endpoint and endpoint configuration
mxnet_predictor.delete_endpoint()

Training Metrics
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -274,8 +280,8 @@ We can take the example in `Using Estimators <#using-estimators>`__ , and use e
# Serializes data and makes a prediction request to the local endpoint
response = mxnet_predictor.predict(data)

# Tears down the endpoint container
mxnet_estimator.delete_endpoint()
# Tears down the endpoint container and deletes the corresponding endpoint configuration
mxnet_predictor.delete_endpoint()


If you have an existing model and want to deploy it locally, don't specify a sagemaker_session argument to the ``MXNetModel`` constructor.
Expand All @@ -297,7 +303,7 @@ Here is an end-to-end example:
data = numpy.zeros(shape=(1, 1, 28, 28))
predictor.predict(data)

# Tear down the endpoint container
# Tear down the endpoint container and delete the corresponding endpoint configuration
predictor.delete_endpoint()


Expand All @@ -322,6 +328,9 @@ Here is an end-to-end example:
transformer.transform('s3://my/transform/data, content_type='text/csv', split_type='Line')
transformer.wait()

# Deletes the SageMaker model
transformer.delete_model()


For detailed examples of running Docker in local mode, see:

Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def delete_endpoint(self, EndpointName):
if EndpointName in LocalSagemakerClient._endpoints:
LocalSagemakerClient._endpoints[EndpointName].stop()

def delete_endpoint_config(self, EndpointConfigName):
if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
del LocalSagemakerClient._endpoint_configs[EndpointConfigName]

def delete_model(self, ModelName):
if ModelName in LocalSagemakerClient._models:
del LocalSagemakerClient._models[ModelName]


class LocalSagemakerRuntimeClient(object):
"""A SageMaker Runtime client that calls a local endpoint only.
Expand Down
11 changes: 11 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
env=env, tags=tags, base_transform_job_name=self.name,
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)

def delete_model(self):
"""Delete an Amazon SageMaker Model.

Raises:
ValueError: if the model is not created yet.

"""
if self.name is None:
raise ValueError('The SageMaker model must be created first before attempting to delete.')
self.sagemaker_session.delete_model(self.name)


SCRIPT_PARAM_NAME = 'sagemaker_program'
DIR_PARAM_NAME = 'sagemaker_submit_directory'
Expand Down
20 changes: 18 additions & 2 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,25 @@ def _create_request_args(self, data, initial_args=None):
args['Body'] = data
return args

def delete_endpoint(self):
"""Delete the Amazon SageMaker endpoint backing this predictor.
def _delete_endpoint_config(self):
"""Delete the Amazon SageMaker endpoint configuration

"""
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
endpoint_config_name = endpoint_description['EndpointConfigName']
self.sagemaker_session.delete_endpoint_config(endpoint_config_name)

def delete_endpoint(self, delete_endpoint_config=True):
"""Delete the Amazon SageMaker endpoint and endpoint configuration backing this predictor.

Args:
delete_endpoint_config (bool): Flag to indicate whether to delete the corresponding SageMaker endpoint
configuration tied to the endpoint. If False, only the endpoint will be deleted. (default: True)

"""
if delete_endpoint_config:
self._delete_endpoint_config()

self.sagemaker_session.delete_endpoint(self.endpoint)


Expand Down
19 changes: 19 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,25 @@ def delete_endpoint(self, endpoint_name):
LOGGER.info('Deleting endpoint with name: {}'.format(endpoint_name))
self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)

def delete_endpoint_config(self, endpoint_config_name):
"""Delete an Amazon SageMaker endpoint configuration.

Args:
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to delete.
"""
LOGGER.info('Deleting endpoint configuration with name: {}'.format(endpoint_config_name))
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

def delete_model(self, model_name):
"""Delete an Amazon SageMaker Model.

Args:
model_name (str): Name of the Amazon SageMaker model to delete.

"""
LOGGER.info('Deleting model with name: {}'.format(model_name))
self.sagemaker_client.delete_model(ModelName=model_name)

def wait_for_job(self, job, poll=5):
"""Wait for an Amazon SageMaker training job to complete.

Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
split_type)

def delete_model(self):
"""Delete the corresponding SageMaker model for this Transformer.

"""
self.sagemaker_session.delete_model(self.model_name)

def _retrieve_image_name(self):
model_desc = self.sagemaker_session.sagemaker_client.describe_model(ModelName=self.model_name)
return model_desc['PrimaryContainer']['Image']
Expand Down
18 changes: 10 additions & 8 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sagemaker.transformer import Transformer
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
from tests.integ.kms_utils import get_or_create_kms_key
from tests.integ.timeout import timeout
from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer
from tests.integ.vpc_test_utils import get_or_create_vpc_resources


Expand Down Expand Up @@ -56,7 +56,8 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version):
kms_key_arn = get_or_create_kms_key(kms_client, account_id)

transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
transformer.wait()

job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
Expand Down Expand Up @@ -100,7 +101,8 @@ def test_attach_transform_kmeans(sagemaker_session):

attached_transformer = Transformer.attach(transformer.latest_transform_job.name,
sagemaker_session=sagemaker_session)
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
attached_transformer.wait()


Expand Down Expand Up @@ -135,12 +137,12 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
key_prefix=transform_input_key_prefix)

transformer = _create_transformer_and_transform_job(mx, transform_input)
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
transformer.wait()

model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
assert set(subnet_ids) == set(model_desc['VpcConfig']['Subnets'])
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
assert set(subnet_ids) == set(model_desc['VpcConfig']['Subnets'])
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']


def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
Expand Down
42 changes: 32 additions & 10 deletions tests/integ/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ def timeout(seconds=0, minutes=0, hours=0):
"""
Add a signal-based timeout to any block of code.
If multiple time units are specified, they will be added together to determine time limit.

Usage:

with timeout(seconds=5):
my_slow_function(...)


Args:
- seconds: The time limit, in seconds.
- minutes: The time limit, in minutes.
Expand Down Expand Up @@ -75,9 +71,9 @@ def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, second
sagemaker_session.delete_endpoint(endpoint_name)
LOGGER.info('deleted endpoint {}'.format(endpoint_name))

_show_endpoint_logs(endpoint_name, sagemaker_session)
_show_logs(endpoint_name, 'Endpoints', sagemaker_session)
if no_errors:
_cleanup_endpoint_logs(endpoint_name, sagemaker_session)
_cleanup_logs(endpoint_name, 'Endpoints', sagemaker_session)
return
except ClientError as ce:
if ce.response['Error']['Code'] == 'ValidationException':
Expand All @@ -87,8 +83,34 @@ def timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, second
sleep(10)


def _show_endpoint_logs(endpoint_name, sagemaker_session):
log_group = '/aws/sagemaker/Endpoints/{}'.format(endpoint_name)
@contextmanager
def timeout_and_delete_model_with_transformer(transformer, sagemaker_session, seconds=0, minutes=0, hours=0):
with timeout(seconds=seconds, minutes=minutes, hours=hours) as t:
no_errors = False
try:
yield [t]
no_errors = True
finally:
attempts = 3

while attempts > 0:
attempts -= 1
try:
transformer.delete_model()
LOGGER.info('deleted SageMaker model {}'.format(transformer.model_name))

_show_logs(transformer.model_name, 'Models', sagemaker_session)
if no_errors:
_cleanup_logs(transformer.model_name, 'Models', sagemaker_session)
return
except ClientError as ce:
if ce.response['Error']['Code'] == 'ValidationException':
pass
sleep(10)


def _show_logs(resource_name, resource_type, sagemaker_session):
log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name)
try:
# print out logs before deletion for debuggability
LOGGER.info('cloudwatch logs for log group {}:'.format(log_group))
Expand All @@ -100,8 +122,8 @@ def _show_endpoint_logs(endpoint_name, sagemaker_session):
'stacktrace for debugging.', log_group)


def _cleanup_endpoint_logs(endpoint_name, sagemaker_session):
log_group = '/aws/sagemaker/Endpoints/{}'.format(endpoint_name)
def _cleanup_logs(resource_name, resource_type, sagemaker_session):
log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name)
try:
# print out logs before deletion for debuggability
LOGGER.info('deleting cloudwatch log group {}:'.format(log_group))
Expand Down
47 changes: 35 additions & 12 deletions tests/unit/test_local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
BAD_RESPONSE = urllib3.HTTPResponse()
BAD_RESPONSE.status = 502

ENDPOINT_CONFIG_NAME = 'test-endpoint-config'
PRODUCTION_VARIANTS = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}]

MODEL_NAME = 'test-model'
PRIMARY_CONTAINER = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}


@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
@patch('sagemaker.local.local_session.LocalSession')
Expand Down Expand Up @@ -148,25 +154,32 @@ def test_create_training_job_not_fully_replicated(train, LocalSession):
@patch('sagemaker.local.local_session.LocalSession')
def test_create_model(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
model_name = 'my-model'
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}

local_sagemaker_client.create_model(model_name, primary_container)
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)

assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models


@patch('sagemaker.local.local_session.LocalSession')
def test_delete_model(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()

local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models

assert 'my-model' in sagemaker.local.local_session.LocalSagemakerClient._models
local_sagemaker_client.delete_model(MODEL_NAME)
assert MODEL_NAME not in sagemaker.local.local_session.LocalSagemakerClient._models


@patch('sagemaker.local.local_session.LocalSession')
def test_describe_model(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
model_name = 'test-model'
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}

with pytest.raises(ClientError):
local_sagemaker_client.describe_model('model-does-not-exist')

local_sagemaker_client.create_model(model_name, primary_container)
response = local_sagemaker_client.describe_model('test-model')
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
response = local_sagemaker_client.describe_model(MODEL_NAME)

assert response['ModelName'] == 'test-model'
assert response['PrimaryContainer']['ModelDataUrl'] == '/some/model/path'
Expand Down Expand Up @@ -212,10 +225,20 @@ def test_describe_endpoint_config(LocalSession):
@patch('sagemaker.local.local_session.LocalSession')
def test_create_endpoint_config(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
production_variants = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}]
local_sagemaker_client.create_endpoint_config('my-endpoint-config', production_variants)
local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)

assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs


@patch('sagemaker.local.local_session.LocalSession')
def test_delete_endpoint_config(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()

local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)
assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs

assert 'my-endpoint-config' in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
local_sagemaker_client.delete_endpoint_config(ENDPOINT_CONFIG_NAME)
assert ENDPOINT_CONFIG_NAME not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs


@patch('sagemaker.local.image._SageMakerContainer.serve')
Expand Down Expand Up @@ -316,7 +339,7 @@ def test_update_endpoint(LocalSession):
endpoint_name = 'my-endpoint'
endpoint_config = 'my-endpoint-config'
expected_error_message = 'Update endpoint name is not supported in local session.'
with pytest.raises(NotImplementedError, message=expected_error_message):
with pytest.raises(NotImplementedError, match=expected_error_message):
local_sagemaker_client.update_endpoint(endpoint_name, endpoint_config)


Expand Down
Loading