34
34
secondary_training_status_changed ,
35
35
secondary_training_status_message ,
36
36
)
37
+ from sagemaker import exceptions
37
38
38
39
LOGGER = logging .getLogger ("sagemaker" )
39
40
@@ -792,10 +793,12 @@ def wait_for_model_package(self, model_package_name, poll=5):
792
793
793
794
if status != "Completed" :
794
795
reason = desc .get ("FailureReason" , None )
795
- raise ValueError (
796
- "Error creating model package {}: {} Reason: {}" .format (
797
- model_package_name , status , reason
798
- )
796
+ raise exceptions .UnexpectedStatusException (
797
+ message = "Error creating model package {package}: {status} Reason: {reason}" .format (
798
+ package = model_package_name , status = status , reason = reason
799
+ ),
800
+ allowed_statuses = ["Completed" ],
801
+ actual_status = status ,
799
802
)
800
803
return desc
801
804
@@ -947,7 +950,7 @@ def wait_for_job(self, job, poll=5):
947
950
(dict): Return value from the ``DescribeTrainingJob`` API.
948
951
949
952
Raises:
950
- ValueError : If the training job fails.
953
+ exceptions.UnexpectedStatusException : If the training job fails.
951
954
"""
952
955
desc = _wait_until_training_done (
953
956
lambda last_desc : _train_done (self .sagemaker_client , job , last_desc ), None , poll
@@ -966,7 +969,7 @@ def wait_for_compilation_job(self, job, poll=5):
966
969
(dict): Return value from the ``DescribeCompilationJob`` API.
967
970
968
971
Raises:
969
- ValueError : If the compilation job fails.
972
+ exceptions.UnexpectedStatusException : If the compilation job fails.
970
973
"""
971
974
desc = _wait_until (lambda : _compilation_job_status (self .sagemaker_client , job ), poll )
972
975
self ._check_job_status (job , desc , "CompilationJobStatus" )
@@ -983,7 +986,7 @@ def wait_for_tuning_job(self, job, poll=5):
983
986
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
984
987
985
988
Raises:
986
- ValueError : If the hyperparameter tuning job fails.
989
+ exceptions.UnexpectedStatusException : If the hyperparameter tuning job fails.
987
990
"""
988
991
desc = _wait_until (lambda : _tuning_job_status (self .sagemaker_client , job ), poll )
989
992
self ._check_job_status (job , desc , "HyperParameterTuningJobStatus" )
@@ -1000,23 +1003,23 @@ def wait_for_transform_job(self, job, poll=5):
1000
1003
(dict): Return value from the ``DescribeTransformJob`` API.
1001
1004
1002
1005
Raises:
1003
- ValueError : If the transform job fails.
1006
+ exceptions.UnexpectedStatusException : If the transform job fails.
1004
1007
"""
1005
1008
desc = _wait_until (lambda : _transform_job_status (self .sagemaker_client , job ), poll )
1006
1009
self ._check_job_status (job , desc , "TransformJobStatus" )
1007
1010
return desc
1008
1011
1009
1012
def _check_job_status (self , job , desc , status_key_name ):
1010
1013
"""Check to see if the job completed successfully and, if not, construct and
1011
- raise a ValueError .
1014
+ raise a exceptions.UnexpectedStatusException .
1012
1015
1013
1016
Args:
1014
1017
job (str): The name of the job to check.
1015
1018
desc (dict[str, str]): The result of ``describe_training_job()``.
1016
1019
status_key_name (str): Status key name to check for.
1017
1020
1018
1021
Raises:
1019
- ValueError : If the training job fails.
1022
+ exceptions.UnexpectedStatusException : If the training job fails.
1020
1023
"""
1021
1024
status = desc [status_key_name ]
1022
1025
# If the status is capital case, then convert it to Camel case
@@ -1025,7 +1028,13 @@ def _check_job_status(self, job, desc, status_key_name):
1025
1028
if status != "Completed" and status != "Stopped" :
1026
1029
reason = desc .get ("FailureReason" , "(No reason provided)" )
1027
1030
job_type = status_key_name .replace ("JobStatus" , " job" )
1028
- raise ValueError ("Error for {} {}: {} Reason: {}" .format (job_type , job , status , reason ))
1031
+ raise exceptions .UnexpectedStatusException (
1032
+ message = "Error for {job_type} {job_name}: {status}. Reason: {reason}" .format (
1033
+ job_type = job_type , job_name = job , status = status , reason = reason
1034
+ ),
1035
+ allowed_statuses = ["Completed" , "Stopped" ],
1036
+ actual_status = status ,
1037
+ )
1029
1038
1030
1039
def wait_for_endpoint (self , endpoint , poll = 5 ):
1031
1040
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1042,8 +1051,12 @@ def wait_for_endpoint(self, endpoint, poll=5):
1042
1051
1043
1052
if status != "InService" :
1044
1053
reason = desc .get ("FailureReason" , None )
1045
- raise ValueError (
1046
- "Error hosting endpoint {}: {} Reason: {}" .format (endpoint , status , reason )
1054
+ raise exceptions .UnexpectedStatusException (
1055
+ message = "Error hosting endpoint {endpoint}: {status}. Reason: {reason}." .format (
1056
+ endpoint = endpoint , status = status , reason = reason
1057
+ ),
1058
+ allowed_statuses = ["InService" ],
1059
+ actual_status = status ,
1047
1060
)
1048
1061
return desc
1049
1062
@@ -1276,7 +1289,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
1276
1289
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
1277
1290
1278
1291
Raises:
1279
- ValueError : If waiting and the training job fails.
1292
+ exceptions.UnexpectedStatusException : If waiting and the training job fails.
1280
1293
"""
1281
1294
1282
1295
description = self .sagemaker_client .describe_training_job (TrainingJobName = job_name )
0 commit comments