Skip to content

Commit b54216d

Browse files
committed
Check for directory key
1 parent ef8dd31 commit b54216d

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/sagemaker/session.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
460460

461461
# Initialize the variables used to loop through the contents of the S3 bucket.
462462
keys = []
463+
directories = []
463464
next_token = ""
464465
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
465466

@@ -478,20 +479,26 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
478479
return []
479480
# For each object, save its key or directory.
480481
for s3_object in contents:
481-
key = s3_object.get("Key")
482-
keys.append(key)
482+
key: str = s3_object.get("Key")
483+
obj_size = s3_object.get("Size")
484+
if key.endswith("/") and int(obj_size) == 0:
485+
directories.append(key)
486+
else:
487+
keys.append(key)
483488
next_token = response.get("NextContinuationToken")
484489

485490
# For each object key, create the directory on the local machine if needed, and then
486491
# download the file.
487492
downloaded_paths = []
493+
for dir_path in directories:
494+
os.makedirs(os.path.join(path, dir_path), exist_ok=True)
488495
for key in keys:
489496
tail_s3_uri_path = os.path.basename(key)
490497
if not os.path.splitext(key_prefix)[1]:
491498
tail_s3_uri_path = os.path.relpath(key, key_prefix)
492499
destination_path = os.path.join(path, tail_s3_uri_path)
493500
if not os.path.exists(os.path.dirname(destination_path)):
494-
os.makedirs(os.path.dirname(destination_path))
501+
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
495502
s3.download_file(
496503
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
497504
)

0 commit comments

Comments
 (0)