diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 0c2f36b414..4ea198433a 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -90,6 +90,7 @@ def __init__( self.base_transform_job_name = base_transform_job_name self._current_job_name = None self.latest_transform_job = None + self._reset_output_path = False self.sagemaker_session = sagemaker_session or Session() @@ -146,10 +147,11 @@ def transform( self._current_job_name = name_from_base(base_name) - if self.output_path is None: + if self.output_path is None or self._reset_output_path is True: self.output_path = "s3://{}/{}".format( self.sagemaker_session.default_bucket(), self._current_job_name ) + self._reset_output_path = True self.latest_transform_job = _TransformJob.start_new( self, diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 989b074380..c519cb6786 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -301,6 +301,53 @@ def test_transform_byo_estimator(sagemaker_session): assert tags == model_tags +def test_single_transformer_multiple_jobs(sagemaker_session, mxnet_full_version): + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "mnist.py") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + train_instance_count=1, + train_instance_type="ml.c4.xlarge", + sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + job_name = unique_name_from_base("test-mxnet-transform") + + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + mx.fit({"train": train_input, "test": test_input}, job_name=job_name) + + transform_input_path = os.path.join(data_path, "transform", "data.csv") + transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform" + transform_input = mx.sagemaker_session.upload_data( + path=transform_input_path, key_prefix=transform_input_key_prefix + ) + + transformer = mx.transformer(1, "ml.m4.xlarge") + + job_name = unique_name_from_base("test-mxnet-transform") + transformer.transform(transform_input, content_type="text/csv", job_name=job_name) + with timeout_and_delete_model_with_transformer( + transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES + ): + assert transformer.output_path == "s3://{}/{}".format( + sagemaker_session.default_bucket(), job_name + ) + job_name = unique_name_from_base("test-mxnet-transform") + transformer.transform(transform_input, content_type="text/csv", job_name=job_name) + assert transformer.output_path == "s3://{}/{}".format( + sagemaker_session.default_bucket(), job_name + ) + + def _create_transformer_and_transform_job( estimator, transform_input, diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 88c4495e37..325f6536f1 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -437,3 +437,15 @@ def test_transform_job_wait(sagemaker_session): job.wait() assert sagemaker_session.wait_for_transform_job.called_once + + +@patch("sagemaker.transformer._TransformJob.start_new") +def test_restart_output_path(start_new_job, transformer, sagemaker_session): + transformer.output_path = None + sagemaker_session.default_bucket.return_value = S3_BUCKET + + transformer.transform(DATA, job_name="job-1") + assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-1") + + transformer.transform(DATA, job_name="job-2") + assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-2")