Skip to content

Add new APIs to clean up resources from predictor and transformer. #630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ CHANGELOG

* doc-fix: update information about saving models in the MXNet README
* doc-fix: change ReadTheDocs links from latest to stable
* feature: Support for predictor to delete endpoint and endpoint configuration
* feature: Support for transformer to delete model

1.18.2
======
Expand Down
12 changes: 9 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,14 @@ Here is an end to end example of how to use a SageMaker Estimator:
response = mxnet_predictor.predict(data)

# Tears down the SageMaker endpoint
mxnet_estimator.delete_endpoint()
mxnet_predictor.delete_endpoint()


The above example will delete endpoint and endpoint configuration at the same time. If you want to keep endpoint configuration, you can do the following:

.. code:: python
mxnet_predictor.delete_endpoint(delete_endpoint_config=False)

Additionally, it is possible to deploy a different endpoint configuration, which links to your model, to an already existing SageMaker endpoint.
This can be done by specifying the existing endpoint name for the ``endpoint_name`` parameter along with the ``update_endpoint`` parameter as ``True`` within your ``deploy()`` call.
For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.update_endpoint>`__.
Expand Down Expand Up @@ -221,7 +226,7 @@ For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/r
response = mxnet_predictor.predict(data)

# Tears down the SageMaker endpoint
mxnet_estimator.delete_endpoint()
mxnet_predictor.delete_endpoint()

Training Metrics
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -275,7 +280,7 @@ We can take the example in `Using Estimators <#using-estimators>`__ , and use e
response = mxnet_predictor.predict(data)

# Tears down the endpoint container
mxnet_estimator.delete_endpoint()
mxnet_predictor.delete_endpoint()


If you have an existing model and want to deploy it locally, don't specify a sagemaker_session argument to the ``MXNetModel`` constructor.
Expand Down Expand Up @@ -321,6 +326,7 @@ Here is an end-to-end example:
transformer = mxnet_estimator.transformer(1, 'local', assemble_with='Line', max_payload=1)
transformer.transform('s3://my/transform/data, content_type='text/csv', split_type='Line')
transformer.wait()
transformer.delete_model()


For detailed examples of running Docker in local mode, see:
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def delete_endpoint(self, EndpointName):
if EndpointName in LocalSagemakerClient._endpoints:
LocalSagemakerClient._endpoints[EndpointName].stop()

def delete_endpoint_config(self, EndpointConfigName):
if EndpointConfigName in LocalSagemakerClient._endpoint_configs:
del LocalSagemakerClient._endpoint_configs[EndpointConfigName]

def delete_model(self, ModelName):
if ModelName in LocalSagemakerClient._models:
del LocalSagemakerClient._models[ModelName]


class LocalSagemakerRuntimeClient(object):
"""A SageMaker Runtime client that calls a local endpoint only.
Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
env=env, tags=tags, base_transform_job_name=self.name,
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)

def delete_model(self):
"""Delete an Amazon SageMaker ``Model``.

Raises: ValueError if model is not deployed yet.

"""
if self.name is None:
raise ValueError('The SageMaker model is not deployed yet.')
self.sagemaker_session.delete_model(self.name)


SCRIPT_PARAM_NAME = 'sagemaker_program'
DIR_PARAM_NAME = 'sagemaker_submit_directory'
Expand Down
20 changes: 18 additions & 2 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,27 @@ def _create_request_args(self, data, initial_args=None):
args['Body'] = data
return args

def delete_endpoint(self):
"""Delete the Amazon SageMaker endpoint backing this predictor.
def _delete_endpoint_config(self):
"""Delete the Amazon SageMaker endpoint configuration
"""
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
endpoint_config_name = endpoint_description['EndpointConfigName']
self.sagemaker_session.delete_endpoint_config(endpoint_config_name)

def delete_endpoint(self, delete_endpoint_config=True):
"""Delete the Amazon SageMaker endpoint backing this predictor. Also delete the endpoint configuration attached
to it if delete_endpoint_config is True.

Args:
delete_endpoint_config (bool): Flag to indicate whether to delete endpoint configuration together with
endpoint. If False, only endpoint will be deleted. Default: True.

"""
self.sagemaker_session.delete_endpoint(self.endpoint)

if delete_endpoint_config:
self._delete_endpoint_config()


class _CsvSerializer(object):
def __init__(self):
Expand Down
18 changes: 18 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,24 @@ def delete_endpoint(self, endpoint_name):
LOGGER.info('Deleting endpoint with name: {}'.format(endpoint_name))
self.sagemaker_client.delete_endpoint(EndpointName=endpoint_name)

def delete_endpoint_config(self, endpoint_config_name):
"""Delete an Amazon SageMaker endpoint configuration.

Args:
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to delete.
"""
LOGGER.info('Deleting endpoint configuration with name: {}'.format(endpoint_config_name))
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

def delete_model(self, model_name):
"""Delete an Amazon SageMaker ``Model``.

Args:
model_name (str): Name of the Amazon SageMaker model to delete.
"""
LOGGER.info('Deleting model with name: {}'.format(model_name))
self.sagemaker_client.delete_model(ModelName=model_name)

def wait_for_job(self, job, poll=5):
"""Wait for an Amazon SageMaker training job to complete.

Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
using the default AWS configuration chain.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
model (sagemaker.model.Model): A SageMaker Model object, used for SageMaker Model interactions
(default: None). If not specified, model object related activities will fail.
"""
self.model_name = model_name
self.strategy = strategy
Expand Down Expand Up @@ -112,6 +114,12 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
split_type)

def delete_model(self):
"""Delete a SageMaker Model.

"""
self.sagemaker_session.delete_model(self.model_name)

def _retrieve_image_name(self):
model_desc = self.sagemaker_session.sagemaker_client.describe_model(ModelName=self.model_name)
return model_desc['PrimaryContainer']['Image']
Expand Down
40 changes: 40 additions & 0 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,46 @@ def test_attach_transform_kmeans(sagemaker_session):
attached_transformer.wait()


def test_transformer_delete_model(sagemaker_session):
data_path = os.path.join(DATA_DIR, 'one_p_mnist')
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}

train_set_path = os.path.join(data_path, 'mnist.pkl.gz')
with gzip.open(train_set_path, 'rb') as f:
train_set, _, _ = pickle.load(f, **pickle_args)

kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session,
output_path='s3://{}/'.format(sagemaker_session.default_bucket()))

kmeans.init_method = 'random'
kmeans.max_iterations = 1
kmeans.tol = 1
kmeans.num_trials = 1
kmeans.local_init_method = 'kmeans++'
kmeans.half_life_time_size = 1
kmeans.epochs = 1

records = kmeans.record_set(train_set[0][:100])
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
kmeans.fit(records)

transform_input_path = os.path.join(data_path, 'transform_input.csv')
transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform'
transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path,
key_prefix=transform_input_key_prefix)

transformer = _create_transformer_and_transform_job(kmeans, transform_input)
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
transformer.wait()

transformer.delete_model()

with pytest.raises(Exception) as exception:
sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
assert 'Could not find model' in exception.value.message


def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
script_path = os.path.join(data_path, 'mnist.py')
Expand Down
28 changes: 27 additions & 1 deletion tests/unit/test_local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ def test_create_model(LocalSession):
assert 'my-model' in sagemaker.local.local_session.LocalSagemakerClient._models


@patch('sagemaker.local.local_session.LocalSession')
def test_delete_model(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
model_name = 'my-model'
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}

local_sagemaker_client.create_model(model_name, primary_container)
assert model_name in sagemaker.local.local_session.LocalSagemakerClient._models

local_sagemaker_client.delete_model(model_name)
assert model_name not in sagemaker.local.local_session.LocalSagemakerClient._models


@patch('sagemaker.local.local_session.LocalSession')
def test_describe_model(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
Expand Down Expand Up @@ -218,6 +231,19 @@ def test_create_endpoint_config(LocalSession):
assert 'my-endpoint-config' in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs


@patch('sagemaker.local.local_session.LocalSession')
def test_delete_endpoint_config(LocalSession):
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
production_variants = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}]
endpoint_config_name = 'my-endpoint-config'

local_sagemaker_client.create_endpoint_config(endpoint_config_name, production_variants)
assert endpoint_config_name in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs

local_sagemaker_client.delete_endpoint_config(endpoint_config_name)
assert endpoint_config_name not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs


@patch('sagemaker.local.image._SageMakerContainer.serve')
@patch('sagemaker.local.local_session.LocalSession')
@patch('urllib3.PoolManager.request')
Expand Down Expand Up @@ -316,7 +342,7 @@ def test_update_endpoint(LocalSession):
endpoint_name = 'my-endpoint'
endpoint_config = 'my-endpoint-config'
expected_error_message = 'Update endpoint name is not supported in local session.'
with pytest.raises(NotImplementedError, message=expected_error_message):
with pytest.raises(NotImplementedError, match=expected_error_message):
local_sagemaker_client.update_endpoint(endpoint_name, endpoint_config)


Expand Down
12 changes: 11 additions & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, sagemaker_session, **kwargs):
sagemaker_session=sagemaker_session, **kwargs)

def create_predictor(self, endpoint_name):
return RealTimePredictor(endpoint_name, self.sagemaker_session)
return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session)


@pytest.fixture()
Expand Down Expand Up @@ -335,3 +335,13 @@ def test_model_package_create_transformer_with_product_id(sagemaker_session):
assert transformer.model_name == 'auto-generated-model'
assert transformer.instance_type == 'ml.m4.xlarge'
assert transformer.env is None


@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@patch('time.strftime', MagicMock(return_value=TIMESTAMP))
def test_model_delete_model(sagemaker_session, tmpdir):
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1)
model.delete_model()

sagemaker_session.delete_model.assert_called_with(model.name)
19 changes: 19 additions & 0 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,22 @@ def test_predict_call_with_headers_and_csv():
assert kwargs == expected_request_args

assert result == CSV_RETURN_VALUE


def test_delete_endpoint_with_config():
sagemaker_session = empty_sagemaker_session()
sagemaker_session.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointConfigName': 'endpoint-config'})
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
predictor.delete_endpoint()

sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
sagemaker_session.delete_endpoint_config.assert_called_with('endpoint-config')


def test_delete_endpoint_only():
sagemaker_session = empty_sagemaker_session()
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
predictor.delete_endpoint(delete_endpoint_config=False)

sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
sagemaker_session.delete_endpoint_config.assert_not_called()
16 changes: 15 additions & 1 deletion tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ def test_delete_endpoint(boto_session):
boto_session.client().delete_endpoint.assert_called_with(EndpointName='my_endpoint')


def test_delete_endpoint_config(boto_session):
sess = Session(boto_session)
sess.delete_endpoint_config('my_endpoint_config')

boto_session.client().delete_endpoint_config.assert_called_with(EndpointConfigName='my_endpoint_config')


def test_delete_model(boto_session):
sess = Session(boto_session)
sess.delete_model('my_model')

boto_session.client().delete_model.assert_called_with(ModelName='my_model')


def test_user_agent_injected(boto_session):
assert 'AWS-SageMaker-Python-SDK' not in boto_session.client('sagemaker')._client_config.user_agent

Expand Down Expand Up @@ -933,7 +947,7 @@ def test_update_endpoint_non_existing_endpoint(sagemaker_session):
expected_error_message = 'Endpoint with name "non-existing-endpoint" does not exist; ' \
'please use an existing endpoint name'
sagemaker_session.sagemaker_client.describe_endpoint = Mock(side_effect=error)
with pytest.raises(ValueError, message=expected_error_message):
with pytest.raises(ValueError, match=expected_error_message):
sagemaker_session.update_endpoint("non-existing-endpoint", "non-existing-config")


Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def transformer(sagemaker_session):
volume_kms_key=KMS_KEY_ID)


def test_delete_model(sagemaker_session):
transformer = Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
sagemaker_session=sagemaker_session, volume_kms_key=KMS_KEY_ID)
transformer.delete_model()
sagemaker_session.delete_model.assert_called_with(MODEL_NAME)


@patch('sagemaker.transformer._TransformJob.start_new')
def test_transform_with_all_params(start_new_job, transformer):
content_type = 'text/csv'
Expand Down