From 103a06dd0c9661803998fa5fd90bed01b9dc3643 Mon Sep 17 00:00:00 2001 From: Ujjwal Bhardwaj Date: Thu, 13 Jun 2019 18:35:23 +0000 Subject: [PATCH 1/2] feature: handler for stopping transform job --- src/sagemaker/session.py | 21 +++++++++++++++ src/sagemaker/transformer.py | 9 +++++++ tests/integ/test_transformer.py | 47 +++++++++++++++++++++++++++++++++ tests/unit/test_transformer.py | 15 +++++++++++ 4 files changed, 92 insertions(+) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index fcbe6f3735..91fa040eac 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1052,6 +1052,27 @@ def wait_for_transform_job(self, job, poll=5): self._check_job_status(job, desc, "TransformJobStatus") return desc + def stop_transform_job(self, name): + """Stop the Amazon SageMaker hyperparameter tuning job with the specified name. + + Args: + name (str): Name of the Amazon SageMaker batch transform job. + + Raises: + ClientError: If an error occurs while trying to stop the batch transform job. + """ + try: + LOGGER.info('Stopping transform job: {}'.format(name)) + self.sagemaker_client.stop_transform_job(TransformJobName=name) + except ClientError as e: + error_code = e.response['Error']['Code'] + # allow to pass if the job already stopped + if error_code == 'ValidationException': + LOGGER.info('Transform job: {} is already stopped or not running.'.format(name)) + else: + LOGGER.error('Error occurred while attempting to stop transform job: {}.'.format(name)) + raise + def _check_job_status(self, job, desc, status_key_name): """Check to see if the job completed successfully and, if not, construct and raise a exceptions.UnexpectedStatusException. diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index ce442e0986..6262b1ac28 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -228,6 +228,12 @@ def wait(self): self._ensure_last_transform_job() self.latest_transform_job.wait() + def stop_transform_job(self): + """Stop latest running batch transform job. + """ + self._ensure_last_transform_job() + self.latest_transform_job.stop() + def _ensure_last_transform_job(self): """Placeholder docstring""" if self.latest_transform_job is None: @@ -345,6 +351,9 @@ def start_new( def wait(self): self.sagemaker_session.wait_for_transform_job(self.job_name) + def stop(self): + self.sagemaker_session.stop_transform_job(name=self.job_name) + @staticmethod def _load_config(data, data_type, content_type, compression_type, split_type, transformer): """ diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index ad3fd65c2d..965aeaabdf 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -16,6 +16,7 @@ import os import pickle import sys +import time import pytest @@ -366,4 +367,50 @@ def _create_transformer_and_transform_job( output_filter=output_filter, join_source=join_source, ) + + +def test_stop_transform_job(sagemaker_session, mxnet_full_version): + data_path = os.path.join(DATA_DIR, 'mxnet_mnist') + script_path = os.path.join(data_path, 'mnist.py') + tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] + + mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, + train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version) + + train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), + key_prefix='integ-test-data/mxnet_mnist/train') + test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), + key_prefix='integ-test-data/mxnet_mnist/test') + job_name = unique_name_from_base('test-mxnet-transform') + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) + + transform_input_path = os.path.join(data_path, 'transform', 'data.csv') + transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' + transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, + key_prefix=transform_input_key_prefix) + + transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags) + transformer.transform(transform_input, content_type='text/csv') + + time.sleep(15) + + latest_transform_job_name = transformer.latest_transform_job.name + + print('Attempting to stop {}'.format(latest_transform_job_name)) + + transformer.stop_transform_job() + + desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \ + .describe_transform_job(TransformJobName=latest_transform_job_name) + assert desc['TransformJobStatus'] == 'Stopping' + + +def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None, + input_filter=None, output_filter=None, join_source=None): + transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) + transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter, + output_filter=output_filter, join_source=join_source) return transformer diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 325f6536f1..e165abb85a 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -449,3 +449,18 @@ def test_restart_output_path(start_new_job, transformer, sagemaker_session): transformer.transform(DATA, job_name="job-2") assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-2") + + +def test_stop_transform_job(sagemaker_session, transformer): + sagemaker_session.stop_transform_job = Mock(name='stop_transform_job') + transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME) + + transformer.stop_transform_job() + + sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME) + + +def test_stop_transform_job_no_transform_job(transformer): + with pytest.raises(ValueError) as e: + transformer.stop_transform_job() + assert 'No transform job available' in str(e) From 108c5294f4b1d8dfaa051810e447600bfd1b7ab4 Mon Sep 17 00:00:00 2001 From: Ujjwal Bhardwaj Date: Fri, 5 Jul 2019 05:52:56 +0000 Subject: [PATCH 2/2] black check --- src/sagemaker/session.py | 10 ++-- src/sagemaker/transformer.py | 5 +- tests/integ/test_transformer.py | 92 +++++++++++++++++---------------- tests/unit/test_transformer.py | 4 +- 4 files changed, 58 insertions(+), 53 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 91fa040eac..58e6341408 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1062,15 +1062,15 @@ def stop_transform_job(self, name): ClientError: If an error occurs while trying to stop the batch transform job. """ try: - LOGGER.info('Stopping transform job: {}'.format(name)) + LOGGER.info("Stopping transform job: %s", name) self.sagemaker_client.stop_transform_job(TransformJobName=name) except ClientError as e: - error_code = e.response['Error']['Code'] + error_code = e.response["Error"]["Code"] # allow to pass if the job already stopped - if error_code == 'ValidationException': - LOGGER.info('Transform job: {} is already stopped or not running.'.format(name)) + if error_code == "ValidationException": + LOGGER.info("Transform job: %s is already stopped or not running.", name) else: - LOGGER.error('Error occurred while attempting to stop transform job: {}.'.format(name)) + LOGGER.error("Error occurred while attempting to stop transform job: %s.", name) raise def _check_job_status(self, job, desc, status_key_name): diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 6262b1ac28..19cfaf61ad 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -228,11 +228,13 @@ def wait(self): self._ensure_last_transform_job() self.latest_transform_job.wait() - def stop_transform_job(self): + def stop_transform_job(self, wait=True): """Stop latest running batch transform job. """ self._ensure_last_transform_job() self.latest_transform_job.stop() + if wait: + self.latest_transform_job.wait() def _ensure_last_transform_job(self): """Placeholder docstring""" @@ -352,6 +354,7 @@ def wait(self): self.sagemaker_session.wait_for_transform_job(self.job_name) def stop(self): + """Placeholder docstring""" self.sagemaker_session.stop_transform_job(name=self.job_name) @staticmethod diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 965aeaabdf..cb7327bbed 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -351,66 +351,68 @@ def test_single_transformer_multiple_jobs(sagemaker_session, mxnet_full_version) ) -def _create_transformer_and_transform_job( - estimator, - transform_input, - volume_kms_key=None, - input_filter=None, - output_filter=None, - join_source=None, -): - transformer = estimator.transformer(1, "ml.m4.xlarge", volume_kms_key=volume_kms_key) - transformer.transform( - transform_input, - content_type="text/csv", - input_filter=input_filter, - output_filter=output_filter, - join_source=join_source, - ) - - def test_stop_transform_job(sagemaker_session, mxnet_full_version): - data_path = os.path.join(DATA_DIR, 'mxnet_mnist') - script_path = os.path.join(data_path, 'mnist.py') - tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}] + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + tags = [{"Key": "some-tag", "Value": "value-for-tag"}] - mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, - framework_version=mxnet_full_version) + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) - train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'), - key_prefix='integ-test-data/mxnet_mnist/train') - test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'), - key_prefix='integ-test-data/mxnet_mnist/test') - job_name = unique_name_from_base('test-mxnet-transform') + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + job_name = unique_name_from_base("test-mxnet-transform") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - mx.fit({'train': train_input, 'test': test_input}, job_name=job_name) + mx.fit({"train": train_input, "test": test_input}, job_name=job_name) - transform_input_path = os.path.join(data_path, 'transform', 'data.csv') - transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform' - transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, - key_prefix=transform_input_key_prefix) + transform_input_path = os.path.join(data_path, "transform", "data.csv") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) - transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags) - transformer.transform(transform_input, content_type='text/csv') + transformer = mx.transformer(1, "ml.m4.xlarge", tags=tags) + transformer.transform(transform_input, content_type="text/csv") time.sleep(15) latest_transform_job_name = transformer.latest_transform_job.name - print('Attempting to stop {}'.format(latest_transform_job_name)) + print("Attempting to stop {}".format(latest_transform_job_name)) transformer.stop_transform_job() - desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \ - .describe_transform_job(TransformJobName=latest_transform_job_name) - assert desc['TransformJobStatus'] == 'Stopping' + desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client.describe_transform_job( + TransformJobName=latest_transform_job_name + ) + assert desc["TransformJobStatus"] == "Stopped" -def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None, - input_filter=None, output_filter=None, join_source=None): - transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) - transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter, - output_filter=output_filter, join_source=join_source) +def _create_transformer_and_transform_job( + estimator, + transform_input, + volume_kms_key=None, + input_filter=None, + output_filter=None, + join_source=None, +): + transformer = estimator.transformer(1, "ml.m4.xlarge", volume_kms_key=volume_kms_key) + transformer.transform( + transform_input, + content_type="text/csv", + input_filter=input_filter, + output_filter=output_filter, + join_source=join_source, + ) return transformer diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index e165abb85a..6104f91789 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -452,7 +452,7 @@ def test_restart_output_path(start_new_job, transformer, sagemaker_session): def test_stop_transform_job(sagemaker_session, transformer): - sagemaker_session.stop_transform_job = Mock(name='stop_transform_job') + sagemaker_session.stop_transform_job = Mock(name="stop_transform_job") transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME) transformer.stop_transform_job() @@ -463,4 +463,4 @@ def test_stop_transform_job(sagemaker_session, transformer): def test_stop_transform_job_no_transform_job(transformer): with pytest.raises(ValueError) as e: transformer.stop_transform_job() - assert 'No transform job available' in str(e) + assert "No transform job available" in str(e)