diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index c8c37f87ff..cc1a14e6d6 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -472,6 +472,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): # Initialize the variables used to loop through the contents of the S3 bucket. keys = [] + directories = [] next_token = "" base_parameters = {"Bucket": bucket, "Prefix": key_prefix} @@ -490,20 +491,26 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): return [] # For each object, save its key or directory. for s3_object in contents: - key = s3_object.get("Key") - keys.append(key) + key: str = s3_object.get("Key") + obj_size = s3_object.get("Size") + if key.endswith("/") and int(obj_size) == 0: + directories.append(os.path.join(path, key)) + else: + keys.append(key) next_token = response.get("NextContinuationToken") # For each object key, create the directory on the local machine if needed, and then # download the file. downloaded_paths = [] + for dir_path in directories: + os.makedirs(os.path.dirname(dir_path), exist_ok=True) for key in keys: tail_s3_uri_path = os.path.basename(key) if not os.path.splitext(key_prefix)[1]: tail_s3_uri_path = os.path.relpath(key, key_prefix) destination_path = os.path.join(path, tail_s3_uri_path) if not os.path.exists(os.path.dirname(destination_path)): - os.makedirs(os.path.dirname(destination_path)) + os.makedirs(os.path.dirname(destination_path), exist_ok=True) s3.download_file( Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d6079a098e..f36f933f3e 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -6113,3 +6113,75 @@ def test_update_inference_component(sagemaker_session): ) sagemaker_session.sagemaker_client.update_inference_component.assert_called_with(**request) + + +@patch("os.makedirs") +def test_download_data_with_only_directory(makedirs, sagemaker_session): + sagemaker_session.s3_client = Mock() + sagemaker_session.s3_client.list_objects_v2 = Mock( + return_value={ + "Contents": [ + { + "Key": "foo/bar/", + "Size": 0, + } + ] + } + ) + sagemaker_session.download_data(path=".", bucket="foo-bucket") + + makedirs.assert_called_with("./foo/bar", exist_ok=True) + sagemaker_session.s3_client.download_file.assert_not_called() + + +@patch("os.makedirs") +def test_download_data_with_only_file(makedirs, sagemaker_session): + sagemaker_session.s3_client = Mock() + sagemaker_session.s3_client.list_objects_v2 = Mock( + return_value={ + "Contents": [ + { + "Key": "foo/bar/mode.tar.gz", + "Size": 100, + } + ] + } + ) + sagemaker_session.download_data(path=".", bucket="foo-bucket") + + makedirs.assert_called_with("./foo/bar", exist_ok=True) + sagemaker_session.s3_client.download_file.assert_called_with( + Bucket="foo-bucket", + Key="foo/bar/mode.tar.gz", + Filename="./foo/bar/mode.tar.gz", + ExtraArgs=None, + ) + + +@patch("os.makedirs") +def test_download_data_with_file_and_directory(makedirs, sagemaker_session): + sagemaker_session.s3_client = Mock() + sagemaker_session.s3_client.list_objects_v2 = Mock( + return_value={ + "Contents": [ + { + "Key": "foo/bar/", + "Size": 0, + }, + { + "Key": "foo/bar/mode.tar.gz", + "Size": 100, + }, + ] + } + ) + sagemaker_session.download_data(path=".", bucket="foo-bucket") + + makedirs.assert_called_with("./foo/bar", exist_ok=True) + makedirs.assert_has_calls([call("./foo/bar", exist_ok=True), call("./foo/bar", exist_ok=True)]) + sagemaker_session.s3_client.download_file.assert_called_with( + Bucket="foo-bucket", + Key="foo/bar/mode.tar.gz", + Filename="./foo/bar/mode.tar.gz", + ExtraArgs=None, + )