@@ -472,6 +472,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
472
472
473
473
# Initialize the variables used to loop through the contents of the S3 bucket.
474
474
keys = []
475
+ directories = []
475
476
next_token = ""
476
477
base_parameters = {"Bucket" : bucket , "Prefix" : key_prefix }
477
478
@@ -490,20 +491,26 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
490
491
return []
491
492
# For each object, save its key or directory.
492
493
for s3_object in contents :
493
- key = s3_object .get ("Key" )
494
- keys .append (key )
494
+ key : str = s3_object .get ("Key" )
495
+ obj_size = s3_object .get ("Size" )
496
+ if key .endswith ("/" ) and int (obj_size ) == 0 :
497
+ directories .append (os .path .join (path , key ))
498
+ else :
499
+ keys .append (key )
495
500
next_token = response .get ("NextContinuationToken" )
496
501
497
502
# For each object key, create the directory on the local machine if needed, and then
498
503
# download the file.
499
504
downloaded_paths = []
505
+ for dir_path in directories :
506
+ os .makedirs (os .path .dirname (dir_path ), exist_ok = True )
500
507
for key in keys :
501
508
tail_s3_uri_path = os .path .basename (key )
502
509
if not os .path .splitext (key_prefix )[1 ]:
503
510
tail_s3_uri_path = os .path .relpath (key , key_prefix )
504
511
destination_path = os .path .join (path , tail_s3_uri_path )
505
512
if not os .path .exists (os .path .dirname (destination_path )):
506
- os .makedirs (os .path .dirname (destination_path ))
513
+ os .makedirs (os .path .dirname (destination_path ), exist_ok = True )
507
514
s3 .download_file (
508
515
Bucket = bucket , Key = key , Filename = destination_path , ExtraArgs = extra_args
509
516
)
@@ -5447,7 +5454,7 @@ def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=No
5447
5454
exceptions.CapacityError: If the training job fails with CapacityError.
5448
5455
exceptions.UnexpectedStatusException: If waiting and the training job fails.
5449
5456
"""
5450
- _logs_for_job (self . boto_session , job_name , wait , poll , log_type , timeout )
5457
+ _logs_for_job (self , job_name , wait , poll , log_type , timeout )
5451
5458
5452
5459
def logs_for_processing_job (self , job_name , wait = False , poll = 10 ):
5453
5460
"""Display logs for a given processing job, optionally tailing them until the is complete.
@@ -7330,17 +7337,16 @@ def _rule_statuses_changed(current_statuses, last_statuses):
7330
7337
7331
7338
7332
7339
def _logs_for_job ( # noqa: C901 - suppress complexity warning for this method
7333
- boto_session , job_name , wait = False , poll = 10 , log_type = "All" , timeout = None
7340
+ sagemaker_session , job_name , wait = False , poll = 10 , log_type = "All" , timeout = None
7334
7341
):
7335
7342
"""Display logs for a given training job, optionally tailing them until job is complete.
7336
7343
7337
7344
If the output is a tty or a Jupyter cell, it will be color-coded
7338
7345
based on which instance the log entry is from.
7339
7346
7340
7347
Args:
7341
- boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
7342
- calls are delegated to (default: None). If not provided, one is created with
7343
- default AWS configuration chain.
7348
+ sagemaker_session (sagemaker.session.Session): A SageMaker Session
7349
+ object, used for SageMaker interactions.
7344
7350
job_name (str): Name of the training job to display the logs for.
7345
7351
wait (bool): Whether to keep looking for new log entries until the job completes
7346
7352
(default: False).
@@ -7357,13 +7363,13 @@ def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
7357
7363
exceptions.CapacityError: If the training job fails with CapacityError.
7358
7364
exceptions.UnexpectedStatusException: If waiting and the training job fails.
7359
7365
"""
7360
- sagemaker_client = boto_session . client ( "sagemaker" )
7366
+ sagemaker_client = sagemaker_session . sagemaker_client
7361
7367
request_end_time = time .time () + timeout if timeout else None
7362
7368
description = sagemaker_client .describe_training_job (TrainingJobName = job_name )
7363
7369
print (secondary_training_status_message (description , None ), end = "" )
7364
7370
7365
7371
instance_count , stream_names , positions , client , log_group , dot , color_wrap = _logs_init (
7366
- boto_session , description , job = "Training"
7372
+ sagemaker_session . boto_session , description , job = "Training"
7367
7373
)
7368
7374
7369
7375
state = _get_initial_job_state (description , "TrainingJobStatus" , wait )
0 commit comments