Skip to content

Commit 9ce672e

Browse files
trungleducsagemaker-bot
authored andcommitted
fix: Session.download_data can not download nested objects (aws#4277)
* Check for directory key * Add test * Lint
1 parent 8bb5157 commit 9ce672e

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

src/sagemaker/session.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
472472

473473
# Initialize the variables used to loop through the contents of the S3 bucket.
474474
keys = []
475+
directories = []
475476
next_token = ""
476477
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
477478

@@ -490,20 +491,26 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
490491
return []
491492
# For each object, save its key or directory.
492493
for s3_object in contents:
493-
key = s3_object.get("Key")
494-
keys.append(key)
494+
key: str = s3_object.get("Key")
495+
obj_size = s3_object.get("Size")
496+
if key.endswith("/") and int(obj_size) == 0:
497+
directories.append(os.path.join(path, key))
498+
else:
499+
keys.append(key)
495500
next_token = response.get("NextContinuationToken")
496501

497502
# For each object key, create the directory on the local machine if needed, and then
498503
# download the file.
499504
downloaded_paths = []
505+
for dir_path in directories:
506+
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
500507
for key in keys:
501508
tail_s3_uri_path = os.path.basename(key)
502509
if not os.path.splitext(key_prefix)[1]:
503510
tail_s3_uri_path = os.path.relpath(key, key_prefix)
504511
destination_path = os.path.join(path, tail_s3_uri_path)
505512
if not os.path.exists(os.path.dirname(destination_path)):
506-
os.makedirs(os.path.dirname(destination_path))
513+
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
507514
s3.download_file(
508515
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
509516
)

tests/unit/test_session.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6113,3 +6113,75 @@ def test_update_inference_component(sagemaker_session):
61136113
)
61146114

61156115
sagemaker_session.sagemaker_client.update_inference_component.assert_called_with(**request)
6116+
6117+
6118+
@patch("os.makedirs")
6119+
def test_download_data_with_only_directory(makedirs, sagemaker_session):
6120+
sagemaker_session.s3_client = Mock()
6121+
sagemaker_session.s3_client.list_objects_v2 = Mock(
6122+
return_value={
6123+
"Contents": [
6124+
{
6125+
"Key": "foo/bar/",
6126+
"Size": 0,
6127+
}
6128+
]
6129+
}
6130+
)
6131+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6132+
6133+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6134+
sagemaker_session.s3_client.download_file.assert_not_called()
6135+
6136+
6137+
@patch("os.makedirs")
6138+
def test_download_data_with_only_file(makedirs, sagemaker_session):
6139+
sagemaker_session.s3_client = Mock()
6140+
sagemaker_session.s3_client.list_objects_v2 = Mock(
6141+
return_value={
6142+
"Contents": [
6143+
{
6144+
"Key": "foo/bar/mode.tar.gz",
6145+
"Size": 100,
6146+
}
6147+
]
6148+
}
6149+
)
6150+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6151+
6152+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6153+
sagemaker_session.s3_client.download_file.assert_called_with(
6154+
Bucket="foo-bucket",
6155+
Key="foo/bar/mode.tar.gz",
6156+
Filename="./foo/bar/mode.tar.gz",
6157+
ExtraArgs=None,
6158+
)
6159+
6160+
6161+
@patch("os.makedirs")
6162+
def test_download_data_with_file_and_directory(makedirs, sagemaker_session):
6163+
sagemaker_session.s3_client = Mock()
6164+
sagemaker_session.s3_client.list_objects_v2 = Mock(
6165+
return_value={
6166+
"Contents": [
6167+
{
6168+
"Key": "foo/bar/",
6169+
"Size": 0,
6170+
},
6171+
{
6172+
"Key": "foo/bar/mode.tar.gz",
6173+
"Size": 100,
6174+
},
6175+
]
6176+
}
6177+
)
6178+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6179+
6180+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6181+
makedirs.assert_has_calls([call("./foo/bar", exist_ok=True), call("./foo/bar", exist_ok=True)])
6182+
sagemaker_session.s3_client.download_file.assert_called_with(
6183+
Bucket="foo-bucket",
6184+
Key="foo/bar/mode.tar.gz",
6185+
Filename="./foo/bar/mode.tar.gz",
6186+
ExtraArgs=None,
6187+
)

0 commit comments

Comments
 (0)