diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d12c404f1b..e59294c981 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,9 +2,11 @@ CHANGELOG ========= -1.7.2.dev -========= +1.7.2 +===== + * bug-fix: Prediction output for the TF_JSON_SERIALIZER +* enhancement: Add better training job status report 1.7.1 ===== diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5d4df05d1e..ec62e09ac0 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -26,7 +26,7 @@ from botocore.exceptions import ClientError from sagemaker.user_agent import prepend_user_agent -from sagemaker.utils import name_from_image +from sagemaker.utils import name_from_image, secondary_training_status_message, secondary_training_status_changed import sagemaker.logs @@ -556,7 +556,8 @@ def wait_for_job(self, job, poll=5): Raises: ValueError: If the training job fails. """ - desc = _wait_until(lambda: _train_done(self.sagemaker_client, job), poll) + desc = _wait_until_training_done(lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), + None, poll) self._check_job_status(job, desc, 'TrainingJobStatus') return desc @@ -795,6 +796,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress """ description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) + print(secondary_training_status_message(description, None), end='') instance_count = description['ResourceConfig']['InstanceCount'] status = description['TrainingJobStatus'] @@ -834,6 +836,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after # the job was marked complete. last_describe_job_call = time.time() + last_description = description while True: if len(stream_names) < instance_count: # Log streams are created whenever a container starts writing to stdout/err, so this list @@ -877,16 +880,21 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name) last_describe_job_call = time.time() + if secondary_training_status_changed(description, last_description): + print() + print(secondary_training_status_message(description, last_description), end='') + last_description = description + status = description['TrainingJobStatus'] if status == 'Completed' or status == 'Failed' or status == 'Stopped': + print() state = LogState.JOB_COMPLETE if wait: self._check_job_status(job_name, description, 'TrainingJobStatus') if dot: print() - print('===== Job Complete =====') # Customers are not billed for hardware provisioning, so billable time is less than total time billable_time = (description['TrainingEndTime'] - description['TrainingStartTime']) * instance_count print('Billable seconds:', int(billable_time.total_seconds()) + 1) @@ -1009,30 +1017,25 @@ def _deployment_entity_exists(describe_fn): return False -def _train_done(sagemaker_client, job_name): - training_status_codes = { - 'Created': '-', - 'InProgress': '.', - 'Completed': '!', - 'Failed': '*', - 'Stopping': '>', - 'Stopped': 's', - 'Deleting': 'o', - 'Deleted': 'x' - } +def _train_done(sagemaker_client, job_name, last_desc): + in_progress_statuses = ['InProgress', 'Created'] desc = sagemaker_client.describe_training_job(TrainingJobName=job_name) status = desc['TrainingJobStatus'] - print(training_status_codes.get(status, '?'), end='') + if secondary_training_status_changed(desc, last_desc): + print() + print(secondary_training_status_message(desc, last_desc), end='') + else: + print('.', end='') sys.stdout.flush() if status in in_progress_statuses: - return None + return desc, False - print('') - return desc + print() + return desc, True def _tuning_job_status(sagemaker_client, job_name): @@ -1102,6 +1105,14 @@ def _deploy_done(sagemaker_client, endpoint_name): return None if status in in_progress_statuses else desc +def _wait_until_training_done(callable, desc, poll=5): + job_desc, finished = callable(desc) + while not finished: + time.sleep(poll) + job_desc, finished = callable(job_desc) + return job_desc + + def _wait_until(callable, poll=5): result = callable() while result is None: diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index fe6dc42264..428a7235f9 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -16,6 +16,7 @@ import time import re +from datetime import datetime from functools import wraps @@ -130,6 +131,61 @@ def extract_name_from_job_arn(arn): return arn[(slash_pos + 1):] +def secondary_training_status_changed(current_job_description, prev_job_description): + """Returns true if training job's secondary status message has changed. + + Args: + current_job_desc: Current job description, returned from DescribeTrainingJob call. + prev_job_desc: Previous job description, returned from DescribeTrainingJob call. + + Returns: + boolean: Whether the secondary status message of a training job changed or not. + + """ + current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions') + if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0: + return False + + last_message = prev_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']\ + if prev_job_description is not None else '' + message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage'] + + return message != last_message + + +def secondary_training_status_message(job_description, prev_description): + """Returns a string contains start time and the secondary training job status message. + + Args: + job_description: Returned response from DescribeTrainingJob call + prev_description: Previous job description from DescribeTrainingJob call + + Returns: + str: Job status string to be printed. + + """ + + if job_description is None or job_description.get('SecondaryStatusTransitions') is None\ + or len(job_description.get('SecondaryStatusTransitions')) == 0: + return '' + + prev_transitions_num = len(prev_description['SecondaryStatusTransitions']) if prev_description is not None else 0 + current_transitions = job_description['SecondaryStatusTransitions'] + + if len(current_transitions) == prev_transitions_num: + return current_transitions[-1]['StatusMessage'] + else: + transitions_to_print = current_transitions[prev_transitions_num - len(current_transitions):] + status_strs = [] + for transition in transitions_to_print: + message = transition['StatusMessage'] + time_str = datetime.utcfromtimestamp( + time.mktime(transition['StartTime'].timetuple())).strftime('%Y-%m-%d %H:%M:%S') + status_strs.append('{} {} - {}'.format(time_str, transition['Status'], message)) + + return '\n'.join(status_strs) + + class DeferredError(object): """Stores an exception and raises it at a later time anytime this object is accessed in any way. Useful to allow soft-dependencies on imports, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 29f9f9f3a3..5118289a54 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -23,7 +23,7 @@ import sagemaker from sagemaker import s3_input, Session, get_execution_role -from sagemaker.session import _tuning_job_status, _transform_job_status +from sagemaker.session import _tuning_job_status, _transform_job_status, _train_done REGION = 'us-west-2' @@ -711,3 +711,25 @@ def test_transform_job_status_none(sagemaker_session): result = _transform_job_status(sagemaker_session.sagemaker_client, JOB_NAME) assert result is None + + +def test_train_done_completed(sagemaker_session): + training_job_desc = {'TrainingJobStatus': 'Completed'} + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name='describe_training_job', return_value=training_job_desc) + + actual_job_desc, training_finished = _train_done(sagemaker_session.sagemaker_client, JOB_NAME, None) + + assert actual_job_desc['TrainingJobStatus'] == 'Completed' + assert training_finished is True + + +def test_train_done_in_progress(sagemaker_session): + training_job_desc = {'TrainingJobStatus': 'InProgress'} + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name='describe_training_job', return_value=training_job_desc) + + actual_job_desc, training_finished = _train_done(sagemaker_session.sagemaker_client, JOB_NAME, None) + + assert actual_job_desc['TrainingJobStatus'] == 'InProgress' + assert training_finished is False diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 1170d76475..ba14b9e1d1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -17,7 +17,12 @@ import pytest from mock import patch -from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError, extract_name_from_job_arn +from sagemaker.utils import get_config_value, name_from_base,\ + to_str, DeferredError, extract_name_from_job_arn, secondary_training_status_changed,\ + secondary_training_status_message + +from datetime import datetime +import time NAME = 'base_name' @@ -89,3 +94,48 @@ def test_name_from_training_arn(): arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b' name = extract_name_from_job_arn(arn) assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b' + + +MESSAGE = 'message' +STATUS = 'status' +TRAINING_JOB_DESCRIPTION_1 = { + 'SecondaryStatusTransitions': [{'StatusMessage': MESSAGE, 'Status': STATUS}] +} +TRAINING_JOB_DESCRIPTION_2 = { + 'SecondaryStatusTransitions': [{'StatusMessage': 'different message', 'Status': STATUS}] +} +TRAINING_JOB_DESCRIPTION_EMPTY = { + 'SecondaryStatusTransitions': [] +} + + +def test_secondary_training_status_changed_true(): + changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) + assert changed is True + + +def test_secondary_training_status_changed_false(): + changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_1) + assert changed is False + + +def test_secondary_training_status_changed_empty(): + changed = secondary_training_status_changed(TRAINING_JOB_DESCRIPTION_EMPTY, TRAINING_JOB_DESCRIPTION_1) + assert changed is False + + +def test_secondary_training_status_message_status_changed(): + now = datetime.now() + TRAINING_JOB_DESCRIPTION_1['SecondaryStatusTransitions'][-1]['StartTime'] = now + expected = '{} {} - {}'.format( + datetime.utcfromtimestamp(time.mktime(now.timetuple())).strftime('%Y-%m-%d %H:%M:%S'), + STATUS, + MESSAGE + ) + assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_EMPTY) == expected + + +def test_secondary_training_status_message_status_not_changed(): + now = datetime.now() + TRAINING_JOB_DESCRIPTION_1['SecondaryStatusTransitions'][-1]['StartTime'] = now + assert secondary_training_status_message(TRAINING_JOB_DESCRIPTION_1, TRAINING_JOB_DESCRIPTION_2) == MESSAGE