Skip to content

Commit f4f3e51

Browse files
committed
Refactor code
1 parent 956a733 commit f4f3e51

File tree

3 files changed

+27
-33
lines changed

3 files changed

+27
-33
lines changed

src/sagemaker/session.py

+27-28
Original file line numberDiff line numberDiff line change
@@ -1282,25 +1282,13 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
12821282
"""
12831283

12841284
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
1285-
print(secondary_training_status_message(description, None), end="")
1286-
instance_count = description["ResourceConfig"]["InstanceCount"]
1287-
status = description["TrainingJobStatus"]
1285+
print(secondary_training_status_message(description, None), end='')
12881286

1289-
stream_names = [] # The list of log streams
1290-
positions = {} # The current position in each stream, map of stream name -> position
1291-
1292-
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1293-
# to be interrupted by a transient exception.
1294-
config = botocore.config.Config(retries={"max_attempts": 15})
1295-
client = self.boto_session.client("logs", config=config)
1296-
log_group = "/aws/sagemaker/TrainingJobs"
1287+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = \
1288+
_logs_initializer(self, description, job='Training')
12971289

12981290
state = _get_initial_job_state(description, 'TrainingJobStatus', wait)
12991291

1300-
dot = False
1301-
1302-
color_wrap = sagemaker.logs.ColorWrap()
1303-
13041292
# The loop below implements a state machine that alternates between checking the job status and
13051293
# reading whatever is available in the logs at this point. Note, that if we were called with
13061294
# wait == False, we never check the job status.
@@ -1371,23 +1359,12 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): # noqa: C901 -
13711359
"""
13721360

13731361
description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
1374-
instance_count = description['TransformResources']['InstanceCount']
13751362

1376-
stream_names = [] # The list of log streams
1377-
positions = {} # The current position in each stream, map of stream name -> position
1378-
1379-
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1380-
# to be interrupted by a transient exception.
1381-
config = botocore.config.Config(retries={'max_attempts': 15})
1382-
client = self.boto_session.client('logs', config=config)
1383-
log_group = '/aws/sagemaker/TransformJobs'
1363+
instance_count, stream_names, positions, client, log_group, dot, color_wrap = \
1364+
_logs_initializer(self, description, job='Transform')
13841365

13851366
state = _get_initial_job_state(description, 'TransformJobStatus', wait)
13861367

1387-
dot = False
1388-
1389-
color_wrap = sagemaker.logs.ColorWrap()
1390-
13911368
# The loop below implements a state machine that alternates between checking the job status and
13921369
# reading whatever is available in the logs at this point. Note, that if we were called with
13931370
# wait == False, we never check the job status.
@@ -1831,6 +1808,28 @@ def _get_initial_job_state(description, status_key, wait):
18311808
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
18321809

18331810

1811+
def _logs_initializer(sagemaker_session, description, job):
1812+
if job == 'Training':
1813+
instance_count = description['ResourceConfig']['InstanceCount']
1814+
elif job == 'Transform':
1815+
instance_count = description['TransformResources']['InstanceCount']
1816+
1817+
stream_names = [] # The list of log streams
1818+
positions = {} # The current position in each stream, map of stream name -> position
1819+
1820+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
1821+
# to be interrupted by a transient exception.
1822+
config = botocore.config.Config(retries={'max_attempts': 15})
1823+
client = sagemaker_session.boto_session.client('logs', config=config)
1824+
log_group = '/aws/sagemaker/' + job + 'Jobs'
1825+
1826+
dot = False
1827+
1828+
color_wrap = sagemaker.logs.ColorWrap()
1829+
1830+
return instance_count, stream_names, positions, client, log_group, dot, color_wrap
1831+
1832+
18341833
def _flush_log_streams(stream_names, instance_count, client, log_group, job_name, positions, dot, color_wrap):
18351834
if len(stream_names) < instance_count:
18361835
# Log streams are created whenever a container starts writing to stdout/err, so this list

tests/integ/test_transformer.py

-1
Original file line numberDiff line numberDiff line change
@@ -352,4 +352,3 @@ def _create_transformer_and_transform_job(
352352
logs=logs,
353353
)
354354
return transformer
355-

tests/unit/test_session.py

-4
Original file line numberDiff line numberDiff line change
@@ -935,9 +935,6 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
935935
]
936936

937937

938-
MODEL_NAME = "some-model"
939-
940-
941938
@patch('sagemaker.logs.ColorWrap')
942939
def test_logs_for_transform_job_no_wait(cw, sagemaker_session_complete):
943940
ims = sagemaker_session_complete
@@ -989,7 +986,6 @@ def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_
989986

990987

991988
MODEL_NAME = 'some-model'
992-
>>>>>>> feature: Estimator.fit like logs for transformer
993989
PRIMARY_CONTAINER = {
994990
"Environment": {},
995991
"Image": IMAGE,

0 commit comments

Comments
 (0)