Skip to content

Commit 103a06d

Browse files
committed
feature: handler for stopping transform job
1 parent d0c784a commit 103a06d

File tree

4 files changed

+92
-0
lines changed

4 files changed

+92
-0
lines changed

src/sagemaker/session.py

+21
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,27 @@ def wait_for_transform_job(self, job, poll=5):
10521052
self._check_job_status(job, desc, "TransformJobStatus")
10531053
return desc
10541054

1055+
def stop_transform_job(self, name):
1056+
"""Stop the Amazon SageMaker hyperparameter tuning job with the specified name.
1057+
1058+
Args:
1059+
name (str): Name of the Amazon SageMaker batch transform job.
1060+
1061+
Raises:
1062+
ClientError: If an error occurs while trying to stop the batch transform job.
1063+
"""
1064+
try:
1065+
LOGGER.info('Stopping transform job: {}'.format(name))
1066+
self.sagemaker_client.stop_transform_job(TransformJobName=name)
1067+
except ClientError as e:
1068+
error_code = e.response['Error']['Code']
1069+
# allow to pass if the job already stopped
1070+
if error_code == 'ValidationException':
1071+
LOGGER.info('Transform job: {} is already stopped or not running.'.format(name))
1072+
else:
1073+
LOGGER.error('Error occurred while attempting to stop transform job: {}.'.format(name))
1074+
raise
1075+
10551076
def _check_job_status(self, job, desc, status_key_name):
10561077
"""Check to see if the job completed successfully and, if not, construct and
10571078
raise a exceptions.UnexpectedStatusException.

src/sagemaker/transformer.py

+9
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ def wait(self):
228228
self._ensure_last_transform_job()
229229
self.latest_transform_job.wait()
230230

231+
def stop_transform_job(self):
232+
"""Stop latest running batch transform job.
233+
"""
234+
self._ensure_last_transform_job()
235+
self.latest_transform_job.stop()
236+
231237
def _ensure_last_transform_job(self):
232238
"""Placeholder docstring"""
233239
if self.latest_transform_job is None:
@@ -345,6 +351,9 @@ def start_new(
345351
def wait(self):
346352
self.sagemaker_session.wait_for_transform_job(self.job_name)
347353

354+
def stop(self):
355+
self.sagemaker_session.stop_transform_job(name=self.job_name)
356+
348357
@staticmethod
349358
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):
350359
"""

tests/integ/test_transformer.py

+47
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import pickle
1818
import sys
19+
import time
1920

2021
import pytest
2122

@@ -366,4 +367,50 @@ def _create_transformer_and_transform_job(
366367
output_filter=output_filter,
367368
join_source=join_source,
368369
)
370+
371+
372+
def test_stop_transform_job(sagemaker_session, mxnet_full_version):
373+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
374+
script_path = os.path.join(data_path, 'mnist.py')
375+
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
376+
377+
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
378+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
379+
framework_version=mxnet_full_version)
380+
381+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
382+
key_prefix='integ-test-data/mxnet_mnist/train')
383+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
384+
key_prefix='integ-test-data/mxnet_mnist/test')
385+
job_name = unique_name_from_base('test-mxnet-transform')
386+
387+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
388+
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
389+
390+
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
391+
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
392+
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
393+
key_prefix=transform_input_key_prefix)
394+
395+
transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags)
396+
transformer.transform(transform_input, content_type='text/csv')
397+
398+
time.sleep(15)
399+
400+
latest_transform_job_name = transformer.latest_transform_job.name
401+
402+
print('Attempting to stop {}'.format(latest_transform_job_name))
403+
404+
transformer.stop_transform_job()
405+
406+
desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \
407+
.describe_transform_job(TransformJobName=latest_transform_job_name)
408+
assert desc['TransformJobStatus'] == 'Stopping'
409+
410+
411+
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None,
412+
input_filter=None, output_filter=None, join_source=None):
413+
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
414+
transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter,
415+
output_filter=output_filter, join_source=join_source)
369416
return transformer

tests/unit/test_transformer.py

+15
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,18 @@ def test_restart_output_path(start_new_job, transformer, sagemaker_session):
449449

450450
transformer.transform(DATA, job_name="job-2")
451451
assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-2")
452+
453+
454+
def test_stop_transform_job(sagemaker_session, transformer):
455+
sagemaker_session.stop_transform_job = Mock(name='stop_transform_job')
456+
transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME)
457+
458+
transformer.stop_transform_job()
459+
460+
sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME)
461+
462+
463+
def test_stop_transform_job_no_transform_job(transformer):
464+
with pytest.raises(ValueError) as e:
465+
transformer.stop_transform_job()
466+
assert 'No transform job available' in str(e)

0 commit comments

Comments
 (0)