@@ -1124,7 +1124,6 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
1124
1124
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
1125
1125
print (secondary_training_status_message (description , None ), end = '' )
1126
1126
instance_count = description ['ResourceConfig' ]['InstanceCount' ]
1127
- status = description ['TrainingJobStatus' ]
1128
1127
1129
1128
stream_names = [] # The list of log streams
1130
1129
positions = {} # The current position in each stream, map of stream name -> position
@@ -1135,9 +1134,8 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
1135
1134
client = self .boto_session .client ('logs' , config = config )
1136
1135
log_group = '/aws/sagemaker/TrainingJobs'
1137
1136
1138
- job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1137
+ state = _get_initial_job_state ( description , 'TrainingJobStatus' , wait )
1139
1138
1140
- state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1141
1139
dot = False
1142
1140
1143
1141
color_wrap = sagemaker .logs .ColorWrap ()
@@ -1211,7 +1209,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
1211
1209
1212
1210
description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1213
1211
instance_count = description ['TransformResources' ]['InstanceCount' ]
1214
- status = description ['TransformJobStatus' ]
1215
1212
1216
1213
stream_names = [] # The list of log streams
1217
1214
positions = {} # The current position in each stream, map of stream name -> position
@@ -1222,9 +1219,8 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
1222
1219
client = self .boto_session .client ('logs' , config = config )
1223
1220
log_group = '/aws/sagemaker/TransformJobs'
1224
1221
1225
- job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1222
+ state = _get_initial_job_state ( description , 'TransformJobStatus' , wait )
1226
1223
1227
- state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1228
1224
dot = False
1229
1225
1230
1226
color_wrap = sagemaker .logs .ColorWrap ()
@@ -1272,9 +1268,6 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
1272
1268
self ._check_job_status (job_name , description , 'TransformJobStatus' )
1273
1269
if dot :
1274
1270
print ()
1275
- # Customers are not billed for hardware provisioning, so billable time is less than total time
1276
- billable_time = (description ['TransformEndTime' ] - description ['TransformStartTime' ]) * instance_count
1277
- print ('Billable seconds:' , int (billable_time .total_seconds ()) + 1 )
1278
1271
1279
1272
1280
1273
def container_def (image , model_data_url = None , env = None ):
@@ -1644,6 +1637,12 @@ def _vpc_config_from_training_job(training_job_desc, vpc_config_override=vpc_uti
1644
1637
return vpc_utils .sanitize (vpc_config_override )
1645
1638
1646
1639
1640
+ def _get_initial_job_state (description , status_key , wait ):
1641
+ status = description [status_key ]
1642
+ job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1643
+ return LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1644
+
1645
+
1647
1646
def _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap ):
1648
1647
if len (stream_names ) < instance_count :
1649
1648
# Log streams are created whenever a container starts writing to stdout/err, so this list
0 commit comments