@@ -1164,37 +1164,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
1164
1164
last_describe_job_call = time .time ()
1165
1165
last_description = description
1166
1166
while True :
1167
- if len (stream_names ) < instance_count :
1168
- # Log streams are created whenever a container starts writing to stdout/err, so this list
1169
- # may be dynamic until we have a stream for every instance.
1170
- try :
1171
- streams = client .describe_log_streams (logGroupName = log_group , logStreamNamePrefix = job_name + '/' ,
1172
- orderBy = 'LogStreamName' , limit = instance_count )
1173
- stream_names = [s ['logStreamName' ] for s in streams ['logStreams' ]]
1174
- positions .update ([(s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1175
- for s in stream_names if s not in positions ])
1176
- except ClientError as e :
1177
- # On the very first training job run on an account, there's no log group until
1178
- # the container starts logging, so ignore any errors thrown about that
1179
- err = e .response .get ('Error' , {})
1180
- if err .get ('Code' , None ) != 'ResourceNotFoundException' :
1181
- raise
1182
-
1183
- if len (stream_names ) > 0 :
1184
- if dot :
1185
- print ('' )
1186
- dot = False
1187
- for idx , event in sagemaker .logs .multi_stream_iter (client , log_group , stream_names , positions ):
1188
- color_wrap (idx , event ['message' ])
1189
- ts , count = positions [stream_names [idx ]]
1190
- if event ['timestamp' ] == ts :
1191
- positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = ts , skip = count + 1 )
1192
- else :
1193
- positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = event ['timestamp' ], skip = 1 )
1194
- else :
1195
- dot = True
1196
- print ('.' , end = '' )
1197
- sys .stdout .flush ()
1167
+ _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap )
1198
1168
if state == LogState .COMPLETE :
1199
1169
break
1200
1170
@@ -1225,6 +1195,87 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
1225
1195
billable_time = (description ['TrainingEndTime' ] - description ['TrainingStartTime' ]) * instance_count
1226
1196
print ('Billable seconds:' , int (billable_time .total_seconds ()) + 1 )
1227
1197
1198
+ def logs_for_transform_job (self , job_name , wait = False , poll = 10 ): # noqa: C901 - suppress complexity warning
1199
+ """Display the logs for a given transform job, optionally tailing them until the
1200
+ job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1201
+ based on which instance the log entry is from.
1202
+
1203
+ Args:
1204
+ job_name (str): Name of the transform job to display the logs for.
1205
+ wait (bool): Whether to keep looking for new log entries until the job completes (default: False).
1206
+ poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1207
+
1208
+ Raises:
1209
+ ValueError: If the transform job fails.
1210
+ """
1211
+
1212
+ description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1213
+ instance_count = description ['TransformResources' ]['InstanceCount' ]
1214
+ status = description ['TransformJobStatus' ]
1215
+
1216
+ stream_names = [] # The list of log streams
1217
+ positions = {} # The current position in each stream, map of stream name -> position
1218
+
1219
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1220
+ # to be interrupted by a transient exception.
1221
+ config = botocore .config .Config (retries = {'max_attempts' : 15 })
1222
+ client = self .boto_session .client ('logs' , config = config )
1223
+ log_group = '/aws/sagemaker/TransformJobs'
1224
+
1225
+ job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1226
+
1227
+ state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1228
+ dot = False
1229
+
1230
+ color_wrap = sagemaker .logs .ColorWrap ()
1231
+
1232
+ # The loop below implements a state machine that alternates between checking the job status and
1233
+ # reading whatever is available in the logs at this point. Note, that if we were called with
1234
+ # wait == False, we never check the job status.
1235
+ #
1236
+ # If wait == TRUE and job is not completed, the initial state is TAILING
1237
+ # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is complete).
1238
+ #
1239
+ # The state table:
1240
+ #
1241
+ # STATE ACTIONS CONDITION NEW STATE
1242
+ # ---------------- ---------------- ----------------- ----------------
1243
+ # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1244
+ # Else TAILING
1245
+ # JOB_COMPLETE Read logs, Pause Any COMPLETE
1246
+ # COMPLETE Read logs, Exit N/A
1247
+ #
1248
+ # Notes:
1249
+ # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
1250
+ # the job was marked complete.
1251
+ last_describe_job_call = time .time ()
1252
+ while True :
1253
+ _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap )
1254
+ if state == LogState .COMPLETE :
1255
+ break
1256
+
1257
+ time .sleep (poll )
1258
+
1259
+ if state == LogState .JOB_COMPLETE :
1260
+ state = LogState .COMPLETE
1261
+ elif time .time () - last_describe_job_call >= 30 :
1262
+ description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1263
+ last_describe_job_call = time .time ()
1264
+
1265
+ status = description ['TransformJobStatus' ]
1266
+
1267
+ if status == 'Completed' or status == 'Failed' or status == 'Stopped' :
1268
+ print ()
1269
+ state = LogState .JOB_COMPLETE
1270
+
1271
+ if wait :
1272
+ self ._check_job_status (job_name , description , 'TransformJobStatus' )
1273
+ if dot :
1274
+ print ()
1275
+ # Customers are not billed for hardware provisioning, so billable time is less than total time
1276
+ billable_time = (description ['TransformEndTime' ] - description ['TransformStartTime' ]) * instance_count
1277
+ print ('Billable seconds:' , int (billable_time .total_seconds ()) + 1 )
1278
+
1228
1279
1229
1280
def container_def (image , model_data_url = None , env = None ):
1230
1281
"""Create a definition for executing a container as part of a SageMaker model.
@@ -1591,3 +1642,37 @@ def _vpc_config_from_training_job(training_job_desc, vpc_config_override=vpc_uti
1591
1642
return training_job_desc .get (vpc_utils .VPC_CONFIG_KEY )
1592
1643
else :
1593
1644
return vpc_utils .sanitize (vpc_config_override )
1645
+
1646
+
1647
+ def _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap ):
1648
+ if len (stream_names ) < instance_count :
1649
+ # Log streams are created whenever a container starts writing to stdout/err, so this list
1650
+ # may be dynamic until we have a stream for every instance.
1651
+ try :
1652
+ streams = client .describe_log_streams (logGroupName = log_group , logStreamNamePrefix = job_name + '/' ,
1653
+ orderBy = 'LogStreamName' , limit = instance_count )
1654
+ stream_names = [s ['logStreamName' ] for s in streams ['logStreams' ]]
1655
+ positions .update ([(s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1656
+ for s in stream_names if s not in positions ])
1657
+ except ClientError as e :
1658
+ # On the very first training job run on an account, there's no log group until
1659
+ # the container starts logging, so ignore any errors thrown about that
1660
+ err = e .response .get ('Error' , {})
1661
+ if err .get ('Code' , None ) != 'ResourceNotFoundException' :
1662
+ raise
1663
+
1664
+ if len (stream_names ) > 0 :
1665
+ if dot :
1666
+ print ('' )
1667
+ dot = False
1668
+ for idx , event in sagemaker .logs .multi_stream_iter (client , log_group , stream_names , positions ):
1669
+ color_wrap (idx , event ['message' ])
1670
+ ts , count = positions [stream_names [idx ]]
1671
+ if event ['timestamp' ] == ts :
1672
+ positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = ts , skip = count + 1 )
1673
+ else :
1674
+ positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = event ['timestamp' ], skip = 1 )
1675
+ else :
1676
+ dot = True
1677
+ print ('.' , end = '' )
1678
+ sys .stdout .flush ()
0 commit comments