@@ -1282,25 +1282,13 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1282
1282
"""
1283
1283
1284
1284
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
1285
- print (secondary_training_status_message (description , None ), end = "" )
1286
- instance_count = description ["ResourceConfig" ]["InstanceCount" ]
1287
- status = description ["TrainingJobStatus" ]
1285
+ print (secondary_training_status_message (description , None ), end = '' )
1288
1286
1289
- stream_names = [] # The list of log streams
1290
- positions = {} # The current position in each stream, map of stream name -> position
1291
-
1292
- # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1293
- # to be interrupted by a transient exception.
1294
- config = botocore .config .Config (retries = {"max_attempts" : 15 })
1295
- client = self .boto_session .client ("logs" , config = config )
1296
- log_group = "/aws/sagemaker/TrainingJobs"
1287
+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = \
1288
+ _logs_initializer (self , description , job = 'Training' )
1297
1289
1298
1290
state = _get_initial_job_state (description , 'TrainingJobStatus' , wait )
1299
1291
1300
- dot = False
1301
-
1302
- color_wrap = sagemaker .logs .ColorWrap ()
1303
-
1304
1292
# The loop below implements a state machine that alternates between checking the job status and
1305
1293
# reading whatever is available in the logs at this point. Note, that if we were called with
1306
1294
# wait == False, we never check the job status.
@@ -1371,23 +1359,12 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
1371
1359
"""
1372
1360
1373
1361
description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1374
- instance_count = description ['TransformResources' ]['InstanceCount' ]
1375
1362
1376
- stream_names = [] # The list of log streams
1377
- positions = {} # The current position in each stream, map of stream name -> position
1378
-
1379
- # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1380
- # to be interrupted by a transient exception.
1381
- config = botocore .config .Config (retries = {'max_attempts' : 15 })
1382
- client = self .boto_session .client ('logs' , config = config )
1383
- log_group = '/aws/sagemaker/TransformJobs'
1363
+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = \
1364
+ _logs_initializer (self , description , job = 'Transform' )
1384
1365
1385
1366
state = _get_initial_job_state (description , 'TransformJobStatus' , wait )
1386
1367
1387
- dot = False
1388
-
1389
- color_wrap = sagemaker .logs .ColorWrap ()
1390
-
1391
1368
# The loop below implements a state machine that alternates between checking the job status and
1392
1369
# reading whatever is available in the logs at this point. Note, that if we were called with
1393
1370
# wait == False, we never check the job status.
@@ -1831,6 +1808,28 @@ def _get_initial_job_state(description, status_key, wait):
1831
1808
return LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1832
1809
1833
1810
1811
+ def _logs_initializer (sagemaker_session , description , job ):
1812
+ if job == 'Training' :
1813
+ instance_count = description ['ResourceConfig' ]['InstanceCount' ]
1814
+ elif job == 'Transform' :
1815
+ instance_count = description ['TransformResources' ]['InstanceCount' ]
1816
+
1817
+ stream_names = [] # The list of log streams
1818
+ positions = {} # The current position in each stream, map of stream name -> position
1819
+
1820
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1821
+ # to be interrupted by a transient exception.
1822
+ config = botocore .config .Config (retries = {'max_attempts' : 15 })
1823
+ client = sagemaker_session .boto_session .client ('logs' , config = config )
1824
+ log_group = '/aws/sagemaker/' + job + 'Jobs'
1825
+
1826
+ dot = False
1827
+
1828
+ color_wrap = sagemaker .logs .ColorWrap ()
1829
+
1830
+ return instance_count , stream_names , positions , client , log_group , dot , color_wrap
1831
+
1832
+
1834
1833
def _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap ):
1835
1834
if len (stream_names ) < instance_count :
1836
1835
# Log streams are created whenever a container starts writing to stdout/err, so this list
0 commit comments