Skip to content

Commit 7e00970

Browse files
committed
fix: allow output_path without trailing slash in Local Mode training jobs
1 parent 06a00d4 commit 7e00970

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

src/sagemaker/local/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
6262
final_uri = destination
6363
elif parsed_uri.scheme == "s3":
6464
bucket = parsed_uri.netloc
65-
path = "%s%s" % (parsed_uri.path.lstrip("/"), job_name)
65+
path = _create_s3_path(parsed_uri.path, job_name)
6666
final_uri = "s3://%s/%s" % (bucket, path)
6767
sagemaker_session.upload_data(source, bucket, path)
6868
else:
@@ -72,6 +72,11 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
7272
return final_uri
7373

7474

75+
def _create_s3_path(prefix, job_name):
76+
prefix = prefix.strip("/")
77+
return job_name if prefix == "" else "/".join((prefix, job_name))
78+
79+
7580
def recursive_copy(source, destination):
7681
"""A wrapper around distutils.dir_util.copy_tree but won't throw any
7782
exception when the source directory does not exist.

tests/integ/test_chainer_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def _run_mnist_training_job(
137137
train_instance_type=instance_type,
138138
sagemaker_session=sagemaker_session,
139139
hyperparameters={"epochs": 1},
140+
# test output_path without trailing slash
141+
output_path="s3://{}".format(sagemaker_session.default_bucket()),
140142
)
141143

142144
train_input = "file://" + os.path.join(data_path, "train")

tests/unit/test_local_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,33 @@
2020

2121
@patch("shutil.rmtree", Mock())
2222
@patch("sagemaker.local.utils.recursive_copy")
23-
def test_move_to_destination(recursive_copy):
23+
def test_move_to_destination_local(recursive_copy):
2424
# local files will just be recursively copied
2525
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None)
2626
recursive_copy.assert_called_with("/tmp/data", "/target/dir/")
2727

28-
# s3 destination will upload to S3
28+
29+
@patch("shutil.rmtree", Mock())
30+
@patch("sagemaker.local.utils.recursive_copy")
31+
def test_move_to_destination_s3(recursive_copy):
2932
sms = Mock()
33+
34+
# without trailing slash in prefix
3035
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path", "job", sms)
31-
sms.upload_data.assert_called()
36+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job")
37+
recursive_copy.assert_not_called()
38+
39+
# with trailing slash in prefix
40+
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path/", "job", sms)
41+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job")
42+
43+
# without path, with trailing slash
44+
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/", "job", sms)
45+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "job")
46+
47+
# without path, without trailing slash
48+
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket", "job", sms)
49+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "job")
3250

3351

3452
def test_move_to_destination_illegal_destination():

0 commit comments

Comments
 (0)