Skip to content

Commit 77b45e2

Browse files
committed
Refactor code
1 parent 956a733 commit 77b45e2

File tree

1 file changed

+27
-28
lines changed

1 file changed

+27
-28
lines changed

src/sagemaker/session.py

+27-28
Original file line numberDiff line numberDiff line change
@@ -1282,25 +1282,13 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
12821282
"""
12831283

12841284
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
1285-
print(secondary_training_status_message(description, None), end="")
1286-
instance_count = description["ResourceConfig"]["InstanceCount"]
1287-
status = description["TrainingJobStatus"]
1285+
print(secondary_training_status_message(description, None), end='')
12881286

1289-
stream_names = [] # The list of log streams
1290-
positions = {} # The current position in each stream, map of stream name -> position
1291-
1292-
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1293-
# to be interrupted by a transient exception.
1294-
config = botocore.config.Config(retries={"max_attempts": 15})
1295-
client = self.boto_session.client("logs", config=config)
1296-
log_group = "/aws/sagemaker/TrainingJobs"
1287+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = \
1288+
_logs_initializer(self, description, job='Training')
12971289

12981290
state = _get_initial_job_state(description, 'TrainingJobStatus', wait)
12991291

1300-
dot = False
1301-
1302-
color_wrap = sagemaker.logs.ColorWrap()
1303-
13041292
# The loop below implements a state machine that alternates between checking the job status and
13051293
# reading whatever is available in the logs at this point. Note, that if we were called with
13061294
# wait == False, we never check the job status.
@@ -1371,23 +1359,12 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
13711359
"""
13721360

13731361
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1374-
instance_count = description['TransformResources']['InstanceCount']
13751362

1376-
stream_names = [] # The list of log streams
1377-
positions = {} # The current position in each stream, map of stream name -> position
1378-
1379-
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1380-
# to be interrupted by a transient exception.
1381-
config = botocore.config.Config(retries={'max_attempts': 15})
1382-
client = self.boto_session.client('logs', config=config)
1383-
log_group = '/aws/sagemaker/TransformJobs'
1363+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = \
1364+
_logs_initializer(self, description, job='Transform')
13841365

13851366
state = _get_initial_job_state(description, 'TransformJobStatus', wait)
13861367

1387-
dot = False
1388-
1389-
color_wrap = sagemaker.logs.ColorWrap()
1390-
13911368
# The loop below implements a state machine that alternates between checking the job status and
13921369
# reading whatever is available in the logs at this point. Note, that if we were called with
13931370
# wait == False, we never check the job status.
@@ -1831,6 +1808,28 @@ def _get_initial_job_state(description, status_key, wait):
18311808
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
18321809

18331810

1811+
def _logs_initializer(sagemaker_session, description, job):
1812+
if job == 'Training':
1813+
instance_count = description['ResourceConfig']['InstanceCount']
1814+
elif job == 'Transform':
1815+
instance_count = description['TransformResources']['InstanceCount']
1816+
1817+
stream_names = [] # The list of log streams
1818+
positions = {} # The current position in each stream, map of stream name -> position
1819+
1820+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1821+
# to be interrupted by a transient exception.
1822+
config = botocore.config.Config(retries={'max_attempts': 15})
1823+
client = sagemaker_session.boto_session.client('logs', config=config)
1824+
log_group = '/aws/sagemaker/' + job + 'Jobs'
1825+
1826+
dot = False
1827+
1828+
color_wrap = sagemaker.logs.ColorWrap()
1829+
1830+
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
1831+
1832+
18341833
def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap):
18351834
if len(stream_names) < instance_count:
18361835
# Log streams are created whenever a container starts writing to stdout/err, so this list

0 commit comments

Comments
 (0)