Skip to content

Commit 1dcb705

Browse files
committed
Refactor code
1 parent c7a7251 commit 1dcb705

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

src/sagemaker/session.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -1123,23 +1123,12 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
11231123

11241124
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
11251125
print(secondary_training_status_message(description, None), end='')
1126-
instance_count = description['ResourceConfig']['InstanceCount']
1127-
1128-
stream_names = [] # The list of log streams
1129-
positions = {} # The current position in each stream, map of stream name -> position
11301126

1131-
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1132-
# to be interrupted by a transient exception.
1133-
config = botocore.config.Config(retries={'max_attempts': 15})
1134-
client = self.boto_session.client('logs', config=config)
1135-
log_group = '/aws/sagemaker/TrainingJobs'
1127+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = \
1128+
_logs_initializer(self, description, job='Training')
11361129

11371130
state = _get_initial_job_state(description, 'TrainingJobStatus', wait)
11381131

1139-
dot = False
1140-
1141-
color_wrap = sagemaker.logs.ColorWrap()
1142-
11431132
# The loop below implements a state machine that alternates between checking the job status and
11441133
# reading whatever is available in the logs at this point. Note, that if we were called with
11451134
# wait == False, we never check the job status.
@@ -1208,23 +1197,12 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
12081197
"""
12091198

12101199
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1211-
instance_count = description['TransformResources']['InstanceCount']
12121200

1213-
stream_names = [] # The list of log streams
1214-
positions = {} # The current position in each stream, map of stream name -> position
1215-
1216-
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1217-
# to be interrupted by a transient exception.
1218-
config = botocore.config.Config(retries={'max_attempts': 15})
1219-
client = self.boto_session.client('logs', config=config)
1220-
log_group = '/aws/sagemaker/TransformJobs'
1201+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = \
1202+
_logs_initializer(self, description, job='Transform')
12211203

12221204
state = _get_initial_job_state(description, 'TransformJobStatus', wait)
12231205

1224-
dot = False
1225-
1226-
color_wrap = sagemaker.logs.ColorWrap()
1227-
12281206
# The loop below implements a state machine that alternates between checking the job status and
12291207
# reading whatever is available in the logs at this point. Note, that if we were called with
12301208
# wait == False, we never check the job status.
@@ -1643,6 +1621,28 @@ def _get_initial_job_state(description, status_key, wait):
16431621
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
16441622

16451623

1624+
def _logs_initializer(sagemaker_session, description, job):
1625+
if job == 'Training':
1626+
instance_count = description['ResourceConfig']['InstanceCount']
1627+
elif job == 'Transform':
1628+
instance_count = description['TransformResources']['InstanceCount']
1629+
1630+
stream_names = [] # The list of log streams
1631+
positions = {} # The current position in each stream, map of stream name -> position
1632+
1633+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1634+
# to be interrupted by a transient exception.
1635+
config = botocore.config.Config(retries={'max_attempts': 15})
1636+
client = sagemaker_session.boto_session.client('logs', config=config)
1637+
log_group = '/aws/sagemaker/' + job + 'Jobs'
1638+
1639+
dot = False
1640+
1641+
color_wrap = sagemaker.logs.ColorWrap()
1642+
1643+
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
1644+
1645+
16461646
def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap):
16471647
if len(stream_names) < instance_count:
16481648
# Log streams are created whenever a container starts writing to stdout/err, so this list

0 commit comments

Comments
 (0)