Skip to content

Commit 7680e26

Browse files
committed
Add test
1 parent 8f70931 commit 7680e26

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

src/sagemaker/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
494494
key: str = s3_object.get("Key")
495495
obj_size = s3_object.get("Size")
496496
if key.endswith("/") and int(obj_size) == 0:
497-
directories.append(key)
497+
directories.append(os.path.join(path, key))
498498
else:
499499
keys.append(key)
500500
next_token = response.get("NextContinuationToken")
@@ -503,7 +503,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
503503
# download the file.
504504
downloaded_paths = []
505505
for dir_path in directories:
506-
os.makedirs(os.path.join(path, dir_path), exist_ok=True)
506+
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
507507
for key in keys:
508508
tail_s3_uri_path = os.path.basename(key)
509509
if not os.path.splitext(key_prefix)[1]:

tests/unit/test_session.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6113,3 +6113,73 @@ def test_update_inference_component(sagemaker_session):
61136113
)
61146114

61156115
sagemaker_session.sagemaker_client.update_inference_component.assert_called_with(**request)
6116+
@patch("os.makedirs")
6117+
def test_download_data_with_only_directory(makedirs, sagemaker_session):
6118+
sagemaker_session.s3_client = Mock()
6119+
sagemaker_session.s3_client.list_objects_v2 = Mock(
6120+
return_value={
6121+
"Contents": [
6122+
{
6123+
"Key": "foo/bar/",
6124+
"Size": 0,
6125+
}
6126+
]
6127+
}
6128+
)
6129+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6130+
6131+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6132+
sagemaker_session.s3_client.download_file.assert_not_called()
6133+
6134+
6135+
@patch("os.makedirs")
6136+
def test_download_data_with_only_file(makedirs, sagemaker_session):
6137+
sagemaker_session.s3_client = Mock()
6138+
sagemaker_session.s3_client.list_objects_v2 = Mock(
6139+
return_value={
6140+
"Contents": [
6141+
{
6142+
"Key": "foo/bar/mode.tar.gz",
6143+
"Size": 100,
6144+
}
6145+
]
6146+
}
6147+
)
6148+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6149+
6150+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6151+
sagemaker_session.s3_client.download_file.assert_called_with(
6152+
Bucket="foo-bucket",
6153+
Key="foo/bar/mode.tar.gz",
6154+
Filename="./foo/bar/mode.tar.gz",
6155+
ExtraArgs=None,
6156+
)
6157+
6158+
6159+
@patch("os.makedirs")
6160+
def test_download_data_with_file_and_directory(makedirs, sagemaker_session):
6161+
sagemaker_session.s3_client = Mock()
6162+
sagemaker_session.s3_client.list_objects_v2 = Mock(
6163+
return_value={
6164+
"Contents": [
6165+
{
6166+
"Key": "foo/bar/",
6167+
"Size": 0,
6168+
},
6169+
{
6170+
"Key": "foo/bar/mode.tar.gz",
6171+
"Size": 100,
6172+
},
6173+
]
6174+
}
6175+
)
6176+
sagemaker_session.download_data(path=".", bucket="foo-bucket")
6177+
6178+
makedirs.assert_called_with("./foo/bar", exist_ok=True)
6179+
makedirs.assert_has_calls([call("./foo/bar", exist_ok=True), call("./foo/bar", exist_ok=True)])
6180+
sagemaker_session.s3_client.download_file.assert_called_with(
6181+
Bucket="foo-bucket",
6182+
Key="foo/bar/mode.tar.gz",
6183+
Filename="./foo/bar/mode.tar.gz",
6184+
ExtraArgs=None,
6185+
)

0 commit comments

Comments
 (0)