Skip to content

Commit 9638520

Browse files
committed
change: Replaced generic ValueError with custom subclass when reporting unexpected resource status
1 parent 2d7bff8 commit 9638520

File tree

3 files changed

+135
-14
lines changed

3 files changed

+135
-14
lines changed

src/sagemaker/exceptions.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
class UnexpectedStatusException(ValueError):
17+
"""Raised when resource status is not expected and thus not allowed for further execution"""
18+
def __init__(self, message, allowed_statuses, actual_status):
19+
self.allowed_statuses = allowed_statuses
20+
self.actual_status = actual_status
21+
super(UnexpectedStatusException, self).__init__(message)

src/sagemaker/session.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
secondary_training_status_changed,
3535
secondary_training_status_message,
3636
)
37+
from sagemaker import exceptions
3738

3839
LOGGER = logging.getLogger("sagemaker")
3940

@@ -792,10 +793,10 @@ def wait_for_model_package(self, model_package_name, poll=5):
792793

793794
if status != "Completed":
794795
reason = desc.get("FailureReason", None)
795-
raise ValueError(
796-
"Error creating model package {}: {} Reason: {}".format(
797-
model_package_name, status, reason
798-
)
796+
raise exceptions.UnexpectedStatusException(
797+
message="Error creating model package %s: %s Reason: %s".format(model_package_name, status, reason),
798+
allowed_statuses=["Completed"],
799+
actual_status=status
799800
)
800801
return desc
801802

@@ -947,7 +948,7 @@ def wait_for_job(self, job, poll=5):
947948
(dict): Return value from the ``DescribeTrainingJob`` API.
948949
949950
Raises:
950-
ValueError: If the training job fails.
951+
exceptions.UnexpectedStatusException: If the training job fails.
951952
"""
952953
desc = _wait_until_training_done(
953954
lambda last_desc: _train_done(self.sagemaker_client, job, last_desc), None, poll
@@ -966,7 +967,7 @@ def wait_for_compilation_job(self, job, poll=5):
966967
(dict): Return value from the ``DescribeCompilationJob`` API.
967968
968969
Raises:
969-
ValueError: If the compilation job fails.
970+
exceptions.UnexpectedStatusException: If the compilation job fails.
970971
"""
971972
desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll)
972973
self._check_job_status(job, desc, "CompilationJobStatus")
@@ -983,7 +984,7 @@ def wait_for_tuning_job(self, job, poll=5):
983984
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
984985
985986
Raises:
986-
ValueError: If the hyperparameter tuning job fails.
987+
exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails.
987988
"""
988989
desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll)
989990
self._check_job_status(job, desc, "HyperParameterTuningJobStatus")
@@ -1000,23 +1001,23 @@ def wait_for_transform_job(self, job, poll=5):
10001001
(dict): Return value from the ``DescribeTransformJob`` API.
10011002
10021003
Raises:
1003-
ValueError: If the transform job fails.
1004+
exceptions.UnexpectedStatusException: If the transform job fails.
10041005
"""
10051006
desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll)
10061007
self._check_job_status(job, desc, "TransformJobStatus")
10071008
return desc
10081009

10091010
def _check_job_status(self, job, desc, status_key_name):
10101011
"""Check to see if the job completed successfully and, if not, construct and
1011-
raise a ValueError.
1012+
raise a exceptions.UnexpectedStatusException.
10121013
10131014
Args:
10141015
job (str): The name of the job to check.
10151016
desc (dict[str, str]): The result of ``describe_training_job()``.
10161017
status_key_name (str): Status key name to check for.
10171018
10181019
Raises:
1019-
ValueError: If the training job fails.
1020+
exceptions.UnexpectedStatusException: If the training job fails.
10201021
"""
10211022
status = desc[status_key_name]
10221023
# If the status is capital case, then convert it to Camel case
@@ -1025,7 +1026,11 @@ def _check_job_status(self, job, desc, status_key_name):
10251026
if status != "Completed" and status != "Stopped":
10261027
reason = desc.get("FailureReason", "(No reason provided)")
10271028
job_type = status_key_name.replace("JobStatus", " job")
1028-
raise ValueError("Error for {} {}: {} Reason: {}".format(job_type, job, status, reason))
1029+
raise exceptions.UnexpectedStatusException(
1030+
message="Error for %s %s: %s Reason: %s".format(job_type, job, status, reason),
1031+
allowed_statuses=["Completed", "Stopped"],
1032+
actual_status=status
1033+
)
10291034

10301035
def wait_for_endpoint(self, endpoint, poll=5):
10311036
"""Wait for an Amazon SageMaker endpoint deployment to complete.
@@ -1042,8 +1047,10 @@ def wait_for_endpoint(self, endpoint, poll=5):
10421047

10431048
if status != "InService":
10441049
reason = desc.get("FailureReason", None)
1045-
raise ValueError(
1046-
"Error hosting endpoint {}: {} Reason: {}".format(endpoint, status, reason)
1050+
raise exceptions.UnexpectedStatusException(
1051+
message="Error hosting endpoint %s: %s Reason: %s".format(endpoint, status, reason),
1052+
allowed_statuses=["InService"],
1053+
actual_status=status
10471054
)
10481055
return desc
10491056

@@ -1276,7 +1283,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
12761283
poll (int): The interval in seconds between polling for new log entries and job completion (default: 5).
12771284
12781285
Raises:
1279-
ValueError: If waiting and the training job fails.
1286+
exceptions.UnexpectedStatusException: If waiting and the training job fails.
12801287
"""
12811288

12821289
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from mock import Mock, MagicMock
17+
import sagemaker
18+
19+
EXPANDED_ROLE = 'arn:aws:iam::111111111111:role/ExpandedRole'
20+
REGION = 'us-west-2'
21+
MODEL_PACKAGE_NAME = 'my_model_package'
22+
JOB_NAME = 'my_job_name'
23+
ENDPOINT_NAME = 'the_point_of_end'
24+
25+
26+
def get_sagemaker_session(returns_status):
27+
boto_mock = Mock(name='boto_session', region_name=REGION)
28+
client_mock = Mock()
29+
client_mock.describe_model_package = MagicMock(return_value={'ModelPackageStatus': returns_status})
30+
client_mock.describe_endpoint = MagicMock(return_value={'EndpointStatus': returns_status})
31+
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
32+
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
33+
return ims
34+
35+
36+
def test_does_not_raise_when_successfully_created_package():
37+
try:
38+
sagemaker_session = get_sagemaker_session(returns_status='Completed')
39+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
40+
except sagemaker.exceptions.UnexpectedStatusException:
41+
pytest.fail("UnexpectedStatusException was thrown while it should not")
42+
43+
44+
def test_raise_when_failed_created_package():
45+
try:
46+
sagemaker_session = get_sagemaker_session(returns_status='EnRoute')
47+
sagemaker_session.wait_for_model_package(MODEL_PACKAGE_NAME)
48+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
49+
except Exception as e:
50+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
51+
assert e.actual_status == 'EnRoute'
52+
assert 'Completed' in e.allowed_statuses
53+
54+
55+
def test_does_not_raise_when_correct_job_status():
56+
try:
57+
job = Mock()
58+
sagemaker_session = get_sagemaker_session(returns_status='Stopped')
59+
sagemaker_session._check_job_status(job, {'TransformationJobStatus': 'Stopped'}, 'TransformationJobStatus')
60+
except sagemaker.exceptions.UnexpectedStatusException:
61+
pytest.fail("UnexpectedStatusException was thrown while it should not")
62+
63+
64+
def test_does_raise_when_incorrect_job_status():
65+
try:
66+
job = Mock()
67+
sagemaker_session = get_sagemaker_session(returns_status='Failed')
68+
sagemaker_session._check_job_status(job, {'TransformationJobStatus': 'Failed'}, 'TransformationJobStatus')
69+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
70+
except Exception as e:
71+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
72+
assert e.actual_status == 'Failed'
73+
assert 'Completed' in e.allowed_statuses
74+
assert 'Stopped' in e.allowed_statuses
75+
76+
77+
def test_does_not_raise_when_successfully_deployed_endpoint():
78+
try:
79+
sagemaker_session = get_sagemaker_session(returns_status='InService')
80+
sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
81+
except sagemaker.exceptions.UnexpectedStatusException:
82+
pytest.fail("UnexpectedStatusException was thrown while it should not")
83+
84+
85+
def test_raise_when_failed_to_deploy_endpoint():
86+
try:
87+
sagemaker_session = get_sagemaker_session(returns_status='Failed')
88+
assert sagemaker_session.wait_for_endpoint(ENDPOINT_NAME)
89+
assert False, 'sagemaker.exceptions.UnexpectedStatusException should have been raised but was not'
90+
except Exception as e:
91+
assert type(e) == sagemaker.exceptions.UnexpectedStatusException
92+
assert e.actual_status == 'Failed'
93+
assert 'InService' in e.allowed_statuses

0 commit comments

Comments
 (0)