From 3903f19e0574c02208518ab9af147e08b120f471 Mon Sep 17 00:00:00 2001 From: Duc Trung Le Date: Mon, 4 Dec 2023 12:04:39 +0100 Subject: [PATCH 1/2] Update model path in local mode --- src/sagemaker/local/image.py | 1 + src/sagemaker/local/utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 98a5a7c629..8f6410c9b4 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -430,6 +430,7 @@ def retrieve_artifacts(self, compose_data, output_data_config, job_name): output_data_config["S3OutputPath"], job_name, self.sagemaker_session, + prefix="output", ) _delete_tree(model_artifacts) diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 298c95acb6..16375de7d4 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -53,7 +53,7 @@ def copy_directory_structure(destination_directory, relative_path): os.makedirs(destination_directory, relative_path) -def move_to_destination(source, destination, job_name, sagemaker_session): +def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""): """Move source to destination. Can handle uploading to S3. @@ -64,6 +64,8 @@ def move_to_destination(source, destination, job_name, sagemaker_session): job_name (str): SageMaker job name. sagemaker_session (sagemaker.Session): a sagemaker_session to interact with S3 if needed + prefix (str, optional): the directory on S3 used to save files, default + to the root of ``destination`` Returns: (str): destination URI @@ -75,7 +77,7 @@ def move_to_destination(source, destination, job_name, sagemaker_session): final_uri = destination elif parsed_uri.scheme == "s3": bucket = parsed_uri.netloc - path = s3.s3_path_join(parsed_uri.path, job_name) + path = s3.s3_path_join(parsed_uri.path, job_name, prefix) final_uri = s3.s3_path_join("s3://", bucket, path) sagemaker_session.upload_data(source, bucket, path) else: From 8ad869cf64b8c2195e64469fdd5d31730cfcd8a0 Mon Sep 17 00:00:00 2001 From: Duc Trung Le Date: Mon, 4 Dec 2023 13:44:56 +0100 Subject: [PATCH 2/2] Add test --- tests/unit/sagemaker/local/test_local_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index 2db8c83351..39b9e2b392 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -66,6 +66,18 @@ def test_move_to_destination_s3(recursive_copy): sms.upload_data.assert_called_with("/tmp/data", "bucket", "job") +@patch("shutil.rmtree", Mock()) +def test_move_to_destination_s3_with_prefix(): + sms = Mock( + settings=SessionSettings(), + ) + uri = sagemaker.local.utils.move_to_destination( + "/tmp/data", "s3://bucket/path", "job", sms, "foo_prefix" + ) + sms.upload_data.assert_called_with("/tmp/data", "bucket", "path/job/foo_prefix") + assert uri == "s3://bucket/path/job/foo_prefix" + + def test_move_to_destination_illegal_destination(): with pytest.raises(ValueError): sagemaker.local.utils.move_to_destination("/tmp/data", "ftp://ftp/in/2018", "job", None)