diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index a9729ddebb..e7d701ae8e 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -62,7 +62,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session): final_uri = destination elif parsed_uri.scheme == "s3": bucket = parsed_uri.netloc - path = "%s%s" % (parsed_uri.path.lstrip("/"), job_name) + path = _create_s3_prefix(parsed_uri.path, job_name) final_uri = "s3://%s/%s" % (bucket, path) sagemaker_session.upload_data(source, bucket, path) else: @@ -72,6 +72,22 @@ def move_to_destination(source, destination, job_name, sagemaker_session): return final_uri +def _create_s3_prefix(path, job_name): + """Constructs a path out of the given path and job name to be + used as an S3 prefix. + + Args: + path (str): the original path. If the path is only ``"/"``, + then it is ignored. + job_name (str): the job name to be appended to the path. + + Returns: + str: an S3 prefix of the form ``"path/job_name"`` + """ + path = path.strip("/") + return job_name if path == "" else "/".join((path, job_name)) + + def recursive_copy(source, destination): """A wrapper around distutils.dir_util.copy_tree but won't throw any exception when the source directory does not exist. diff --git a/tests/integ/test_chainer_train.py b/tests/integ/test_chainer_train.py index 56673de77b..add2efce1a 100644 --- a/tests/integ/test_chainer_train.py +++ b/tests/integ/test_chainer_train.py @@ -137,6 +137,8 @@ def _run_mnist_training_job( train_instance_type=instance_type, sagemaker_session=sagemaker_session, hyperparameters={"epochs": 1}, + # test output_path without trailing slash + output_path="s3://{}".format(sagemaker_session.default_bucket()), ) train_input = "file://" + os.path.join(data_path, "train") diff --git a/tests/unit/test_local_utils.py b/tests/unit/test_local_utils.py index 349a920237..6b2a0b0266 100644 --- a/tests/unit/test_local_utils.py +++ b/tests/unit/test_local_utils.py @@ -20,15 +20,33 @@ @patch("shutil.rmtree", Mock()) @patch("sagemaker.local.utils.recursive_copy") -def test_move_to_destination(recursive_copy): +def test_move_to_destination_local(recursive_copy): # local files will just be recursively copied sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None) recursive_copy.assert_called_with("/tmp/data", "/target/dir/") - # s3 destination will upload to S3 + +@patch("shutil.rmtree", Mock()) +@patch("sagemaker.local.utils.recursive_copy") +def test_move_to_destination_s3(recursive_copy): sms = Mock() + + # without trailing slash in prefix sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path", "job", sms) - sms.upload_data.assert_called() + sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job") + recursive_copy.assert_not_called() + + # with trailing slash in prefix + sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path/", "job", sms) + sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job") + + # without path, with trailing slash + sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/", "job", sms) + sms.upload_data.assert_called_with("/tmp/data", "bucket", "job") + + # without path, without trailing slash + sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket", "job", sms) + sms.upload_data.assert_called_with("/tmp/data", "bucket", "job") def test_move_to_destination_illegal_destination():