@@ -1225,6 +1225,116 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
1225
1225
billable_time = (description ['TrainingEndTime' ] - description ['TrainingStartTime' ]) * instance_count
1226
1226
print ('Billable seconds:' , int (billable_time .total_seconds ()) + 1 )
1227
1227
1228
+ def logs_for_transform_job (self , job_name , wait = False , poll = 10 ):
1229
+ """Display the logs for a given transform job, optionally tailing them until the
1230
+ job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1231
+ based on which instance the log entry is from.
1232
+ Args:
1233
+ job_name (str): Name of the transform job to display the logs for.
1234
+ wait (bool): Whether to keep looking for new log entries until the job completes (default: False).
1235
+ poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1236
+ Raises:
1237
+ ValueError: If waiting and the transform job fails.
1238
+ """
1239
+
1240
+ description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1241
+ instance_count = description ['TransformResources' ]['InstanceCount' ]
1242
+ status = description ['TransformJobStatus' ]
1243
+
1244
+ stream_names = [] # The list of log streams
1245
+ positions = {} # The current position in each stream, map of stream name -> position
1246
+
1247
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1248
+ # to be interrupted by a transient exception.
1249
+ config = botocore .config .Config (retries = {'max_attempts' : 15 })
1250
+ client = self .boto_session .client ('logs' , config = config )
1251
+ log_group = '/aws/sagemaker/TransformJobs'
1252
+
1253
+ job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1254
+
1255
+ state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1256
+ dot = False
1257
+
1258
+ color_wrap = sagemaker .logs .ColorWrap ()
1259
+
1260
+ # The loop below implements a state machine that alternates between checking the job status and
1261
+ # reading whatever is available in the logs at this point. Note, that if we were called with
1262
+ # wait == False, we never check the job status.
1263
+ #
1264
+ # If wait == TRUE and job is not completed, the initial state is TAILING
1265
+ # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is complete).
1266
+ #
1267
+ # The state table:
1268
+ #
1269
+ # STATE ACTIONS CONDITION NEW STATE
1270
+ # ---------------- ---------------- ----------------- ----------------
1271
+ # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1272
+ # Else TAILING
1273
+ # JOB_COMPLETE Read logs, Pause Any COMPLETE
1274
+ # COMPLETE Read logs, Exit N/A
1275
+ #
1276
+ # Notes:
1277
+ # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
1278
+ # the job was marked complete.
1279
+ last_describe_job_call = time .time ()
1280
+ last_description = description
1281
+ while True :
1282
+ if len (stream_names ) < instance_count :
1283
+ # Log streams are created whenever a container starts writing to stdout/err, so this list
1284
+ # may be dynamic until we have a stream for every instance.
1285
+ try :
1286
+ streams = client .describe_log_streams (logGroupName = log_group , logStreamNamePrefix = job_name + '/' ,
1287
+ orderBy = 'LogStreamName' , limit = instance_count )
1288
+ stream_names = [s ['logStreamName' ] for s in streams ['logStreams' ]]
1289
+ positions .update ([(s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1290
+ for s in stream_names if s not in positions ])
1291
+ except ClientError as e :
1292
+ # On the very first training job run on an account, there's no log group until
1293
+ # the container starts logging, so ignore any errors thrown about that
1294
+ err = e .response .get ('Error' , {})
1295
+ if err .get ('Code' , None ) != 'ResourceNotFoundException' :
1296
+ raise
1297
+
1298
+ if len (stream_names ) > 0 :
1299
+ if dot :
1300
+ print ('' )
1301
+ dot = False
1302
+ for idx , event in sagemaker .logs .multi_stream_iter (client , log_group , stream_names , positions ):
1303
+ color_wrap (idx , event ['message' ])
1304
+ ts , count = positions [stream_names [idx ]]
1305
+ if event ['timestamp' ] == ts :
1306
+ positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = ts , skip = count + 1 )
1307
+ else :
1308
+ positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = event ['timestamp' ], skip = 1 )
1309
+ else :
1310
+ dot = True
1311
+ print ('.' , end = '' )
1312
+ sys .stdout .flush ()
1313
+ if state == LogState .COMPLETE :
1314
+ break
1315
+
1316
+ time .sleep (poll )
1317
+
1318
+ if state == LogState .JOB_COMPLETE :
1319
+ state = LogState .COMPLETE
1320
+ elif time .time () - last_describe_job_call >= 30 :
1321
+ description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1322
+ last_describe_job_call = time .time ()
1323
+
1324
+ status = description ['TransformJobStatus' ]
1325
+
1326
+ if status == 'Completed' or status == 'Failed' or status == 'Stopped' :
1327
+ print ()
1328
+ state = LogState .JOB_COMPLETE
1329
+
1330
+ if wait :
1331
+ self ._check_job_status (job_name , description , 'TransformJobStatus' )
1332
+ if dot :
1333
+ print ()
1334
+ # Customers are not billed for hardware provisioning, so billable time is less than total time
1335
+ billable_time = (description ['TransformEndTime' ] - description ['TransformStartTime' ]) * instance_count
1336
+ print ('Billable seconds:' , int (billable_time .total_seconds ()) + 1 )
1337
+
1228
1338
1229
1339
def container_def (image , model_data_url = None , env = None ):
1230
1340
"""Create a definition for executing a container as part of a SageMaker model.
0 commit comments