@@ -1295,11 +1295,8 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1295
1295
client = self .boto_session .client ("logs" , config = config )
1296
1296
log_group = "/aws/sagemaker/TrainingJobs"
1297
1297
1298
- job_already_completed = (
1299
- True if status == "Completed" or status == "Failed" or status == "Stopped" else False
1300
- )
1298
+ state = _get_initial_job_state (description , 'TrainingJobStatus' , wait )
1301
1299
1302
- state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1303
1300
dot = False
1304
1301
1305
1302
color_wrap = sagemaker .logs .ColorWrap ()
@@ -1375,7 +1372,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
1375
1372
1376
1373
description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1377
1374
instance_count = description ['TransformResources' ]['InstanceCount' ]
1378
- status = description ['TransformJobStatus' ]
1379
1375
1380
1376
stream_names = [] # The list of log streams
1381
1377
positions = {} # The current position in each stream, map of stream name -> position
@@ -1386,9 +1382,8 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
1386
1382
client = self .boto_session .client ('logs' , config = config )
1387
1383
log_group = '/aws/sagemaker/TransformJobs'
1388
1384
1389
- job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1385
+ state = _get_initial_job_state ( description , 'TransformJobStatus' , wait )
1390
1386
1391
- state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1392
1387
dot = False
1393
1388
1394
1389
color_wrap = sagemaker .logs .ColorWrap ()
@@ -1436,9 +1431,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
1436
1431
self ._check_job_status (job_name , description , 'TransformJobStatus' )
1437
1432
if dot :
1438
1433
print ()
1439
- # Customers are not billed for hardware provisioning, so billable time is less than total time
1440
- billable_time = (description ['TransformEndTime' ] - description ['TransformStartTime' ]) * instance_count
1441
- print ('Billable seconds:' , int (billable_time .total_seconds ()) + 1 )
1442
1434
1443
1435
1444
1436
def container_def (image , model_data_url = None , env = None ):
@@ -1833,6 +1825,12 @@ def _vpc_config_from_training_job(
1833
1825
return vpc_utils .sanitize (vpc_config_override )
1834
1826
1835
1827
1828
+ def _get_initial_job_state (description , status_key , wait ):
1829
+ status = description [status_key ]
1830
+ job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1831
+ return LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1832
+
1833
+
1836
1834
def _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap ):
1837
1835
if len (stream_names ) < instance_count :
1838
1836
# Log streams are created whenever a container starts writing to stdout/err, so this list
0 commit comments