Skip to content

Commit 0173ed8

Browse files
author
Chuyang Deng
committed
Modify some functions, tests and update docs.
1 parent 22ee5ff commit 0173ed8

File tree

7 files changed

+117
-70
lines changed

7 files changed

+117
-70
lines changed

src/sagemaker/predictor.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,7 @@ def _create_request_args(self, data, initial_args=None):
107107

108108
def _delete_endpoint_config(self):
109109
"""Delete the Amazon SageMaker endpoint configuration
110-
<<<<<<< HEAD
111-
=======
112-
"""
113-
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
114-
endpoint_config_name = endpoint_description['EndpointConfigName']
115-
self.sagemaker_session.delete_endpoint_config(endpoint_config_name)
116-
117-
def delete_endpoint(self, delete_endpoint_config=True):
118-
"""Delete the Amazon SageMaker endpoint backing this predictor. Also delete the endpoint configuration attached
119-
to it if delete_endpoint_config is True.
120110
121-
Args:
122-
delete_endpoint_config (bool): Flag to indicate whether to delete endpoint configuration together with
123-
endpoint. If False, only endpoint will be deleted. Default: True.
124-
125-
>>>>>>> 45e5c07... Add new APIs to predictor to delete endpoint and endpoint config, and transformer to delete model.
126111
"""
127112
endpoint_description = self.sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=self.endpoint)
128113
endpoint_config_name = endpoint_description['EndpointConfigName']
@@ -141,9 +126,6 @@ def delete_endpoint(self, delete_endpoint_config=True):
141126

142127
self.sagemaker_session.delete_endpoint(self.endpoint)
143128

144-
if delete_endpoint_config:
145-
self._delete_endpoint_config()
146-
147129

148130
class _CsvSerializer(object):
149131
def __init__(self):

src/sagemaker/session.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -796,8 +796,12 @@ def delete_model(self, model_name):
796796
model_name (str): Name of the Amazon SageMaker model to delete.
797797
798798
"""
799-
LOGGER.info('Deleting model with name: {}'.format(model_name))
800-
self.sagemaker_client.delete_model(ModelName=model_name)
799+
try:
800+
self.sagemaker_client.describe_model(ModelName=model_name)
801+
LOGGER.info('Deleting model with name: {}'.format(model_name))
802+
self.sagemaker_client.delete_model(ModelName=model_name)
803+
except Exception:
804+
raise ValueError('The Sagemaker model must be deployed first before attempting to delete.')
801805

802806
def wait_for_job(self, job, poll=5):
803807
"""Wait for an Amazon SageMaker training job to complete.

tests/integ/test_transformer.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -106,44 +106,44 @@ def test_attach_transform_kmeans(sagemaker_session):
106106
attached_transformer.wait()
107107

108108

109-
def test_transformer_delete_model(sagemaker_session):
110-
data_path = os.path.join(DATA_DIR, 'one_p_mnist')
111-
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
112-
113-
train_set_path = os.path.join(data_path, 'mnist.pkl.gz')
114-
with gzip.open(train_set_path, 'rb') as f:
115-
train_set, _, _ = pickle.load(f, **pickle_args)
116-
117-
kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
118-
train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session,
119-
output_path='s3://{}/'.format(sagemaker_session.default_bucket()))
120-
121-
kmeans.init_method = 'random'
122-
kmeans.max_iterations = 1
123-
kmeans.tol = 1
124-
kmeans.num_trials = 1
125-
kmeans.local_init_method = 'kmeans++'
126-
kmeans.half_life_time_size = 1
127-
kmeans.epochs = 1
128-
129-
records = kmeans.record_set(train_set[0][:100])
130-
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
131-
kmeans.fit(records)
132-
133-
transform_input_path = os.path.join(data_path, 'transform_input.csv')
134-
transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform'
135-
transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path,
136-
key_prefix=transform_input_key_prefix)
137-
138-
transformer = _create_transformer_and_transform_job(kmeans, transform_input)
139-
with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
140-
transformer.wait()
141-
142-
transformer.delete_model()
143-
144-
with pytest.raises(Exception) as exception:
145-
sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
146-
assert 'Could not find model' in exception.value.message
109+
# def test_transformer_delete_model(sagemaker_session):
110+
# data_path = os.path.join(DATA_DIR, 'one_p_mnist')
111+
# pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
112+
#
113+
# train_set_path = os.path.join(data_path, 'mnist.pkl.gz')
114+
# with gzip.open(train_set_path, 'rb') as f:
115+
# train_set, _, _ = pickle.load(f, **pickle_args)
116+
#
117+
# kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
118+
# train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session,
119+
# output_path='s3://{}/'.format(sagemaker_session.default_bucket()))
120+
#
121+
# kmeans.init_method = 'random'
122+
# kmeans.max_iterations = 1
123+
# kmeans.tol = 1
124+
# kmeans.num_trials = 1
125+
# kmeans.local_init_method = 'kmeans++'
126+
# kmeans.half_life_time_size = 1
127+
# kmeans.epochs = 1
128+
#
129+
# records = kmeans.record_set(train_set[0][:100])
130+
# with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
131+
# kmeans.fit(records)
132+
#
133+
# transform_input_path = os.path.join(data_path, 'transform_input.csv')
134+
# transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform'
135+
# transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path,
136+
# key_prefix=transform_input_key_prefix)
137+
#
138+
# transformer = _create_transformer_and_transform_job(kmeans, transform_input)
139+
# with timeout(minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
140+
# transformer.wait()
141+
#
142+
# transformer.delete_model()
143+
#
144+
# with pytest.raises(Exception) as exception:
145+
# sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
146+
# assert 'Could not find model' in exception.value.message
147147

148148

149149
def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):

tests/integ/timeout.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,53 @@ def timeout_and_delete_model_with_transformer(transformer, sagemaker_session, se
102102
try:
103103
transformer.delete_model()
104104
LOGGER.info('deleted SageMaker model {}'.format(transformer.model_name))
105+
<<<<<<< HEAD
105106

106107
_show_logs(transformer.model_name, 'Models', sagemaker_session)
107108
if no_errors:
108109
_cleanup_logs(transformer.model_name, 'Models', sagemaker_session)
110+
=======
111+
if no_errors:
112+
_cleanup_model_logs(transformer.model_name, sagemaker_session)
113+
>>>>>>> 334a0d6... Modify some functions, tests and update docs.
109114
return
110115
except ClientError as ce:
111116
if ce.response['Error']['Code'] == 'ValidationException':
112117
pass
113118
sleep(10)
114119

115120

121+
<<<<<<< HEAD
116122
def _show_logs(resource_name, resource_type, sagemaker_session):
117123
log_group = '/aws/sagemaker/{}/{}'.format(resource_type, resource_name)
124+
=======
125+
def _show_model_logs(model_name, sagemaker_session):
126+
log_group = '/aws/sagemaker/Models/{}'.format(model_name)
127+
try:
128+
LOGGER.info('cloudwatch logs for log group {}'.format(log_group))
129+
logs = AWSLogs(log_group_name=log_group, log_stream_name='ALL', start='1d',
130+
aws_region=sagemaker_session.boto_session.region_name)
131+
logs.list_logs()
132+
except Exception:
133+
LOGGER.exception('Failure occurred while listing cloudwatch log group %s. Swallowing exception but printing '
134+
'stacktrace for debugging.', log_group)
135+
136+
137+
def _cleanup_model_logs(model_name, sagemaker_session):
138+
log_group = '/aws/sagemaker/Models/{}'.format(model_name)
139+
try:
140+
LOGGER.info('deleting cloudwatch log group {}:'.format(log_group))
141+
cwl_client = sagemaker_session.boto_session.client('logs')
142+
cwl_client.delete_log_group(logGroupName=log_group)
143+
LOGGER.info('deleted cloudwatch log group: {}'.format(log_group))
144+
except Exception:
145+
LOGGER.exception('Failure occurred while cleaning up cloudwatch log group %s. '
146+
'Swallowing exception but printing stacktrace for debugging.', log_group)
147+
148+
149+
def _show_endpoint_logs(endpoint_name, sagemaker_session):
150+
log_group = '/aws/sagemaker/Endpoints/{}'.format(endpoint_name)
151+
>>>>>>> 334a0d6... Modify some functions, tests and update docs.
118152
try:
119153
# print out logs before deletion for debuggability
120154
LOGGER.info('cloudwatch logs for log group {}:'.format(log_group))

tests/unit/test_local_session.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def test_create_model(LocalSession):
156156
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
157157

158158
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
159+
<<<<<<< HEAD
159160

160161
assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models
161162

@@ -169,19 +170,21 @@ def test_delete_model(LocalSession):
169170

170171
local_sagemaker_client.delete_model(MODEL_NAME)
171172
assert MODEL_NAME not in sagemaker.local.local_session.LocalSagemakerClient._models
173+
=======
174+
175+
assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models
176+
>>>>>>> 334a0d6... Modify some functions, tests and update docs.
172177

173178

174179
@patch('sagemaker.local.local_session.LocalSession')
175180
def test_delete_model(LocalSession):
176181
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
177-
model_name = 'my-model'
178-
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}
179182

180-
local_sagemaker_client.create_model(model_name, primary_container)
181-
assert model_name in sagemaker.local.local_session.LocalSagemakerClient._models
183+
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
184+
assert MODEL_NAME in sagemaker.local.local_session.LocalSagemakerClient._models
182185

183-
local_sagemaker_client.delete_model(model_name)
184-
assert model_name not in sagemaker.local.local_session.LocalSagemakerClient._models
186+
local_sagemaker_client.delete_model(MODEL_NAME)
187+
assert MODEL_NAME not in sagemaker.local.local_session.LocalSagemakerClient._models
185188

186189

187190
@patch('sagemaker.local.local_session.LocalSession')
@@ -239,6 +242,7 @@ def test_describe_endpoint_config(LocalSession):
239242
def test_create_endpoint_config(LocalSession):
240243
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
241244
local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)
245+
<<<<<<< HEAD
242246

243247
assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
244248

@@ -252,19 +256,21 @@ def test_delete_endpoint_config(LocalSession):
252256

253257
local_sagemaker_client.delete_endpoint_config(ENDPOINT_CONFIG_NAME)
254258
assert ENDPOINT_CONFIG_NAME not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
259+
=======
260+
261+
assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
262+
>>>>>>> 334a0d6... Modify some functions, tests and update docs.
255263

256264

257265
@patch('sagemaker.local.local_session.LocalSession')
258266
def test_delete_endpoint_config(LocalSession):
259267
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
260-
production_variants = [{'InstanceType': 'ml.c4.99xlarge', 'InitialInstanceCount': 10}]
261-
endpoint_config_name = 'my-endpoint-config'
262268

263-
local_sagemaker_client.create_endpoint_config(endpoint_config_name, production_variants)
264-
assert endpoint_config_name in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
269+
local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)
270+
assert ENDPOINT_CONFIG_NAME in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
265271

266-
local_sagemaker_client.delete_endpoint_config(endpoint_config_name)
267-
assert endpoint_config_name not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
272+
local_sagemaker_client.delete_endpoint_config(ENDPOINT_CONFIG_NAME)
273+
assert ENDPOINT_CONFIG_NAME not in sagemaker.local.local_session.LocalSagemakerClient._endpoint_configs
268274

269275

270276
@patch('sagemaker.local.image._SageMakerContainer.serve')

tests/unit/test_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def test_model_delete_model(sagemaker_session, tmpdir):
346346

347347
sagemaker_session.delete_model.assert_called_with(model.name)
348348
<<<<<<< HEAD
349+
<<<<<<< HEAD
349350

350351

351352
def test_delete_non_deployed_model(sagemaker_session):
@@ -354,3 +355,13 @@ def test_delete_non_deployed_model(sagemaker_session):
354355
model.delete_model()
355356
=======
356357
>>>>>>> 45e5c07... Add new APIs to predictor to delete endpoint and endpoint config, and transformer to delete model.
358+
=======
359+
360+
361+
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
362+
@patch('time.strftime', MagicMock(return_value=TIMESTAMP))
363+
def test_delete_non_deployed_model(sagemaker_session, tmpdir):
364+
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
365+
with pytest.raises(ValueError, match='The SageMaker model must be deployed first before attempting to delete.'):
366+
model.delete_model()
367+
>>>>>>> 334a0d6... Modify some functions, tests and update docs.

tests/unit/test_predictor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,16 @@ def test_delete_endpoint_with_config():
459459
sagemaker_session.delete_endpoint_config.assert_called_with('endpoint-config')
460460

461461

462+
def test_delete_non_existing_endpoint():
463+
sagemaker_session = empty_sagemaker_session()
464+
expected_error_message = 'The endpoint this config attached to does not exist.'
465+
sagemaker_session.sagemaker_client.describe_endpoint = Mock(side_effect=ValueError(expected_error_message))
466+
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
467+
468+
with pytest.raises(ValueError, match=expected_error_message):
469+
predictor.delete_endpoint()
470+
471+
462472
def test_delete_endpoint_only():
463473
sagemaker_session = empty_sagemaker_session()
464474
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)

0 commit comments

Comments
 (0)