Skip to content

Commit 5e39d3f

Browse files
committed
feature: handler for stopping transform job
1 parent 0346abd commit 5e39d3f

File tree

4 files changed

+85
-0
lines changed

4 files changed

+85
-0
lines changed

src/sagemaker/session.py

+21
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,27 @@ def wait_for_transform_job(self, job, poll=5):
906906
self._check_job_status(job, desc, 'TransformJobStatus')
907907
return desc
908908

909+
def stop_transform_job(self, name):
910+
"""Stop the Amazon SageMaker hyperparameter tuning job with the specified name.
911+
912+
Args:
913+
name (str): Name of the Amazon SageMaker batch transform job.
914+
915+
Raises:
916+
ClientError: If an error occurs while trying to stop the batch transform job.
917+
"""
918+
try:
919+
LOGGER.info('Stopping transform job: {}'.format(name))
920+
self.sagemaker_client.stop_transform_job(TransformJobName=name)
921+
except ClientError as e:
922+
error_code = e.response['Error']['Code']
923+
# allow to pass if the job already stopped
924+
if error_code == 'ValidationException':
925+
LOGGER.info('Transform job: {} is already stopped or not running.'.format(name))
926+
else:
927+
LOGGER.error('Error occurred while attempting to stop transform job: {}.'.format(name))
928+
raise
929+
909930
def _check_job_status(self, job, desc, status_key_name):
910931
"""Check to see if the job completed successfully and, if not, construct and
911932
raise a ValueError.

src/sagemaker/transformer.py

+9
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,12 @@ def wait(self):
165165
self._ensure_last_transform_job()
166166
self.latest_transform_job.wait()
167167

168+
def stop_transform_job(self):
169+
"""Stop latest running batch transform job.
170+
"""
171+
self._ensure_last_transform_job()
172+
self.latest_transform_job.stop()
173+
168174
def _ensure_last_transform_job(self):
169175
if self.latest_transform_job is None:
170176
raise ValueError('No transform job available')
@@ -242,6 +248,9 @@ def start_new(cls, transformer, data, data_type, content_type, compression_type,
242248
def wait(self):
243249
self.sagemaker_session.wait_for_transform_job(self.job_name)
244250

251+
def stop(self):
252+
self.sagemaker_session.stop_transform_job(name=self.job_name)
253+
245254
@staticmethod
246255
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):
247256
input_config = _TransformJob._format_inputs_to_input_config(data, data_type, content_type,

tests/integ/test_transformer.py

+40
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import pickle
1818
import sys
19+
import time
1920

2021
import pytest
2122

@@ -236,6 +237,45 @@ def test_transform_byo_estimator(sagemaker_session):
236237
assert tags == model_tags
237238

238239

240+
def test_stop_transform_job(sagemaker_session, mxnet_full_version):
241+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
242+
script_path = os.path.join(data_path, 'mnist.py')
243+
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
244+
245+
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
246+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
247+
framework_version=mxnet_full_version)
248+
249+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
250+
key_prefix='integ-test-data/mxnet_mnist/train')
251+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
252+
key_prefix='integ-test-data/mxnet_mnist/test')
253+
job_name = unique_name_from_base('test-mxnet-transform')
254+
255+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
256+
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
257+
258+
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
259+
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
260+
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
261+
key_prefix=transform_input_key_prefix)
262+
263+
transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags)
264+
transformer.transform(transform_input, content_type='text/csv')
265+
266+
time.sleep(15)
267+
268+
latest_transform_job_name = transformer.latest_transform_job.name
269+
270+
print('Attempting to stop {}'.format(latest_transform_job_name))
271+
272+
transformer.stop_transform_job()
273+
274+
desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \
275+
.describe_transform_job(TransformJobName=latest_transform_job_name)
276+
assert desc['TransformJobStatus'] == 'Stopping'
277+
278+
239279
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None,
240280
input_filter=None, output_filter=None, join_source=None):
241281
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)

tests/unit/test_transformer.py

+15
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,18 @@ def test_transform_job_wait(sagemaker_session):
420420
job.wait()
421421

422422
assert sagemaker_session.wait_for_transform_job.called_once
423+
424+
425+
def test_stop_transform_job(sagemaker_session, transformer):
426+
sagemaker_session.stop_transform_job = Mock(name='stop_transform_job')
427+
transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME)
428+
429+
transformer.stop_transform_job()
430+
431+
sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME)
432+
433+
434+
def test_stop_transform_job_no_transform_job(transformer):
435+
with pytest.raises(ValueError) as e:
436+
transformer.stop_transform_job()
437+
assert 'No transform job available' in str(e)

0 commit comments

Comments
 (0)