Skip to content

Commit 675ae02

Browse files
committed
feature: handler for stopping transform job
1 parent 686569e commit 675ae02

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

src/sagemaker/session.py

+21
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,27 @@ def wait_for_transform_job(self, job, poll=5):
902902
self._check_job_status(job, desc, 'TransformJobStatus')
903903
return desc
904904

905+
def stop_transform_job(self, name):
906+
"""Stop the Amazon SageMaker hyperparameter tuning job with the specified name.
907+
908+
Args:
909+
name (str): Name of the Amazon SageMaker batch transform job.
910+
911+
Raises:
912+
ClientError: If an error occurs while trying to stop the batch transform job.
913+
"""
914+
try:
915+
LOGGER.info('Stopping transform job: {}'.format(name))
916+
self.sagemaker_client.stop_transform_job(TransformJobName=name)
917+
except ClientError as e:
918+
error_code = e.response['Error']['Code']
919+
# allow to pass if the job already stopped
920+
if error_code == 'ValidationException':
921+
LOGGER.info('Transform job: {} is already stopped or not running.'.format(name))
922+
else:
923+
LOGGER.error('Error occurred while attempting to stop transform job: {}.'.format(name))
924+
raise
925+
905926
def _check_job_status(self, job, desc, status_key_name):
906927
"""Check to see if the job completed successfully and, if not, construct and
907928
raise a ValueError.

src/sagemaker/transformer.py

+9
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def wait(self):
156156
self._ensure_last_transform_job()
157157
self.latest_transform_job.wait()
158158

159+
def stop_transform_job(self):
160+
"""Stop latest running batch transform job.
161+
"""
162+
self._ensure_last_transform_job()
163+
self.latest_transform_job.stop()
164+
159165
def _ensure_last_transform_job(self):
160166
if self.latest_transform_job is None:
161167
raise ValueError('No transform job available')
@@ -230,6 +236,9 @@ def start_new(cls, transformer, data, data_type, content_type, compression_type,
230236
def wait(self):
231237
self.sagemaker_session.wait_for_transform_job(self.job_name)
232238

239+
def stop(self):
240+
self.sagemaker_session.stop_transform_job(name=self.job_name)
241+
233242
@staticmethod
234243
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):
235244
input_config = _TransformJob._format_inputs_to_input_config(data, data_type, content_type,

0 commit comments

Comments
 (0)