34
34
secondary_training_status_changed ,
35
35
secondary_training_status_message ,
36
36
)
37
+ from sagemaker import exceptions
37
38
38
39
LOGGER = logging .getLogger ("sagemaker" )
39
40
@@ -792,10 +793,13 @@ def wait_for_model_package(self, model_package_name, poll=5):
792
793
793
794
if status != "Completed" :
794
795
reason = desc .get ("FailureReason" , None )
795
- raise ValueError (
796
- "Error creating model package {}: {} Reason: {}" .format (
797
- model_package_name , status , reason
798
- )
796
+ raise exceptions .UnexpectedStatusException (
797
+ message = "Error creating model package {package}: {status} Reason: {reason}" .format (
798
+ package = model_package_name ,
799
+ status = status ,
800
+ reason = reason ),
801
+ allowed_statuses = ["Completed" ],
802
+ actual_status = status
799
803
)
800
804
return desc
801
805
@@ -947,7 +951,7 @@ def wait_for_job(self, job, poll=5):
947
951
(dict): Return value from the ``DescribeTrainingJob`` API.
948
952
949
953
Raises:
950
- ValueError : If the training job fails.
954
+ exceptions.UnexpectedStatusException : If the training job fails.
951
955
"""
952
956
desc = _wait_until_training_done (
953
957
lambda last_desc : _train_done (self .sagemaker_client , job , last_desc ), None , poll
@@ -966,7 +970,7 @@ def wait_for_compilation_job(self, job, poll=5):
966
970
(dict): Return value from the ``DescribeCompilationJob`` API.
967
971
968
972
Raises:
969
- ValueError : If the compilation job fails.
973
+ exceptions.UnexpectedStatusException : If the compilation job fails.
970
974
"""
971
975
desc = _wait_until (lambda : _compilation_job_status (self .sagemaker_client , job ), poll )
972
976
self ._check_job_status (job , desc , "CompilationJobStatus" )
@@ -983,7 +987,7 @@ def wait_for_tuning_job(self, job, poll=5):
983
987
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
984
988
985
989
Raises:
986
- ValueError : If the hyperparameter tuning job fails.
990
+ exceptions.UnexpectedStatusException : If the hyperparameter tuning job fails.
987
991
"""
988
992
desc = _wait_until (lambda : _tuning_job_status (self .sagemaker_client , job ), poll )
989
993
self ._check_job_status (job , desc , "HyperParameterTuningJobStatus" )
@@ -1000,23 +1004,23 @@ def wait_for_transform_job(self, job, poll=5):
1000
1004
(dict): Return value from the ``DescribeTransformJob`` API.
1001
1005
1002
1006
Raises:
1003
- ValueError : If the transform job fails.
1007
+ exceptions.UnexpectedStatusException : If the transform job fails.
1004
1008
"""
1005
1009
desc = _wait_until (lambda : _transform_job_status (self .sagemaker_client , job ), poll )
1006
1010
self ._check_job_status (job , desc , "TransformJobStatus" )
1007
1011
return desc
1008
1012
1009
1013
def _check_job_status (self , job , desc , status_key_name ):
1010
1014
"""Check to see if the job completed successfully and, if not, construct and
1011
- raise a ValueError .
1015
+ raise a exceptions.UnexpectedStatusException .
1012
1016
1013
1017
Args:
1014
1018
job (str): The name of the job to check.
1015
1019
desc (dict[str, str]): The result of ``describe_training_job()``.
1016
1020
status_key_name (str): Status key name to check for.
1017
1021
1018
1022
Raises:
1019
- ValueError : If the training job fails.
1023
+ exceptions.UnexpectedStatusException : If the training job fails.
1020
1024
"""
1021
1025
status = desc [status_key_name ]
1022
1026
# If the status is capital case, then convert it to Camel case
@@ -1025,7 +1029,16 @@ def _check_job_status(self, job, desc, status_key_name):
1025
1029
if status != "Completed" and status != "Stopped" :
1026
1030
reason = desc .get ("FailureReason" , "(No reason provided)" )
1027
1031
job_type = status_key_name .replace ("JobStatus" , " job" )
1028
- raise ValueError ("Error for {} {}: {} Reason: {}" .format (job_type , job , status , reason ))
1032
+ raise exceptions .UnexpectedStatusException (
1033
+ message = "Error for {job_type} {job_name}: {status}. Reason: {reason}" .format (
1034
+ job_type = job_type ,
1035
+ job_name = job ,
1036
+ status = status ,
1037
+ reason = reason
1038
+ ),
1039
+ allowed_statuses = ["Completed" , "Stopped" ],
1040
+ actual_status = status
1041
+ )
1029
1042
1030
1043
def wait_for_endpoint (self , endpoint , poll = 5 ):
1031
1044
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1042,8 +1055,13 @@ def wait_for_endpoint(self, endpoint, poll=5):
1042
1055
1043
1056
if status != "InService" :
1044
1057
reason = desc .get ("FailureReason" , None )
1045
- raise ValueError (
1046
- "Error hosting endpoint {}: {} Reason: {}" .format (endpoint , status , reason )
1058
+ raise exceptions .UnexpectedStatusException (
1059
+ message = "Error hosting endpoint {endpoint}: {status}. Reason: {reason}" .format (
1060
+ endpoint = endpoint ,
1061
+ status = status ,
1062
+ reason = reason ),
1063
+ allowed_statuses = ["InService" ],
1064
+ actual_status = status
1047
1065
)
1048
1066
return desc
1049
1067
@@ -1276,7 +1294,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1276
1294
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1277
1295
1278
1296
Raises:
1279
- ValueError : If waiting and the training job fails.
1297
+ exceptions.UnexpectedStatusException : If waiting and the training job fails.
1280
1298
"""
1281
1299
1282
1300
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
0 commit comments