|
26 | 26 | from botocore.exceptions import ClientError
|
27 | 27 |
|
28 | 28 | from sagemaker.user_agent import prepend_user_agent
|
29 |
| -from sagemaker.utils import name_from_image |
| 29 | +from sagemaker.utils import name_from_image, secondary_training_status_message, secondary_training_status_changed |
30 | 30 | import sagemaker.logs
|
31 | 31 |
|
32 | 32 |
|
@@ -556,7 +556,8 @@ def wait_for_job(self, job, poll=5):
|
556 | 556 | Raises:
|
557 | 557 | ValueError: If the training job fails.
|
558 | 558 | """
|
559 |
| - desc = _wait_until(lambda: _train_done(self.sagemaker_client, job), poll) |
| 559 | + desc = _wait_until_training_done(lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), |
| 560 | + None, poll) |
560 | 561 | self._check_job_status(job, desc, 'TrainingJobStatus')
|
561 | 562 | return desc
|
562 | 563 |
|
@@ -795,6 +796,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
|
795 | 796 | """
|
796 | 797 |
|
797 | 798 | description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
|
| 799 | + print(secondary_training_status_message(description, None)) |
798 | 800 | instance_count = description['ResourceConfig']['InstanceCount']
|
799 | 801 | status = description['TrainingJobStatus']
|
800 | 802 |
|
@@ -834,6 +836,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
|
834 | 836 | # - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
|
835 | 837 | # the job was marked complete.
|
836 | 838 | last_describe_job_call = time.time()
|
| 839 | + last_description = description |
837 | 840 | while True:
|
838 | 841 | if len(stream_names) < instance_count:
|
839 | 842 | # Log streams are created whenever a container starts writing to stdout/err, so this list
|
@@ -877,16 +880,21 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
|
877 | 880 | description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
|
878 | 881 | last_describe_job_call = time.time()
|
879 | 882 |
|
| 883 | + if secondary_training_status_changed(description, last_description): |
| 884 | + print() |
| 885 | + print(secondary_training_status_message(description, last_description), end='') |
| 886 | + last_description = description |
| 887 | + |
880 | 888 | status = description['TrainingJobStatus']
|
881 | 889 |
|
882 | 890 | if status == 'Completed' or status == 'Failed' or status == 'Stopped':
|
| 891 | + print() |
883 | 892 | state = LogState.JOB_COMPLETE
|
884 | 893 |
|
885 | 894 | if wait:
|
886 | 895 | self._check_job_status(job_name, description, 'TrainingJobStatus')
|
887 | 896 | if dot:
|
888 | 897 | print()
|
889 |
| - print('===== Job Complete =====') |
890 | 898 | # Customers are not billed for hardware provisioning, so billable time is less than total time
|
891 | 899 | billable_time = (description['TrainingEndTime'] - description['TrainingStartTime']) * instance_count
|
892 | 900 | print('Billable seconds:', int(billable_time.total_seconds()) + 1)
|
@@ -1009,30 +1017,25 @@ def _deployment_entity_exists(describe_fn):
|
1009 | 1017 | return False
|
1010 | 1018 |
|
1011 | 1019 |
|
1012 |
| -def _train_done(sagemaker_client, job_name): |
1013 |
| - training_status_codes = { |
1014 |
| - 'Created': '-', |
1015 |
| - 'InProgress': '.', |
1016 |
| - 'Completed': '!', |
1017 |
| - 'Failed': '*', |
1018 |
| - 'Stopping': '>', |
1019 |
| - 'Stopped': 's', |
1020 |
| - 'Deleting': 'o', |
1021 |
| - 'Deleted': 'x' |
1022 |
| - } |
| 1020 | +def _train_done(sagemaker_client, job_name, last_desc): |
| 1021 | + |
1023 | 1022 | in_progress_statuses = ['InProgress', 'Created']
|
1024 | 1023 |
|
1025 | 1024 | desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
|
1026 | 1025 | status = desc['TrainingJobStatus']
|
1027 | 1026 |
|
1028 |
| - print(training_status_codes.get(status, '?'), end='') |
| 1027 | + if secondary_training_status_changed(desc, last_desc): |
| 1028 | + print() |
| 1029 | + print(secondary_training_status_message(desc, last_desc), end='') |
| 1030 | + else: |
| 1031 | + print('.', end='') |
1029 | 1032 | sys.stdout.flush()
|
1030 | 1033 |
|
1031 | 1034 | if status in in_progress_statuses:
|
1032 |
| - return None |
| 1035 | + return desc, False |
1033 | 1036 |
|
1034 |
| - print('') |
1035 |
| - return desc |
| 1037 | + print() |
| 1038 | + return desc, True |
1036 | 1039 |
|
1037 | 1040 |
|
1038 | 1041 | def _tuning_job_status(sagemaker_client, job_name):
|
@@ -1102,6 +1105,14 @@ def _deploy_done(sagemaker_client, endpoint_name):
|
1102 | 1105 | return None if status in in_progress_statuses else desc
|
1103 | 1106 |
|
1104 | 1107 |
|
| 1108 | +def _wait_until_training_done(callable, desc, poll=5): |
| 1109 | + job_desc, finished = callable(desc) |
| 1110 | + while not finished: |
| 1111 | + time.sleep(poll) |
| 1112 | + job_desc, finished = callable(job_desc) |
| 1113 | + return job_desc |
| 1114 | + |
| 1115 | + |
1105 | 1116 | def _wait_until(callable, poll=5):
|
1106 | 1117 | result = callable()
|
1107 | 1118 | while result is None:
|
|
0 commit comments