Skip to content

Commit edd10aa

Browse files
authored
fix: allow output_path without trailing slash in Local Mode training jobs (#1439)
1 parent 90f8b0f commit edd10aa

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

src/sagemaker/local/utils.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
6262
final_uri = destination
6363
elif parsed_uri.scheme == "s3":
6464
bucket = parsed_uri.netloc
65-
path = "%s%s" % (parsed_uri.path.lstrip("/"), job_name)
65+
path = _create_s3_prefix(parsed_uri.path, job_name)
6666
final_uri = "s3://%s/%s" % (bucket, path)
6767
sagemaker_session.upload_data(source, bucket, path)
6868
else:
@@ -72,6 +72,22 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
7272
return final_uri
7373

7474

75+
def _create_s3_prefix(path, job_name):
76+
"""Constructs a path out of the given path and job name to be
77+
used as an S3 prefix.
78+
79+
Args:
80+
path (str): the original path. If the path is only ``"/"``,
81+
then it is ignored.
82+
job_name (str): the job name to be appended to the path.
83+
84+
Returns:
85+
str: an S3 prefix of the form ``"path/job_name"``
86+
"""
87+
path = path.strip("/")
88+
return job_name if path == "" else "/".join((path, job_name))
89+
90+
7591
def recursive_copy(source, destination):
7692
"""A wrapper around distutils.dir_util.copy_tree but won't throw any
7793
exception when the source directory does not exist.

tests/integ/test_chainer_train.py

+2
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def _run_mnist_training_job(
137137
train_instance_type=instance_type,
138138
sagemaker_session=sagemaker_session,
139139
hyperparameters={"epochs": 1},
140+
# test output_path without trailing slash
141+
output_path="s3://{}".format(sagemaker_session.default_bucket()),
140142
)
141143

142144
train_input = "file://" + os.path.join(data_path, "train")

tests/unit/test_local_utils.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,33 @@
2020

2121
@patch("shutil.rmtree", Mock())
2222
@patch("sagemaker.local.utils.recursive_copy")
23-
def test_move_to_destination(recursive_copy):
23+
def test_move_to_destination_local(recursive_copy):
2424
# local files will just be recursively copied
2525
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None)
2626
recursive_copy.assert_called_with("/tmp/data", "/target/dir/")
2727

28-
# s3 destination will upload to S3
28+
29+
@patch("shutil.rmtree", Mock())
30+
@patch("sagemaker.local.utils.recursive_copy")
31+
def test_move_to_destination_s3(recursive_copy):
2932
sms = Mock()
33+
34+
# without trailing slash in prefix
3035
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path", "job", sms)
31-
sms.upload_data.assert_called()
36+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job")
37+
recursive_copy.assert_not_called()
38+
39+
# with trailing slash in prefix
40+
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path/", "job", sms)
41+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job")
42+
43+
# without path, with trailing slash
44+
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/", "job", sms)
45+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "job")
46+
47+
# without path, without trailing slash
48+
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket", "job", sms)
49+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "job")
3250

3351

3452
def test_move_to_destination_illegal_destination():

0 commit comments

Comments
 (0)