Skip to content

Commit fb69406

Browse files
trungleducakrishna1995
authored andcommitted
Check for directory key
1 parent 8c2012b commit fb69406

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/sagemaker/session.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
472472

473473
# Initialize the variables used to loop through the contents of the S3 bucket.
474474
keys = []
475+
directories = []
475476
next_token = ""
476477
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
477478

@@ -490,20 +491,26 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
490491
return []
491492
# For each object, save its key or directory.
492493
for s3_object in contents:
493-
key = s3_object.get("Key")
494-
keys.append(key)
494+
key: str = s3_object.get("Key")
495+
obj_size = s3_object.get("Size")
496+
if key.endswith("/") and int(obj_size) == 0:
497+
directories.append(key)
498+
else:
499+
keys.append(key)
495500
next_token = response.get("NextContinuationToken")
496501

497502
# For each object key, create the directory on the local machine if needed, and then
498503
# download the file.
499504
downloaded_paths = []
505+
for dir_path in directories:
506+
os.makedirs(os.path.join(path, dir_path), exist_ok=True)
500507
for key in keys:
501508
tail_s3_uri_path = os.path.basename(key)
502509
if not os.path.splitext(key_prefix)[1]:
503510
tail_s3_uri_path = os.path.relpath(key, key_prefix)
504511
destination_path = os.path.join(path, tail_s3_uri_path)
505512
if not os.path.exists(os.path.dirname(destination_path)):
506-
os.makedirs(os.path.dirname(destination_path))
513+
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
507514
s3.download_file(
508515
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
509516
)

0 commit comments

Comments
 (0)