@@ -1123,23 +1123,12 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
1123
1123
1124
1124
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
1125
1125
print (secondary_training_status_message (description , None ), end = '' )
1126
- instance_count = description ['ResourceConfig' ]['InstanceCount' ]
1127
-
1128
- stream_names = [] # The list of log streams
1129
- positions = {} # The current position in each stream, map of stream name -> position
1130
1126
1131
- # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1132
- # to be interrupted by a transient exception.
1133
- config = botocore .config .Config (retries = {'max_attempts' : 15 })
1134
- client = self .boto_session .client ('logs' , config = config )
1135
- log_group = '/aws/sagemaker/TrainingJobs'
1127
+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = \
1128
+ _logs_initializer (self , description , job = 'Training' )
1136
1129
1137
1130
state = _get_initial_job_state (description , 'TrainingJobStatus' , wait )
1138
1131
1139
- dot = False
1140
-
1141
- color_wrap = sagemaker .logs .ColorWrap ()
1142
-
1143
1132
# The loop below implements a state machine that alternates between checking the job status and
1144
1133
# reading whatever is available in the logs at this point. Note, that if we were called with
1145
1134
# wait == False, we never check the job status.
@@ -1208,23 +1197,12 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
1208
1197
"""
1209
1198
1210
1199
description = self .sagemaker_client .describe_transform_job (TransformJobName = job_name )
1211
- instance_count = description ['TransformResources' ]['InstanceCount' ]
1212
1200
1213
- stream_names = [] # The list of log streams
1214
- positions = {} # The current position in each stream, map of stream name -> position
1215
-
1216
- # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1217
- # to be interrupted by a transient exception.
1218
- config = botocore .config .Config (retries = {'max_attempts' : 15 })
1219
- client = self .boto_session .client ('logs' , config = config )
1220
- log_group = '/aws/sagemaker/TransformJobs'
1201
+ instance_count , stream_names , positions , client , log_group , dot , color_wrap = \
1202
+ _logs_initializer (self , description , job = 'Transform' )
1221
1203
1222
1204
state = _get_initial_job_state (description , 'TransformJobStatus' , wait )
1223
1205
1224
- dot = False
1225
-
1226
- color_wrap = sagemaker .logs .ColorWrap ()
1227
-
1228
1206
# The loop below implements a state machine that alternates between checking the job status and
1229
1207
# reading whatever is available in the logs at this point. Note, that if we were called with
1230
1208
# wait == False, we never check the job status.
@@ -1643,6 +1621,28 @@ def _get_initial_job_state(description, status_key, wait):
1643
1621
return LogState .TAILING if wait and not job_already_completed else LogState .COMPLETE
1644
1622
1645
1623
1624
+ def _logs_initializer (sagemaker_session , description , job ):
1625
+ if job == 'Training' :
1626
+ instance_count = description ['ResourceConfig' ]['InstanceCount' ]
1627
+ elif job == 'Transform' :
1628
+ instance_count = description ['TransformResources' ]['InstanceCount' ]
1629
+
1630
+ stream_names = [] # The list of log streams
1631
+ positions = {} # The current position in each stream, map of stream name -> position
1632
+
1633
+ # Increase retries allowed (from default of 4), as we don't want waiting for a training job
1634
+ # to be interrupted by a transient exception.
1635
+ config = botocore .config .Config (retries = {'max_attempts' : 15 })
1636
+ client = sagemaker_session .boto_session .client ('logs' , config = config )
1637
+ log_group = '/aws/sagemaker/' + job + 'Jobs'
1638
+
1639
+ dot = False
1640
+
1641
+ color_wrap = sagemaker .logs .ColorWrap ()
1642
+
1643
+ return instance_count , stream_names , positions , client , log_group , dot , color_wrap
1644
+
1645
+
1646
1646
def _flush_log_streams (stream_names , instance_count , client , log_group , job_name , positions , dot , color_wrap ):
1647
1647
if len (stream_names ) < instance_count :
1648
1648
# Log streams are created whenever a container starts writing to stdout/err, so this list
0 commit comments