Skip to content

Generate better training job status report #322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 31, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
CHANGELOG
=========

1.7.2.dev
=========
1.7.2
=====

* bug-fix: Prediction output for the TF_JSON_SERIALIZER
* enhancement: Add better training job status report

1.7.1
=====
Expand Down
47 changes: 29 additions & 18 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from botocore.exceptions import ClientError

from sagemaker.user_agent import prepend_user_agent
from sagemaker.utils import name_from_image
from sagemaker.utils import name_from_image, secondary_training_status_message, secondary_training_status_changed
import sagemaker.logs


Expand Down Expand Up @@ -556,7 +556,8 @@ def wait_for_job(self, job, poll=5):
Raises:
ValueError: If the training job fails.
"""
desc = _wait_until(lambda: _train_done(self.sagemaker_client, job), poll)
desc = _wait_until_training_done(lambda last_desc: _train_done(self.sagemaker_client, job, last_desc),
None, poll)
self._check_job_status(job, desc, 'TrainingJobStatus')
return desc

Expand Down Expand Up @@ -795,6 +796,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
"""

description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
print(secondary_training_status_message(description, None), end='')
instance_count = description['ResourceConfig']['InstanceCount']
status = description['TrainingJobStatus']

Expand Down Expand Up @@ -834,6 +836,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
# the job was marked complete.
last_describe_job_call = time.time()
last_description = description
while True:
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
Expand Down Expand Up @@ -877,16 +880,21 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
last_describe_job_call = time.time()

if secondary_training_status_changed(description, last_description):
print()
print(secondary_training_status_message(description, last_description), end='')
last_description = description

status = description['TrainingJobStatus']

if status == 'Completed' or status == 'Failed' or status == 'Stopped':
print()
state = LogState.JOB_COMPLETE

if wait:
self._check_job_status(job_name, description, 'TrainingJobStatus')
if dot:
print()
print('===== Job Complete =====')
# Customers are not billed for hardware provisioning, so billable time is less than total time
billable_time = (description['TrainingEndTime'] - description['TrainingStartTime']) * instance_count
print('Billable seconds:', int(billable_time.total_seconds()) + 1)
Expand Down Expand Up @@ -1009,30 +1017,25 @@ def _deployment_entity_exists(describe_fn):
return False


def _train_done(sagemaker_client, job_name):
training_status_codes = {
'Created': '-',
'InProgress': '.',
'Completed': '!',
'Failed': '*',
'Stopping': '>',
'Stopped': 's',
'Deleting': 'o',
'Deleted': 'x'
}
def _train_done(sagemaker_client, job_name, last_desc):

in_progress_statuses = ['InProgress', 'Created']

desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
status = desc['TrainingJobStatus']

print(training_status_codes.get(status, '?'), end='')
if secondary_training_status_changed(desc, last_desc):
print()
print(secondary_training_status_message(desc, last_desc), end='')
else:
print('.', end='')
sys.stdout.flush()

if status in in_progress_statuses:
return None
return desc, False

print('')
return desc
print()
return desc, True


def _tuning_job_status(sagemaker_client, job_name):
Expand Down Expand Up @@ -1102,6 +1105,14 @@ def _deploy_done(sagemaker_client, endpoint_name):
return None if status in in_progress_statuses else desc


def _wait_until_training_done(callable, desc, poll=5):
job_desc, finished = callable(desc)
while not finished:
time.sleep(poll)
job_desc, finished = callable(job_desc)
return job_desc


def _wait_until(callable, poll=5):
result = callable()
while result is None:
Expand Down
56 changes: 56 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time

import re
from datetime import datetime
from functools import wraps


Expand Down Expand Up @@ -130,6 +131,61 @@ def extract_name_from_job_arn(arn):
return arn[(slash_pos + 1):]


def secondary_training_status_changed(current_job_description, prev_job_description):
"""Returns true if training job's secondary status message has changed.

Args:
current_job_desc: Current job description, returned from DescribeTrainingJob call.
prev_job_desc: Previous job description, returned from DescribeTrainingJob call.

Returns:
boolean: Whether the secondary status message of a training job changed or not.

"""
current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions')
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
return False

last_message = prev_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']\
if prev_job_description is not None else ''
message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']

return message != last_message


def secondary_training_status_message(job_description, prev_description):
"""Returns a string contains start time and the secondary training job status message.

Args:
job_description: Returned response from DescribeTrainingJob call
prev_description: Previous job description from DescribeTrainingJob call

Returns:
str: Job status string to be printed.

"""

if job_description is None or job_description.get('SecondaryStatusTransitions') is None\
or len(job_description.get('SecondaryStatusTransitions')) == 0:
return ''

prev_transitions_num = len(prev_description['SecondaryStatusTransitions']) if prev_description is not None else 0
current_transitions = job_description['SecondaryStatusTransitions']

if len(current_transitions) == prev_transitions_num:
return current_transitions[-1]['StatusMessage']
else:
transitions_to_print = current_transitions[prev_transitions_num - len(current_transitions):]
status_strs = []
for transition in transitions_to_print:
message = transition['StatusMessage']
time_str = datetime.utcfromtimestamp(
time.mktime(transition['StartTime'].timetuple())).strftime('%Y-%m-%d %H:%M:%S')
status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message))

return '\n'.join(status_strs)


class DeferredError(object):
"""Stores an exception and raises it at a later time anytime this
object is accessed in any way. Useful to allow soft-dependencies on imports,
Expand Down
24 changes: 23 additions & 1 deletion tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import sagemaker
from sagemaker import s3_input, Session, get_execution_role
from sagemaker.session import _tuning_job_status, _transform_job_status
from sagemaker.session import _tuning_job_status, _transform_job_status, _train_done

REGION = 'us-west-2'

Expand Down Expand Up @@ -711,3 +711,25 @@ def test_transform_job_status_none(sagemaker_session):

result = _transform_job_status(sagemaker_session.sagemaker_client, JOB_NAME)
assert result is None


def test_train_done_completed(sagemaker_session):
training_job_desc = {'TrainingJobStatus': 'Completed'}
sagemaker_session.sagemaker_client.describe_training_job = Mock(
name='describe_training_job', return_value=training_job_desc)

actual_job_desc, training_finished = _train_done(sagemaker_session.sagemaker_client, JOB_NAME, None)

assert actual_job_desc['TrainingJobStatus'] == 'Completed'
assert training_finished is True


def test_train_done_in_progress(sagemaker_session):
training_job_desc = {'TrainingJobStatus': 'InProgress'}
sagemaker_session.sagemaker_client.describe_training_job = Mock(
name='describe_training_job', return_value=training_job_desc)

actual_job_desc, training_finished = _train_done(sagemaker_session.sagemaker_client, JOB_NAME, None)

assert actual_job_desc['TrainingJobStatus'] == 'InProgress'
assert training_finished is False
52 changes: 51 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import pytest
from mock import patch

from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError, extract_name_from_job_arn
from sagemaker.utils import get_config_value, name_from_base,\
to_str, DeferredError, extract_name_from_job_arn, secondary_training_status_changed,\
secondary_training_status_message

from datetime import datetime
import time

NAME = 'base_name'

Expand Down Expand Up @@ -89,3 +94,48 @@ def test_name_from_training_arn():
arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b'
name = extract_name_from_job_arn(arn)
assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b'


MESSAGE = 'message'
STATUS = 'status'
TRAINING_JOB_DESCRIPTION_1 = {
'SecondaryStatusTransitions': [{'StatusMessage': MESSAGE, 'Status': STATUS}]
}
TRAINING_JOB_DESCRIPTION_2 = {
'SecondaryStatusTransitions': [{'StatusMessage': 'different message', 'Status': STATUS}]
}
TRAINING_JOB_DESCRIPTION_EMPTY = {
'SecondaryStatusTransitions': []
}


def test_secondary_training_status_changed_true():
changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2)
assert changed is True


def test_secondary_training_status_changed_false():
changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1)
assert changed is False


def test_secondary_training_status_changed_empty():
changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, TRAINING_JOB_DESCRIPTION_1)
assert changed is False


def test_secondary_training_status_message_status_changed():
now = datetime.now()
TRAINING_JOB_DESCRIPTION_1['SecondaryStatusTransitions'][-1]['StartTime'] = now
expected = '{} {} - {}'.format(
datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime('%Y-%m-%d %H:%M:%S'),
STATUS,
MESSAGE
)
assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_EMPTY) == expected


def test_secondary_training_status_message_status_not_changed():
now = datetime.now()
TRAINING_JOB_DESCRIPTION_1['SecondaryStatusTransitions'][-1]['StartTime'] = now
assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) == MESSAGE