Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e08e5fb

Browse files
trungleducmufiAmazon
authored andcommittedDec 22, 2023
change: update model path in local mode (aws#4296)
* Update model path in local mode * Add test
1 parent 27f9c88 commit e08e5fb

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed
 

‎src/sagemaker/local/image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name):
430430
output_data_config["S3OutputPath"],
431431
job_name,
432432
self.sagemaker_session,
433+
prefix="output",
433434
)
434435

435436
_delete_tree(model_artifacts)

‎src/sagemaker/local/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def copy_directory_structure(destination_directory, relative_path):
5353
os.makedirs(destination_directory, relative_path)
5454

5555

56-
def move_to_destination(source, destination, job_name, sagemaker_session):
56+
def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):
5757
"""Move source to destination.
5858
5959
Can handle uploading to S3.
@@ -64,6 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
6464
job_name (str): SageMaker job name.
6565
sagemaker_session (sagemaker.Session): a sagemaker_session to interact
6666
with S3 if needed
67+
prefix (str, optional): the directory on S3 used to save files, default
68+
to the root of ``destination``
6769
6870
Returns:
6971
(str): destination URI
@@ -75,7 +77,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session):
7577
final_uri = destination
7678
elif parsed_uri.scheme == "s3":
7779
bucket = parsed_uri.netloc
78-
path = s3.s3_path_join(parsed_uri.path, job_name)
80+
path = s3.s3_path_join(parsed_uri.path, job_name, prefix)
7981
final_uri = s3.s3_path_join("s3://", bucket, path)
8082
sagemaker_session.upload_data(source, bucket, path)
8183
else:

‎tests/unit/sagemaker/local/test_local_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ def test_move_to_destination_s3(recursive_copy):
6666
sms.upload_data.assert_called_with("/tmp/data", "bucket", "job")
6767

6868

69+
@patch("shutil.rmtree", Mock())
70+
def test_move_to_destination_s3_with_prefix():
71+
sms = Mock(
72+
settings=SessionSettings(),
73+
)
74+
uri = sagemaker.local.utils.move_to_destination(
75+
"/tmp/data", "s3://bucket/path", "job", sms, "foo_prefix"
76+
)
77+
sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job/foo_prefix")
78+
assert uri == "s3://bucket/path/job/foo_prefix"
79+
80+
6981
def test_move_to_destination_illegal_destination():
7082
with pytest.raises(ValueError):
7183
sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None)

0 commit comments

Comments
 (0)
Please sign in to comment.