Skip to content

Commit 379ceac

Browse files
imujjwal96laurenyu
authored andcommitted
feature: handler for stopping transform job (#850)
1 parent c701100 commit 379ceac

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed

src/sagemaker/session.py

+21
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,27 @@ def wait_for_transform_job(self, job, poll=5):
11031103
self._check_job_status(job, desc, "TransformJobStatus")
11041104
return desc
11051105

1106+
def stop_transform_job(self, name):
1107+
"""Stop the Amazon SageMaker hyperparameter tuning job with the specified name.
1108+
1109+
Args:
1110+
name (str): Name of the Amazon SageMaker batch transform job.
1111+
1112+
Raises:
1113+
ClientError: If an error occurs while trying to stop the batch transform job.
1114+
"""
1115+
try:
1116+
LOGGER.info("Stopping transform job: %s", name)
1117+
self.sagemaker_client.stop_transform_job(TransformJobName=name)
1118+
except ClientError as e:
1119+
error_code = e.response["Error"]["Code"]
1120+
# allow to pass if the job already stopped
1121+
if error_code == "ValidationException":
1122+
LOGGER.info("Transform job: %s is already stopped or not running.", name)
1123+
else:
1124+
LOGGER.error("Error occurred while attempting to stop transform job: %s.", name)
1125+
raise
1126+
11061127
def _check_job_status(self, job, desc, status_key_name):
11071128
"""Check to see if the job completed successfully and, if not, construct and
11081129
raise a exceptions.UnexpectedStatusException.

src/sagemaker/transformer.py

+12
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,14 @@ def wait(self):
229229
self._ensure_last_transform_job()
230230
self.latest_transform_job.wait()
231231

232+
def stop_transform_job(self, wait=True):
233+
"""Stop latest running batch transform job.
234+
"""
235+
self._ensure_last_transform_job()
236+
self.latest_transform_job.stop()
237+
if wait:
238+
self.latest_transform_job.wait()
239+
232240
def _ensure_last_transform_job(self):
233241
"""Placeholder docstring"""
234242
if self.latest_transform_job is None:
@@ -346,6 +354,10 @@ def start_new(
346354
def wait(self):
347355
self.sagemaker_session.wait_for_transform_job(self.job_name)
348356

357+
def stop(self):
358+
"""Placeholder docstring"""
359+
self.sagemaker_session.stop_transform_job(name=self.job_name)
360+
349361
@staticmethod
350362
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):
351363
"""

tests/integ/test_transformer.py

+49
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import pickle
1818
import sys
19+
import time
1920

2021
import pytest
2122

@@ -349,6 +350,54 @@ def test_single_transformer_multiple_jobs(sagemaker_session, mxnet_full_version,
349350
)
350351

351352

353+
def test_stop_transform_job(sagemaker_session, mxnet_full_version):
354+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
355+
script_path = os.path.join(data_path, "mnist.py")
356+
tags = [{"Key": "some-tag", "Value": "value-for-tag"}]
357+
358+
mx = MXNet(
359+
entry_point=script_path,
360+
role="SageMakerRole",
361+
train_instance_count=1,
362+
train_instance_type="ml.c4.xlarge",
363+
sagemaker_session=sagemaker_session,
364+
framework_version=mxnet_full_version,
365+
)
366+
367+
train_input = mx.sagemaker_session.upload_data(
368+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
369+
)
370+
test_input = mx.sagemaker_session.upload_data(
371+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
372+
)
373+
job_name = unique_name_from_base("test-mxnet-transform")
374+
375+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
376+
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
377+
378+
transform_input_path = os.path.join(data_path, "transform", "data.csv")
379+
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
380+
transform_input = mx.sagemaker_session.upload_data(
381+
path=transform_input_path, key_prefix=transform_input_key_prefix
382+
)
383+
384+
transformer = mx.transformer(1, "ml.m4.xlarge", tags=tags)
385+
transformer.transform(transform_input, content_type="text/csv")
386+
387+
time.sleep(15)
388+
389+
latest_transform_job_name = transformer.latest_transform_job.name
390+
391+
print("Attempting to stop {}".format(latest_transform_job_name))
392+
393+
transformer.stop_transform_job()
394+
395+
desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client.describe_transform_job(
396+
TransformJobName=latest_transform_job_name
397+
)
398+
assert desc["TransformJobStatus"] == "Stopped"
399+
400+
352401
def _create_transformer_and_transform_job(
353402
estimator,
354403
transform_input,

tests/unit/test_transformer.py

+15
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,18 @@ def test_restart_output_path(start_new_job, transformer, sagemaker_session):
449449

450450
transformer.transform(DATA, job_name="job-2")
451451
assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-2")
452+
453+
454+
def test_stop_transform_job(sagemaker_session, transformer):
455+
sagemaker_session.stop_transform_job = Mock(name="stop_transform_job")
456+
transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME)
457+
458+
transformer.stop_transform_job()
459+
460+
sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME)
461+
462+
463+
def test_stop_transform_job_no_transform_job(transformer):
464+
with pytest.raises(ValueError) as e:
465+
transformer.stop_transform_job()
466+
assert "No transform job available" in str(e)

0 commit comments

Comments
 (0)