34
34
secondary_training_status_changed ,
35
35
secondary_training_status_message ,
36
36
)
37
+ from sagemaker import exceptions
37
38
38
39
LOGGER = logging .getLogger ("sagemaker" )
39
40
@@ -792,10 +793,10 @@ def wait_for_model_package(self, model_package_name, poll=5):
792
793
793
794
if status != "Completed" :
794
795
reason = desc .get ("FailureReason" , None )
795
- raise ValueError (
796
- "Error creating model package {}: {} Reason: {} " .format (
797
- model_package_name , status , reason
798
- )
796
+ raise exceptions . UnexpectedStatusException (
797
+ message = "Error creating model package %s: %s Reason: %s " .format (model_package_name , status , reason ),
798
+ allowed_statuses = [ "Completed" ],
799
+ actual_status = status
799
800
)
800
801
return desc
801
802
@@ -947,7 +948,7 @@ def wait_for_job(self, job, poll=5):
947
948
(dict): Return value from the ``DescribeTrainingJob`` API.
948
949
949
950
Raises:
950
- ValueError : If the training job fails.
951
+ exceptions.UnexpectedStatusException : If the training job fails.
951
952
"""
952
953
desc = _wait_until_training_done (
953
954
lambda last_desc : _train_done (self .sagemaker_client , job , last_desc ), None , poll
@@ -966,7 +967,7 @@ def wait_for_compilation_job(self, job, poll=5):
966
967
(dict): Return value from the ``DescribeCompilationJob`` API.
967
968
968
969
Raises:
969
- ValueError : If the compilation job fails.
970
+ exceptions.UnexpectedStatusException : If the compilation job fails.
970
971
"""
971
972
desc = _wait_until (lambda : _compilation_job_status (self .sagemaker_client , job ), poll )
972
973
self ._check_job_status (job , desc , "CompilationJobStatus" )
@@ -983,7 +984,7 @@ def wait_for_tuning_job(self, job, poll=5):
983
984
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
984
985
985
986
Raises:
986
- ValueError : If the hyperparameter tuning job fails.
987
+ exceptions.UnexpectedStatusException : If the hyperparameter tuning job fails.
987
988
"""
988
989
desc = _wait_until (lambda : _tuning_job_status (self .sagemaker_client , job ), poll )
989
990
self ._check_job_status (job , desc , "HyperParameterTuningJobStatus" )
@@ -1000,23 +1001,23 @@ def wait_for_transform_job(self, job, poll=5):
1000
1001
(dict): Return value from the ``DescribeTransformJob`` API.
1001
1002
1002
1003
Raises:
1003
- ValueError : If the transform job fails.
1004
+ exceptions.UnexpectedStatusException : If the transform job fails.
1004
1005
"""
1005
1006
desc = _wait_until (lambda : _transform_job_status (self .sagemaker_client , job ), poll )
1006
1007
self ._check_job_status (job , desc , "TransformJobStatus" )
1007
1008
return desc
1008
1009
1009
1010
def _check_job_status (self , job , desc , status_key_name ):
1010
1011
"""Check to see if the job completed successfully and, if not, construct and
1011
- raise a ValueError .
1012
+ raise a exceptions.UnexpectedStatusException .
1012
1013
1013
1014
Args:
1014
1015
job (str): The name of the job to check.
1015
1016
desc (dict[str, str]): The result of ``describe_training_job()``.
1016
1017
status_key_name (str): Status key name to check for.
1017
1018
1018
1019
Raises:
1019
- ValueError : If the training job fails.
1020
+ exceptions.UnexpectedStatusException : If the training job fails.
1020
1021
"""
1021
1022
status = desc [status_key_name ]
1022
1023
# If the status is capital case, then convert it to Camel case
@@ -1025,7 +1026,11 @@ def _check_job_status(self, job, desc, status_key_name):
1025
1026
if status != "Completed" and status != "Stopped" :
1026
1027
reason = desc .get ("FailureReason" , "(No reason provided)" )
1027
1028
job_type = status_key_name .replace ("JobStatus" , " job" )
1028
- raise ValueError ("Error for {} {}: {} Reason: {}" .format (job_type , job , status , reason ))
1029
+ raise exceptions .UnexpectedStatusException (
1030
+ message = "Error for %s %s: %s Reason: %s" .format (job_type , job , status , reason ),
1031
+ allowed_statuses = ["Completed" , "Stopped" ],
1032
+ actual_status = status
1033
+ )
1029
1034
1030
1035
def wait_for_endpoint (self , endpoint , poll = 5 ):
1031
1036
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1042,8 +1047,10 @@ def wait_for_endpoint(self, endpoint, poll=5):
1042
1047
1043
1048
if status != "InService" :
1044
1049
reason = desc .get ("FailureReason" , None )
1045
- raise ValueError (
1046
- "Error hosting endpoint {}: {} Reason: {}" .format (endpoint , status , reason )
1050
+ raise exceptions .UnexpectedStatusException (
1051
+ message = "Error hosting endpoint %s: %s Reason: %s" .format (endpoint , status , reason ),
1052
+ allowed_statuses = ["InService" ],
1053
+ actual_status = status
1047
1054
)
1048
1055
return desc
1049
1056
@@ -1276,7 +1283,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1276
1283
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1277
1284
1278
1285
Raises:
1279
- ValueError : If waiting and the training job fails.
1286
+ exceptions.UnexpectedStatusException : If waiting and the training job fails.
1280
1287
"""
1281
1288
1282
1289
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
0 commit comments