Skip to content

Commit a6585e2

Browse files
imujjwal96chuyang-deng
authored andcommitted
fix: reset default output path in Transformer.transform (#905)
* fix: reset default output path on create transform job * Unit and integration tests
1 parent 5d8f2b7 commit a6585e2

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

src/sagemaker/transformer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090
self.base_transform_job_name = base_transform_job_name
9191
self._current_job_name = None
9292
self.latest_transform_job = None
93+
self._reset_output_path = False
9394

9495
self.sagemaker_session = sagemaker_session or Session()
9596

@@ -146,10 +147,11 @@ def transform(
146147

147148
self._current_job_name = name_from_base(base_name)
148149

149-
if self.output_path is None:
150+
if self.output_path is None or self._reset_output_path is True:
150151
self.output_path = "s3://{}/{}".format(
151152
self.sagemaker_session.default_bucket(), self._current_job_name
152153
)
154+
self._reset_output_path = True
153155

154156
self.latest_transform_job = _TransformJob.start_new(
155157
self,

tests/integ/test_transformer.py

+47
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,53 @@ def test_transform_byo_estimator(sagemaker_session):
301301
assert tags == model_tags
302302

303303

304+
def test_single_transformer_multiple_jobs(sagemaker_session, mxnet_full_version):
305+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
306+
script_path = os.path.join(data_path, "mnist.py")
307+
308+
mx = MXNet(
309+
entry_point=script_path,
310+
role="SageMakerRole",
311+
train_instance_count=1,
312+
train_instance_type="ml.c4.xlarge",
313+
sagemaker_session=sagemaker_session,
314+
framework_version=mxnet_full_version,
315+
)
316+
317+
train_input = mx.sagemaker_session.upload_data(
318+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
319+
)
320+
test_input = mx.sagemaker_session.upload_data(
321+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
322+
)
323+
job_name = unique_name_from_base("test-mxnet-transform")
324+
325+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
326+
mx.fit({"train": train_input, "test": test_input}, job_name=job_name)
327+
328+
transform_input_path = os.path.join(data_path, "transform", "data.csv")
329+
transform_input_key_prefix = "integ-test-data/mxnet_mnist/transform"
330+
transform_input = mx.sagemaker_session.upload_data(
331+
path=transform_input_path, key_prefix=transform_input_key_prefix
332+
)
333+
334+
transformer = mx.transformer(1, "ml.m4.xlarge")
335+
336+
job_name = unique_name_from_base("test-mxnet-transform")
337+
transformer.transform(transform_input, content_type="text/csv", job_name=job_name)
338+
with timeout_and_delete_model_with_transformer(
339+
transformer, sagemaker_session, minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES
340+
):
341+
assert transformer.output_path == "s3://{}/{}".format(
342+
sagemaker_session.default_bucket(), job_name
343+
)
344+
job_name = unique_name_from_base("test-mxnet-transform")
345+
transformer.transform(transform_input, content_type="text/csv", job_name=job_name)
346+
assert transformer.output_path == "s3://{}/{}".format(
347+
sagemaker_session.default_bucket(), job_name
348+
)
349+
350+
304351
def _create_transformer_and_transform_job(
305352
estimator,
306353
transform_input,

tests/unit/test_transformer.py

+12
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,15 @@ def test_transform_job_wait(sagemaker_session):
437437
job.wait()
438438

439439
assert sagemaker_session.wait_for_transform_job.called_once
440+
441+
442+
@patch("sagemaker.transformer._TransformJob.start_new")
443+
def test_restart_output_path(start_new_job, transformer, sagemaker_session):
444+
transformer.output_path = None
445+
sagemaker_session.default_bucket.return_value = S3_BUCKET
446+
447+
transformer.transform(DATA, job_name="job-1")
448+
assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-1")
449+
450+
transformer.transform(DATA, job_name="job-2")
451+
assert transformer.output_path == "s3://{}/{}".format(S3_BUCKET, "job-2")

0 commit comments

Comments
 (0)