@@ -1278,7 +1278,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1278
1278
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1279
1279
1280
1280
Raises:
1281
- ValueError: If waiting and the training job fails.
1281
+ ValueError: If the training job fails.
1282
1282
"""
1283
1283
1284
1284
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
@@ -1326,52 +1326,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1326
1326
last_describe_job_call = time .time ()
1327
1327
last_description = description
1328
1328
while True :
1329
- if len (stream_names ) < instance_count :
1330
- # Log streams are created whenever a container starts writing to stdout/err, so this list
1331
- # may be dynamic until we have a stream for every instance.
1332
- try :
1333
- streams = client .describe_log_streams (
1334
- logGroupName = log_group ,
1335
- logStreamNamePrefix = job_name + "/" ,
1336
- orderBy = "LogStreamName" ,
1337
- limit = instance_count ,
1338
- )
1339
- stream_names = [s ["logStreamName" ] for s in streams ["logStreams" ]]
1340
- positions .update (
1341
- [
1342
- (s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1343
- for s in stream_names
1344
- if s not in positions
1345
- ]
1346
- )
1347
- except ClientError as e :
1348
- # On the very first training job run on an account, there's no log group until
1349
- # the container starts logging, so ignore any errors thrown about that
1350
- err = e .response .get ("Error" , {})
1351
- if err .get ("Code" , None ) != "ResourceNotFoundException" :
1352
- raise
1353
-
1354
- if len (stream_names ) > 0 :
1355
- if dot :
1356
- print ("" )
1357
- dot = False
1358
- for idx , event in sagemaker .logs .multi_stream_iter (
1359
- client , log_group , stream_names , positions
1360
- ):
1361
- color_wrap (idx , event ["message" ])
1362
- ts , count = positions [stream_names [idx ]]
1363
- if event ["timestamp" ] == ts :
1364
- positions [stream_names [idx ]] = sagemaker .logs .Position (
1365
- timestamp = ts , skip = count + 1
1366
- )
1367
- else :
1368
- positions [stream_names [idx ]] = sagemaker .logs .Position (
1369
- timestamp = event ["timestamp" ], skip = 1
1370
- )
1371
- else :
1372
- dot = True
1373
- print ("." , end = "" )
1374
- sys .stdout .flush ()
1329
+ _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap )
1375
1330
if state == LogState .COMPLETE :
1376
1331
break
1377
1332
@@ -1404,6 +1359,87 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1404
1359
) * instance_count
1405
1360
print ("Billable seconds:" , int (billable_time .total_seconds ()) + 1 )
1406
1361
1362
+ def logs_for_transform_job (self , job_name , wait = False , poll = 10 ): # noqa: C901 - suppress complexity warning
1363
+ """Display the logs for a given transform job, optionally tailing them until the
1364
+ job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1365
+ based on which instance the log entry is from.
1366
+
1367
+ Args:
1368
+ job_name (str): Name of the transform job to display the logs for.
1369
+ wait (bool): Whether to keep looking for new log entries until the job completes (default: False).
1370
+ poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1371
+
1372
+ Raises:
1373
+ ValueError: If the transform job fails.
1374
+ """
1375
+
1376
+ description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1377
+ instance_count = description ['TransformResources' ]['InstanceCount' ]
1378
+ status = description ['TransformJobStatus' ]
1379
+
1380
+ stream_names = [] # The list of log streams
1381
+ positions = {} # The current position in each stream, map of stream name -> position
1382
+
1383
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1384
+ # to be interrupted by a transient exception.
1385
+ config = botocore .config .Config (retries = {'max_attempts' : 15 })
1386
+ client = self .boto_session .client ('logs' , config = config )
1387
+ log_group = '/aws/sagemaker/TransformJobs'
1388
+
1389
+ job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
1390
+
1391
+ state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1392
+ dot = False
1393
+
1394
+ color_wrap = sagemaker .logs .ColorWrap ()
1395
+
1396
+ # The loop below implements a state machine that alternates between checking the job status and
1397
+ # reading whatever is available in the logs at this point. Note, that if we were called with
1398
+ # wait == False, we never check the job status.
1399
+ #
1400
+ # If wait == TRUE and job is not completed, the initial state is TAILING
1401
+ # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is complete).
1402
+ #
1403
+ # The state table:
1404
+ #
1405
+ # STATE ACTIONS CONDITION NEW STATE
1406
+ # ---------------- ---------------- ----------------- ----------------
1407
+ # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1408
+ # Else TAILING
1409
+ # JOB_COMPLETE Read logs, Pause Any COMPLETE
1410
+ # COMPLETE Read logs, Exit N/A
1411
+ #
1412
+ # Notes:
1413
+ # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
1414
+ # the job was marked complete.
1415
+ last_describe_job_call = time .time ()
1416
+ while True :
1417
+ _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap )
1418
+ if state == LogState .COMPLETE :
1419
+ break
1420
+
1421
+ time .sleep (poll )
1422
+
1423
+ if state == LogState .JOB_COMPLETE :
1424
+ state = LogState .COMPLETE
1425
+ elif time .time () - last_describe_job_call >= 30 :
1426
+ description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1427
+ last_describe_job_call = time .time ()
1428
+
1429
+ status = description ['TransformJobStatus' ]
1430
+
1431
+ if status == 'Completed' or status == 'Failed' or status == 'Stopped' :
1432
+ print ()
1433
+ state = LogState .JOB_COMPLETE
1434
+
1435
+ if wait :
1436
+ self ._check_job_status (job_name , description , 'TransformJobStatus' )
1437
+ if dot :
1438
+ print ()
1439
+ # Customers are not billed for hardware provisioning, so billable time is less than total time
1440
+ billable_time = (description ['TransformEndTime' ] - description ['TransformStartTime' ]) * instance_count
1441
+ print ('Billable seconds:' , int (billable_time .total_seconds ()) + 1 )
1442
+
1407
1443
1408
1444
def container_def (image , model_data_url = None , env = None ):
1409
1445
"""Create a definition for executing a container as part of a SageMaker model.
@@ -1795,3 +1831,37 @@ def _vpc_config_from_training_job(
1795
1831
return training_job_desc .get (vpc_utils .VPC_CONFIG_KEY )
1796
1832
else :
1797
1833
return vpc_utils .sanitize (vpc_config_override )
1834
+
1835
+
1836
+ def _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap ):
1837
+ if len (stream_names ) < instance_count :
1838
+ # Log streams are created whenever a container starts writing to stdout/err, so this list
1839
+ # may be dynamic until we have a stream for every instance.
1840
+ try :
1841
+ streams = client .describe_log_streams (logGroupName = log_group , logStreamNamePrefix = job_name + '/' ,
1842
+ orderBy = 'LogStreamName' , limit = instance_count )
1843
+ stream_names = [s ['logStreamName' ] for s in streams ['logStreams' ]]
1844
+ positions .update ([(s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1845
+ for s in stream_names if s not in positions ])
1846
+ except ClientError as e :
1847
+ # On the very first training job run on an account, there's no log group until
1848
+ # the container starts logging, so ignore any errors thrown about that
1849
+ err = e .response .get ('Error' , {})
1850
+ if err .get ('Code' , None ) != 'ResourceNotFoundException' :
1851
+ raise
1852
+
1853
+ if len (stream_names ) > 0 :
1854
+ if dot :
1855
+ print ('' )
1856
+ dot = False
1857
+ for idx , event in sagemaker .logs .multi_stream_iter (client , log_group , stream_names , positions ):
1858
+ color_wrap (idx , event ['message' ])
1859
+ ts , count = positions [stream_names [idx ]]
1860
+ if event ['timestamp' ] == ts :
1861
+ positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = ts , skip = count + 1 )
1862
+ else :
1863
+ positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = event ['timestamp' ], skip = 1 )
1864
+ else :
1865
+ dot = True
1866
+ print ('.' , end = '' )
1867
+ sys .stdout .flush ()
0 commit comments