@@ -1428,24 +1428,12 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1428
1428
1429
1429
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
1430
1430
print (secondary_training_status_message (description , None ), end = "" )
1431
- instance_count = description ["ResourceConfig" ]["InstanceCount" ]
1432
- status = description ["TrainingJobStatus" ]
1433
-
1434
- stream_names = [] # The list of log streams
1435
- positions = {} # The current position in each stream, map of stream name -> position
1436
-
1437
- # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1438
- # to be interrupted by a transient exception.
1439
- config = botocore .config .Config (retries = {"max_attempts" : 15 })
1440
- client = self .boto_session .client ("logs" , config = config )
1441
- log_group = "/aws/sagemaker/TrainingJobs"
1442
1431
1443
- job_already_completed = status in ("Completed" , "Failed" , "Stopped" )
1444
-
1445
- state = LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1446
- dot = False
1432
+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = _logs_init (
1433
+ self , description , job = "Training"
1434
+ )
1447
1435
1448
- color_wrap = sagemaker . logs . ColorWrap ( )
1436
+ state = _get_initial_job_state ( description , "TrainingJobStatus" , wait )
1449
1437
1450
1438
# The loop below implements a state machine that alternates between checking the job status
1451
1439
# and reading whatever is available in the logs at this point. Note, that if we were
@@ -1470,52 +1458,16 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1470
1458
last_describe_job_call = time .time ()
1471
1459
last_description = description
1472
1460
while True :
1473
- if len (stream_names ) < instance_count :
1474
- # Log streams are created whenever a container starts writing to stdout/err, so
1475
- # this list # may be dynamic until we have a stream for every instance.
1476
- try :
1477
- streams = client .describe_log_streams (
1478
- logGroupName = log_group ,
1479
- logStreamNamePrefix = job_name + "/" ,
1480
- orderBy = "LogStreamName" ,
1481
- limit = instance_count ,
1482
- )
1483
- stream_names = [s ["logStreamName" ] for s in streams ["logStreams" ]]
1484
- positions .update (
1485
- [
1486
- (s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1487
- for s in stream_names
1488
- if s not in positions
1489
- ]
1490
- )
1491
- except ClientError as e :
1492
- # On the very first training job run on an account, there's no log group until
1493
- # the container starts logging, so ignore any errors thrown about that
1494
- err = e .response .get ("Error" , {})
1495
- if err .get ("Code" , None ) != "ResourceNotFoundException" :
1496
- raise
1497
-
1498
- if len (stream_names ) > 0 :
1499
- if dot :
1500
- print ("" )
1501
- dot = False
1502
- for idx , event in sagemaker .logs .multi_stream_iter (
1503
- client , log_group , stream_names , positions
1504
- ):
1505
- color_wrap (idx , event ["message" ])
1506
- ts , count = positions [stream_names [idx ]]
1507
- if event ["timestamp" ] == ts :
1508
- positions [stream_names [idx ]] = sagemaker .logs .Position (
1509
- timestamp = ts , skip = count + 1
1510
- )
1511
- else :
1512
- positions [stream_names [idx ]] = sagemaker .logs .Position (
1513
- timestamp = event ["timestamp" ], skip = 1
1514
- )
1515
- else :
1516
- dot = True
1517
- print ("." , end = "" )
1518
- sys .stdout .flush ()
1461
+ _flush_log_streams (
1462
+ stream_names ,
1463
+ instance_count ,
1464
+ client ,
1465
+ log_group ,
1466
+ job_name ,
1467
+ positions ,
1468
+ dot ,
1469
+ color_wrap ,
1470
+ )
1519
1471
if state == LogState .COMPLETE :
1520
1472
break
1521
1473
@@ -1554,6 +1506,86 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1554
1506
saving = (1 - float (billable_time ) / training_time ) * 100
1555
1507
print ("Managed Spot Training savings: {:.1f}%" .format (saving ))
1556
1508
1509
+ def logs_for_transform_job (self , job_name , wait = False , poll = 10 ):
1510
+ """Display the logs for a given transform job, optionally tailing them until the
1511
+ job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
1512
+ based on which instance the log entry is from.
1513
+
1514
+ Args:
1515
+ job_name (str): Name of the transform job to display the logs for.
1516
+ wait (bool): Whether to keep looking for new log entries until the job completes
1517
+ (default: False).
1518
+ poll (int): The interval in seconds between polling for new log entries and job
1519
+ completion (default: 5).
1520
+
1521
+ Raises:
1522
+ ValueError: If the transform job fails.
1523
+ """
1524
+
1525
+ description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1526
+
1527
+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = _logs_init (
1528
+ self , description , job = "Transform"
1529
+ )
1530
+
1531
+ state = _get_initial_job_state (description , "TransformJobStatus" , wait )
1532
+
1533
+ # The loop below implements a state machine that alternates between checking the job status
1534
+ # and reading whatever is available in the logs at this point. Note, that if we were
1535
+ # called with wait == False, we never check the job status.
1536
+ #
1537
+ # If wait == TRUE and job is not completed, the initial state is TAILING
1538
+ # If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
1539
+ # complete).
1540
+ #
1541
+ # The state table:
1542
+ #
1543
+ # STATE ACTIONS CONDITION NEW STATE
1544
+ # ---------------- ---------------- ----------------- ----------------
1545
+ # TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
1546
+ # Else TAILING
1547
+ # JOB_COMPLETE Read logs, Pause Any COMPLETE
1548
+ # COMPLETE Read logs, Exit N/A
1549
+ #
1550
+ # Notes:
1551
+ # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
1552
+ # Cloudwatch after the job was marked complete.
1553
+ last_describe_job_call = time .time ()
1554
+ while True :
1555
+ _flush_log_streams (
1556
+ stream_names ,
1557
+ instance_count ,
1558
+ client ,
1559
+ log_group ,
1560
+ job_name ,
1561
+ positions ,
1562
+ dot ,
1563
+ color_wrap ,
1564
+ )
1565
+ if state == LogState .COMPLETE :
1566
+ break
1567
+
1568
+ time .sleep (poll )
1569
+
1570
+ if state == LogState .JOB_COMPLETE :
1571
+ state = LogState .COMPLETE
1572
+ elif time .time () - last_describe_job_call >= 30 :
1573
+ description = self .sagemaker_client .describe_transform_job (
1574
+ TransformJobName = job_name
1575
+ )
1576
+ last_describe_job_call = time .time ()
1577
+
1578
+ status = description ["TransformJobStatus" ]
1579
+
1580
+ if status in ("Completed" , "Failed" , "Stopped" ):
1581
+ print ()
1582
+ state = LogState .JOB_COMPLETE
1583
+
1584
+ if wait :
1585
+ self ._check_job_status (job_name , description , "TransformJobStatus" )
1586
+ if dot :
1587
+ print ()
1588
+
1557
1589
1558
1590
def container_def (image , model_data_url = None , env = None ):
1559
1591
"""Create a definition for executing a container as part of a SageMaker model.
@@ -1892,3 +1924,83 @@ def _vpc_config_from_training_job(
1892
1924
if vpc_config_override is vpc_utils .VPC_CONFIG_DEFAULT :
1893
1925
return training_job_desc .get (vpc_utils .VPC_CONFIG_KEY )
1894
1926
return vpc_utils .sanitize (vpc_config_override )
1927
+
1928
+
1929
+ def _get_initial_job_state (description , status_key , wait ):
1930
+ """Placeholder docstring"""
1931
+ status = description [status_key ]
1932
+ job_already_completed = status in ("Completed" , "Failed" , "Stopped" )
1933
+ return LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1934
+
1935
+
1936
+ def _logs_init (sagemaker_session , description , job ):
1937
+ """Placeholder docstring"""
1938
+ if job == "Training" :
1939
+ instance_count = description ["ResourceConfig" ]["InstanceCount" ]
1940
+ elif job == "Transform" :
1941
+ instance_count = description ["TransformResources" ]["InstanceCount" ]
1942
+
1943
+ stream_names = [] # The list of log streams
1944
+ positions = {} # The current position in each stream, map of stream name -> position
1945
+
1946
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1947
+ # to be interrupted by a transient exception.
1948
+ config = botocore .config .Config (retries = {"max_attempts" : 15 })
1949
+ client = sagemaker_session .boto_session .client ("logs" , config = config )
1950
+ log_group = "/aws/sagemaker/" + job + "Jobs"
1951
+
1952
+ dot = False
1953
+
1954
+ color_wrap = sagemaker .logs .ColorWrap ()
1955
+
1956
+ return instance_count , stream_names , positions , client , log_group , dot , color_wrap
1957
+
1958
+
1959
+ def _flush_log_streams (
1960
+ stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap
1961
+ ):
1962
+ """Placeholder docstring"""
1963
+ if len (stream_names ) < instance_count :
1964
+ # Log streams are created whenever a container starts writing to stdout/err, so this list
1965
+ # may be dynamic until we have a stream for every instance.
1966
+ try :
1967
+ streams = client .describe_log_streams (
1968
+ logGroupName = log_group ,
1969
+ logStreamNamePrefix = job_name + "/" ,
1970
+ orderBy = "LogStreamName" ,
1971
+ limit = instance_count ,
1972
+ )
1973
+ stream_names = [s ["logStreamName" ] for s in streams ["logStreams" ]]
1974
+ positions .update (
1975
+ [
1976
+ (s , sagemaker .logs .Position (timestamp = 0 , skip = 0 ))
1977
+ for s in stream_names
1978
+ if s not in positions
1979
+ ]
1980
+ )
1981
+ except ClientError as e :
1982
+ # On the very first training job run on an account, there's no log group until
1983
+ # the container starts logging, so ignore any errors thrown about that
1984
+ err = e .response .get ("Error" , {})
1985
+ if err .get ("Code" , None ) != "ResourceNotFoundException" :
1986
+ raise
1987
+
1988
+ if len (stream_names ) > 0 :
1989
+ if dot :
1990
+ print ("" )
1991
+ dot = False
1992
+ for idx , event in sagemaker .logs .multi_stream_iter (
1993
+ client , log_group , stream_names , positions
1994
+ ):
1995
+ color_wrap (idx , event ["message" ])
1996
+ ts , count = positions [stream_names [idx ]]
1997
+ if event ["timestamp" ] == ts :
1998
+ positions [stream_names [idx ]] = sagemaker .logs .Position (timestamp = ts , skip = count + 1 )
1999
+ else :
2000
+ positions [stream_names [idx ]] = sagemaker .logs .Position (
2001
+ timestamp = event ["timestamp" ], skip = 1
2002
+ )
2003
+ else :
2004
+ dot = True
2005
+ print ("." , end = "" )
2006
+ sys .stdout .flush ()
0 commit comments