Skip to content

Commit 86d6f27

Browse files
icywang86ruiyangaws
authored andcommitted
Generate better training job status report (#77) (#322)
* Generate better training job status report With this feature while waiting for a training job to finish user will be able see more detailed status from the output. The output is in the format of '<StarTime> <SecondaryStatus> <StatusMessage>...'
1 parent fe5acbf commit 86d6f27

File tree

5 files changed

+163
-22
lines changed

5 files changed

+163
-22
lines changed

CHANGELOG.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
CHANGELOG
33
=========
44

5-
1.7.2.dev
6-
=========
5+
1.7.2
6+
=====
7+
78
* bug-fix: Prediction output for the TF_JSON_SERIALIZER
9+
* enhancement: Add better training job status report
810

911
1.7.1
1012
=====

src/sagemaker/session.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from botocore.exceptions import ClientError
2727

2828
from sagemaker.user_agent import prepend_user_agent
29-
from sagemaker.utils import name_from_image
29+
from sagemaker.utils import name_from_image, secondary_training_status_message, secondary_training_status_changed
3030
import sagemaker.logs
3131

3232

@@ -556,7 +556,8 @@ def wait_for_job(self, job, poll=5):
556556
Raises:
557557
ValueError: If the training job fails.
558558
"""
559-
desc = _wait_until(lambda: _train_done(self.sagemaker_client, job), poll)
559+
desc = _wait_until_training_done(lambda last_desc: _train_done(self.sagemaker_client, job, last_desc),
560+
None, poll)
560561
self._check_job_status(job, desc, 'TrainingJobStatus')
561562
return desc
562563

@@ -795,6 +796,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
795796
"""
796797

797798
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
799+
print(secondary_training_status_message(description, None), end='')
798800
instance_count = description['ResourceConfig']['InstanceCount']
799801
status = description['TrainingJobStatus']
800802

@@ -834,6 +836,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
834836
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
835837
# the job was marked complete.
836838
last_describe_job_call = time.time()
839+
last_description = description
837840
while True:
838841
if len(stream_names) < instance_count:
839842
# Log streams are created whenever a container starts writing to stdout/err, so this list
@@ -877,16 +880,21 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
877880
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
878881
last_describe_job_call = time.time()
879882

883+
if secondary_training_status_changed(description, last_description):
884+
print()
885+
print(secondary_training_status_message(description, last_description), end='')
886+
last_description = description
887+
880888
status = description['TrainingJobStatus']
881889

882890
if status == 'Completed' or status == 'Failed' or status == 'Stopped':
891+
print()
883892
state = LogState.JOB_COMPLETE
884893

885894
if wait:
886895
self._check_job_status(job_name, description, 'TrainingJobStatus')
887896
if dot:
888897
print()
889-
print('===== Job Complete =====')
890898
# Customers are not billed for hardware provisioning, so billable time is less than total time
891899
billable_time = (description['TrainingEndTime'] - description['TrainingStartTime']) * instance_count
892900
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
@@ -1009,30 +1017,25 @@ def _deployment_entity_exists(describe_fn):
10091017
return False
10101018

10111019

1012-
def _train_done(sagemaker_client, job_name):
1013-
training_status_codes = {
1014-
'Created': '-',
1015-
'InProgress': '.',
1016-
'Completed': '!',
1017-
'Failed': '*',
1018-
'Stopping': '>',
1019-
'Stopped': 's',
1020-
'Deleting': 'o',
1021-
'Deleted': 'x'
1022-
}
1020+
def _train_done(sagemaker_client, job_name, last_desc):
1021+
10231022
in_progress_statuses = ['InProgress', 'Created']
10241023

10251024
desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
10261025
status = desc['TrainingJobStatus']
10271026

1028-
print(training_status_codes.get(status, '?'), end='')
1027+
if secondary_training_status_changed(desc, last_desc):
1028+
print()
1029+
print(secondary_training_status_message(desc, last_desc), end='')
1030+
else:
1031+
print('.', end='')
10291032
sys.stdout.flush()
10301033

10311034
if status in in_progress_statuses:
1032-
return None
1035+
return desc, False
10331036

1034-
print('')
1035-
return desc
1037+
print()
1038+
return desc, True
10361039

10371040

10381041
def _tuning_job_status(sagemaker_client, job_name):
@@ -1102,6 +1105,14 @@ def _deploy_done(sagemaker_client, endpoint_name):
11021105
return None if status in in_progress_statuses else desc
11031106

11041107

1108+
def _wait_until_training_done(callable, desc, poll=5):
1109+
job_desc, finished = callable(desc)
1110+
while not finished:
1111+
time.sleep(poll)
1112+
job_desc, finished = callable(job_desc)
1113+
return job_desc
1114+
1115+
11051116
def _wait_until(callable, poll=5):
11061117
result = callable()
11071118
while result is None:

src/sagemaker/utils.py

+56
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717

1818
import re
19+
from datetime import datetime
1920
from functools import wraps
2021

2122

@@ -130,6 +131,61 @@ def extract_name_from_job_arn(arn):
130131
return arn[(slash_pos + 1):]
131132

132133

134+
def secondary_training_status_changed(current_job_description, prev_job_description):
135+
"""Returns true if training job's secondary status message has changed.
136+
137+
Args:
138+
current_job_desc: Current job description, returned from DescribeTrainingJob call.
139+
prev_job_desc: Previous job description, returned from DescribeTrainingJob call.
140+
141+
Returns:
142+
boolean: Whether the secondary status message of a training job changed or not.
143+
144+
"""
145+
current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions')
146+
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
147+
return False
148+
149+
last_message = prev_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']\
150+
if prev_job_description is not None else ''
151+
message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']
152+
153+
return message != last_message
154+
155+
156+
def secondary_training_status_message(job_description, prev_description):
157+
"""Returns a string contains start time and the secondary training job status message.
158+
159+
Args:
160+
job_description: Returned response from DescribeTrainingJob call
161+
prev_description: Previous job description from DescribeTrainingJob call
162+
163+
Returns:
164+
str: Job status string to be printed.
165+
166+
"""
167+
168+
if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
169+
or len(job_description.get('SecondaryStatusTransitions')) == 0:
170+
return ''
171+
172+
prev_transitions_num = len(prev_description['SecondaryStatusTransitions']) if prev_description is not None else 0
173+
current_transitions = job_description['SecondaryStatusTransitions']
174+
175+
if len(current_transitions) == prev_transitions_num:
176+
return current_transitions[-1]['StatusMessage']
177+
else:
178+
transitions_to_print = current_transitions[prev_transitions_num - len(current_transitions):]
179+
status_strs = []
180+
for transition in transitions_to_print:
181+
message = transition['StatusMessage']
182+
time_str = datetime.utcfromtimestamp(
183+
time.mktime(transition['StartTime'].timetuple())).strftime('%Y-%m-%d %H:%M:%S')
184+
status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message))
185+
186+
return '\n'.join(status_strs)
187+
188+
133189
class DeferredError(object):
134190
"""Stores an exception and raises it at a later time anytime this
135191
object is accessed in any way. Useful to allow soft-dependencies on imports,

tests/unit/test_session.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import sagemaker
2525
from sagemaker import s3_input, Session, get_execution_role
26-
from sagemaker.session import _tuning_job_status, _transform_job_status
26+
from sagemaker.session import _tuning_job_status, _transform_job_status, _train_done
2727

2828
REGION = 'us-west-2'
2929

@@ -711,3 +711,25 @@ def test_transform_job_status_none(sagemaker_session):
711711

712712
result = _transform_job_status(sagemaker_session.sagemaker_client, JOB_NAME)
713713
assert result is None
714+
715+
716+
def test_train_done_completed(sagemaker_session):
717+
training_job_desc = {'TrainingJobStatus': 'Completed'}
718+
sagemaker_session.sagemaker_client.describe_training_job = Mock(
719+
name='describe_training_job', return_value=training_job_desc)
720+
721+
actual_job_desc, training_finished = _train_done(sagemaker_session.sagemaker_client, JOB_NAME, None)
722+
723+
assert actual_job_desc['TrainingJobStatus'] == 'Completed'
724+
assert training_finished is True
725+
726+
727+
def test_train_done_in_progress(sagemaker_session):
728+
training_job_desc = {'TrainingJobStatus': 'InProgress'}
729+
sagemaker_session.sagemaker_client.describe_training_job = Mock(
730+
name='describe_training_job', return_value=training_job_desc)
731+
732+
actual_job_desc, training_finished = _train_done(sagemaker_session.sagemaker_client, JOB_NAME, None)
733+
734+
assert actual_job_desc['TrainingJobStatus'] == 'InProgress'
735+
assert training_finished is False

tests/unit/test_utils.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import pytest
1818
from mock import patch
1919

20-
from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError, extract_name_from_job_arn
20+
from sagemaker.utils import get_config_value, name_from_base,\
21+
to_str, DeferredError, extract_name_from_job_arn, secondary_training_status_changed,\
22+
secondary_training_status_message
23+
24+
from datetime import datetime
25+
import time
2126

2227
NAME = 'base_name'
2328

@@ -89,3 +94,48 @@ def test_name_from_training_arn():
8994
arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b'
9095
name = extract_name_from_job_arn(arn)
9196
assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b'
97+
98+
99+
MESSAGE = 'message'
100+
STATUS = 'status'
101+
TRAINING_JOB_DESCRIPTION_1 = {
102+
'SecondaryStatusTransitions': [{'StatusMessage': MESSAGE, 'Status': STATUS}]
103+
}
104+
TRAINING_JOB_DESCRIPTION_2 = {
105+
'SecondaryStatusTransitions': [{'StatusMessage': 'different message', 'Status': STATUS}]
106+
}
107+
TRAINING_JOB_DESCRIPTION_EMPTY = {
108+
'SecondaryStatusTransitions': []
109+
}
110+
111+
112+
def test_secondary_training_status_changed_true():
113+
changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2)
114+
assert changed is True
115+
116+
117+
def test_secondary_training_status_changed_false():
118+
changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1)
119+
assert changed is False
120+
121+
122+
def test_secondary_training_status_changed_empty():
123+
changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, TRAINING_JOB_DESCRIPTION_1)
124+
assert changed is False
125+
126+
127+
def test_secondary_training_status_message_status_changed():
128+
now = datetime.now()
129+
TRAINING_JOB_DESCRIPTION_1['SecondaryStatusTransitions'][-1]['StartTime'] = now
130+
expected = '{} {} - {}'.format(
131+
datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime('%Y-%m-%d %H:%M:%S'),
132+
STATUS,
133+
MESSAGE
134+
)
135+
assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_EMPTY) == expected
136+
137+
138+
def test_secondary_training_status_message_status_not_changed():
139+
now = datetime.now()
140+
TRAINING_JOB_DESCRIPTION_1['SecondaryStatusTransitions'][-1]['StartTime'] = now
141+
assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) == MESSAGE

0 commit comments

Comments
 (0)