Skip to content

fix: allow output_path without trailing slash in Local Mode training jobs #1439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
final_uri = destination
elif parsed_uri.scheme == "s3":
bucket = parsed_uri.netloc
path = "%s%s" % (parsed_uri.path.lstrip("/"), job_name)
path = _create_s3_prefix(parsed_uri.path, job_name)
final_uri = "s3://%s/%s" % (bucket, path)
sagemaker_session.upload_data(source, bucket, path)
else:
Expand All @@ -72,6 +72,22 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
return final_uri


def _create_s3_prefix(path, job_name):
"""Constructs a path out of the given path and job name to be
used as an S3 prefix.

Args:
path (str): the original path. If the path is only ``"/"``,
then it is ignored.
job_name (str): the job name to be appended to the path.

Returns:
str: an S3 prefix of the form ``"path/job_name"``
"""
path = path.strip("/")
return job_name if path == "" else "/".join((path, job_name))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious to know instead of doing string manipulation, wouldn't os.path.join or pathlib.Path a better options for correctly making unified paths for local and non-local modes?

Copy link
Contributor Author

@laurenyu laurenyu Apr 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would, except that S3 always uses "/" as a separator, so using os.path.join would fail for Windows users.

as far as I can tell, there's not an easy way of overriding the separator used in os.path.join (I wrote a little about this in #1435). The other option would be to use something like urlunparse, but I think it might be a little heavy-handed for this use-case.

edit: looking into pathlib

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PurePosixPath does actually look to solve the problem, except that it was introduced in Python 3.4 and never backported to Python 2.7. since this SDK still currently supports Python 2.7, we can't use it here either. however, I'll make a note in our internal backlog to use PurePosixPath after we do get around to dropping Python 2.7 support. Thanks for sharing that! We've had quite a few conversations on our team around finding a module that would work no matter the OS, but hadn't yet found pathlib.



def recursive_copy(source, destination):
"""A wrapper around distutils.dir_util.copy_tree but won't throw any
exception when the source directory does not exist.
Expand Down
2 changes: 2 additions & 0 deletions tests/integ/test_chainer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def _run_mnist_training_job(
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
hyperparameters={"epochs": 1},
# test output_path without trailing slash
output_path="s3://{}".format(sagemaker_session.default_bucket()),
)

train_input = "file://" + os.path.join(data_path, "train")
Expand Down
24 changes: 21 additions & 3 deletions tests/unit/test_local_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,33 @@

@patch("shutil.rmtree", Mock())
@patch("sagemaker.local.utils.recursive_copy")
def test_move_to_destination(recursive_copy):
def test_move_to_destination_local(recursive_copy):
# local files will just be recursively copied
sagemaker.local.utils.move_to_destination("/tmp/data", "file:///target/dir/", "job", None)
recursive_copy.assert_called_with("/tmp/data", "/target/dir/")

# s3 destination will upload to S3

@patch("shutil.rmtree", Mock())
@patch("sagemaker.local.utils.recursive_copy")
def test_move_to_destination_s3(recursive_copy):
sms = Mock()

# without trailing slash in prefix
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path", "job", sms)
sms.upload_data.assert_called()
sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job")
recursive_copy.assert_not_called()

# with trailing slash in prefix
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/path/", "job", sms)
sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job")

# without path, with trailing slash
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket/", "job", sms)
sms.upload_data.assert_called_with("/tmp/data", "bucket", "job")

# without path, without trailing slash
sagemaker.local.utils.move_to_destination("/tmp/data", "s3://bucket", "job", sms)
sms.upload_data.assert_called_with("/tmp/data", "bucket", "job")


def test_move_to_destination_illegal_destination():
Expand Down