Skip to content

Commit 1adcb76

Browse files
committed
Add test
1 parent b54216d commit 1adcb76

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

src/sagemaker/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
482482
key: str = s3_object.get("Key")
483483
obj_size = s3_object.get("Size")
484484
if key.endswith("/") and int(obj_size) == 0:
485-
directories.append(key)
485+
directories.append(os.path.join(path, key))
486486
else:
487487
keys.append(key)
488488
next_token = response.get("NextContinuationToken")
@@ -491,7 +491,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
491491
# download the file.
492492
downloaded_paths = []
493493
for dir_path in directories:
494-
os.makedirs(os.path.join(path, dir_path), exist_ok=True)
494+
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
495495
for key in keys:
496496
tail_s3_uri_path = os.path.basename(key)
497497
if not os.path.splitext(key_prefix)[1]:

tests/unit/test_session.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5984,3 +5984,75 @@ def test_upload_data_default_bucket_and_prefix_combinations(
59845984
expected__with_user_input__with_default_bucket_only=expected__with_user_input__with_default_bucket_only,
59855985
)
59865986
assert actual == expected
5987+
5988+
5989+
@patch("os.makedirs")
5990+
def test_download_data_with_only_directory(makedirs, sagemaker_session):
5991+
sagemaker_session.s3_client = Mock()
5992+
sagemaker_session.s3_client.list_objects_v2 = Mock(
5993+
return_value={
5994+
"Contents": [
5995+
{
5996+
"Key": "foo/bar/",
5997+
"Size": 0,
5998+
}
5999+
]
6000+
}
6001+
)
6002+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6003+
6004+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6005+
sagemaker_session.s3_client.download_file.assert_not_called()
6006+
6007+
6008+
@patch("os.makedirs")
6009+
def test_download_data_with_only_file(makedirs, sagemaker_session):
6010+
sagemaker_session.s3_client = Mock()
6011+
sagemaker_session.s3_client.list_objects_v2 = Mock(
6012+
return_value={
6013+
"Contents": [
6014+
{
6015+
"Key": "foo/bar/mode.tar.gz",
6016+
"Size": 100,
6017+
}
6018+
]
6019+
}
6020+
)
6021+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6022+
6023+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6024+
sagemaker_session.s3_client.download_file.assert_called_with(
6025+
Bucket="foo-bucket",
6026+
Key="foo/bar/mode.tar.gz",
6027+
Filename="./foo/bar/mode.tar.gz",
6028+
ExtraArgs=None,
6029+
)
6030+
6031+
6032+
@patch("os.makedirs")
6033+
def test_download_data_with_file_and_directory(makedirs, sagemaker_session):
6034+
sagemaker_session.s3_client = Mock()
6035+
sagemaker_session.s3_client.list_objects_v2 = Mock(
6036+
return_value={
6037+
"Contents": [
6038+
{
6039+
"Key": "foo/bar/",
6040+
"Size": 0,
6041+
},
6042+
{
6043+
"Key": "foo/bar/mode.tar.gz",
6044+
"Size": 100,
6045+
},
6046+
]
6047+
}
6048+
)
6049+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6050+
6051+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6052+
makedirs.assert_has_calls([call("./foo/bar", exist_ok=True), call("./foo/bar", exist_ok=True)])
6053+
sagemaker_session.s3_client.download_file.assert_called_with(
6054+
Bucket="foo-bucket",
6055+
Key="foo/bar/mode.tar.gz",
6056+
Filename="./foo/bar/mode.tar.gz",
6057+
ExtraArgs=None,
6058+
)

0 commit comments

Comments
 (0)