Skip to content

Commit 924b060

Browse files
feature: pluggable instance fallback mechanism, add CapacityError (#3033)
1 parent 30b4ce2 commit 924b060

File tree

6 files changed

+191
-77
lines changed

6 files changed

+191
-77
lines changed

src/sagemaker/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def __init__(self, message, allowed_statuses, actual_status):
2323
super(UnexpectedStatusException, self).__init__(message)
2424

2525

26+
class CapacityError(UnexpectedStatusException):
27+
"""Raised when resource status is not expected and fails with a reason of CapacityError"""
28+
29+
2630
class AsyncInferenceError(Exception):
2731
"""The base exception class for Async Inference exceptions."""
2832

src/sagemaker/session.py

+51-12
Original file line numberDiff line numberDiff line change
@@ -1721,6 +1721,7 @@ def wait_for_auto_ml_job(self, job, poll=5):
17211721
(dict): Return value from the ``DescribeAutoMLJob`` API.
17221722
17231723
Raises:
1724+
exceptions.CapacityError: If the auto ml job fails with CapacityError.
17241725
exceptions.UnexpectedStatusException: If the auto ml job fails.
17251726
"""
17261727
desc = _wait_until(lambda: _auto_ml_job_status(self.sagemaker_client, job), poll)
@@ -1743,7 +1744,8 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
17431744
completion (default: 5).
17441745
17451746
Raises:
1746-
exceptions.UnexpectedStatusException: If waiting and the training job fails.
1747+
exceptions.CapacityError: If waiting and auto ml job fails with CapacityError.
1748+
exceptions.UnexpectedStatusException: If waiting and auto ml job fails.
17471749
"""
17481750

17491751
description = self.sagemaker_client.describe_auto_ml_job(AutoMLJobName=job_name)
@@ -2845,6 +2847,10 @@ def wait_for_model_package(self, model_package_name, poll=5):
28452847
28462848
Returns:
28472849
dict: Return value from the ``DescribeEndpoint`` API.
2850+
2851+
Raises:
2852+
exceptions.CapacityError: If the Model Package job fails with CapacityError.
2853+
exceptions.UnexpectedStatusException: If waiting and the Model Package job fails.
28482854
"""
28492855
desc = _wait_until(
28502856
lambda: _create_model_package_status(self.sagemaker_client, model_package_name), poll
@@ -2853,10 +2859,17 @@ def wait_for_model_package(self, model_package_name, poll=5):
28532859

28542860
if status != "Completed":
28552861
reason = desc.get("FailureReason", None)
2862+
message = "Error creating model package {package}: {status} Reason: {reason}".format(
2863+
package=model_package_name, status=status, reason=reason
2864+
)
2865+
if "CapacityError" in str(reason):
2866+
raise exceptions.CapacityError(
2867+
message=message,
2868+
allowed_statuses=["InService"],
2869+
actual_status=status,
2870+
)
28562871
raise exceptions.UnexpectedStatusException(
2857-
message="Error creating model package {package}: {status} Reason: {reason}".format(
2858-
package=model_package_name, status=status, reason=reason
2859-
),
2872+
message=message,
28602873
allowed_statuses=["Completed"],
28612874
actual_status=status,
28622875
)
@@ -3147,6 +3160,7 @@ def wait_for_job(self, job, poll=5):
31473160
(dict): Return value from the ``DescribeTrainingJob`` API.
31483161
31493162
Raises:
3163+
exceptions.CapacityError: If the training job fails with CapacityError.
31503164
exceptions.UnexpectedStatusException: If the training job fails.
31513165
"""
31523166
desc = _wait_until_training_done(
@@ -3166,7 +3180,8 @@ def wait_for_processing_job(self, job, poll=5):
31663180
(dict): Return value from the ``DescribeProcessingJob`` API.
31673181
31683182
Raises:
3169-
exceptions.UnexpectedStatusException: If the compilation job fails.
3183+
exceptions.CapacityError: If the processing job fails with CapacityError.
3184+
exceptions.UnexpectedStatusException: If the processing job fails.
31703185
"""
31713186
desc = _wait_until(lambda: _processing_job_status(self.sagemaker_client, job), poll)
31723187
self._check_job_status(job, desc, "ProcessingJobStatus")
@@ -3183,6 +3198,7 @@ def wait_for_compilation_job(self, job, poll=5):
31833198
(dict): Return value from the ``DescribeCompilationJob`` API.
31843199
31853200
Raises:
3201+
exceptions.CapacityError: If the compilation job fails with CapacityError.
31863202
exceptions.UnexpectedStatusException: If the compilation job fails.
31873203
"""
31883204
desc = _wait_until(lambda: _compilation_job_status(self.sagemaker_client, job), poll)
@@ -3200,7 +3216,8 @@ def wait_for_edge_packaging_job(self, job, poll=5):
32003216
(dict): Return value from the ``DescribeEdgePackagingJob`` API.
32013217
32023218
Raises:
3203-
exceptions.UnexpectedStatusException: If the compilation job fails.
3219+
exceptions.CapacityError: If the edge packaging job fails with CapacityError.
3220+
exceptions.UnexpectedStatusException: If the edge packaging job fails.
32043221
"""
32053222
desc = _wait_until(lambda: _edge_packaging_job_status(self.sagemaker_client, job), poll)
32063223
self._check_job_status(job, desc, "EdgePackagingJobStatus")
@@ -3217,6 +3234,7 @@ def wait_for_tuning_job(self, job, poll=5):
32173234
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
32183235
32193236
Raises:
3237+
exceptions.CapacityError: If the hyperparameter tuning job fails with CapacityError.
32203238
exceptions.UnexpectedStatusException: If the hyperparameter tuning job fails.
32213239
"""
32223240
desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll)
@@ -3245,6 +3263,7 @@ def wait_for_transform_job(self, job, poll=5):
32453263
(dict): Return value from the ``DescribeTransformJob`` API.
32463264
32473265
Raises:
3266+
exceptions.CapacityError: If the transform job fails with CapacityError.
32483267
exceptions.UnexpectedStatusException: If the transform job fails.
32493268
"""
32503269
desc = _wait_until(lambda: _transform_job_status(self.sagemaker_client, job), poll)
@@ -3283,6 +3302,7 @@ def _check_job_status(self, job, desc, status_key_name):
32833302
status_key_name (str): Status key name to check for.
32843303
32853304
Raises:
3305+
exceptions.CapacityError: If the training job fails with CapacityError.
32863306
exceptions.UnexpectedStatusException: If the training job fails.
32873307
"""
32883308
status = desc[status_key_name]
@@ -3298,10 +3318,17 @@ def _check_job_status(self, job, desc, status_key_name):
32983318
elif status != "Completed":
32993319
reason = desc.get("FailureReason", "(No reason provided)")
33003320
job_type = status_key_name.replace("JobStatus", " job")
3321+
message = "Error for {job_type} {job_name}: {status}. Reason: {reason}".format(
3322+
job_type=job_type, job_name=job, status=status, reason=reason
3323+
)
3324+
if "CapacityError" in str(reason):
3325+
raise exceptions.CapacityError(
3326+
message=message,
3327+
allowed_statuses=["Completed", "Stopped"],
3328+
actual_status=status,
3329+
)
33013330
raise exceptions.UnexpectedStatusException(
3302-
message="Error for {job_type} {job_name}: {status}. Reason: {reason}".format(
3303-
job_type=job_type, job_name=job, status=status, reason=reason
3304-
),
3331+
message=message,
33053332
allowed_statuses=["Completed", "Stopped"],
33063333
actual_status=status,
33073334
)
@@ -3313,6 +3340,10 @@ def wait_for_endpoint(self, endpoint, poll=30):
33133340
endpoint (str): Name of the ``Endpoint`` to wait for.
33143341
poll (int): Polling interval in seconds (default: 5).
33153342
3343+
Raises:
3344+
exceptions.CapacityError: If the endpoint creation job fails with CapacityError.
3345+
exceptions.UnexpectedStatusException: If the endpoint creation job fails.
3346+
33163347
Returns:
33173348
dict: Return value from the ``DescribeEndpoint`` API.
33183349
"""
@@ -3321,10 +3352,17 @@ def wait_for_endpoint(self, endpoint, poll=30):
33213352

33223353
if status != "InService":
33233354
reason = desc.get("FailureReason", None)
3355+
message = "Error hosting endpoint {endpoint}: {status}. Reason: {reason}.".format(
3356+
endpoint=endpoint, status=status, reason=reason
3357+
)
3358+
if "CapacityError" in str(reason):
3359+
raise exceptions.CapacityError(
3360+
message=message,
3361+
allowed_statuses=["InService"],
3362+
actual_status=status,
3363+
)
33243364
raise exceptions.UnexpectedStatusException(
3325-
message="Error hosting endpoint {endpoint}: {status}. Reason: {reason}.".format(
3326-
endpoint=endpoint, status=status, reason=reason
3327-
),
3365+
message=message,
33283366
allowed_statuses=["InService"],
33293367
actual_status=status,
33303368
)
@@ -3649,6 +3687,7 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
36493687
completion (default: 5).
36503688
36513689
Raises:
3690+
exceptions.CapacityError: If the training job fails with CapacityError.
36523691
exceptions.UnexpectedStatusException: If waiting and the training job fails.
36533692
"""
36543693

tests/integ/test_huggingface.py

+23-33
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
import os
1616

1717
import pytest
18-
import logging
1918

2019
from sagemaker.huggingface import HuggingFace, HuggingFaceProcessor
2120
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor
2221
from sagemaker.utils import unique_name_from_base
2322
from tests import integ
23+
from tests.integ.utils import gpu_list, retry_with_instance_list
2424
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2525
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
26-
from sagemaker.exceptions import UnexpectedStatusException
2726

2827
ROLE = "SageMakerRole"
2928

@@ -34,43 +33,34 @@
3433
and integ.test_region() in integ.TRAINING_NO_P3_REGIONS,
3534
reason="no ml.p2 or ml.p3 instances in this region",
3635
)
36+
@retry_with_instance_list(gpu_list(integ.test_region()))
3737
def test_framework_processing_job_with_deps(
3838
sagemaker_session,
39-
gpu_instance_type_list,
4039
huggingface_training_latest_version,
4140
huggingface_training_pytorch_latest_version,
4241
huggingface_pytorch_latest_training_py_version,
42+
**kwargs,
4343
):
44-
for i_type in gpu_instance_type_list:
45-
logging.info("Using the instance type: {}".format(i_type))
46-
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
47-
code_path = os.path.join(DATA_DIR, "dummy_code_bundle_with_reqs")
48-
entry_point = "main_script.py"
49-
50-
processor = HuggingFaceProcessor(
51-
transformers_version=huggingface_training_latest_version,
52-
pytorch_version=huggingface_training_pytorch_latest_version,
53-
py_version=huggingface_pytorch_latest_training_py_version,
54-
role=ROLE,
55-
instance_count=1,
56-
instance_type=i_type,
57-
sagemaker_session=sagemaker_session,
58-
base_job_name="test-huggingface",
59-
)
60-
try:
61-
processor.run(
62-
code=entry_point,
63-
source_dir=code_path,
64-
inputs=[],
65-
wait=True,
66-
)
67-
except UnexpectedStatusException as e:
68-
if "CapacityError" in str(e) and i_type != gpu_instance_type_list[-1]:
69-
logging.warning("Failure using instance type: {}. {}".format(i_type, str(e)))
70-
continue
71-
else:
72-
raise
73-
break
44+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
45+
code_path = os.path.join(DATA_DIR, "dummy_code_bundle_with_reqs")
46+
entry_point = "main_script.py"
47+
48+
processor = HuggingFaceProcessor(
49+
transformers_version=huggingface_training_latest_version,
50+
pytorch_version=huggingface_training_pytorch_latest_version,
51+
py_version=huggingface_pytorch_latest_training_py_version,
52+
role=ROLE,
53+
instance_count=1,
54+
instance_type=kwargs["instance_type"],
55+
sagemaker_session=sagemaker_session,
56+
base_job_name="test-huggingface",
57+
)
58+
processor.run(
59+
code=entry_point,
60+
source_dir=code_path,
61+
inputs=[],
62+
wait=True,
63+
)
7464

7565

7666
@pytest.mark.release

tests/integ/test_tf.py

+22-32
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import numpy as np
1616
import os
1717
import time
18-
import logging
1918

2019
import pytest
2120

@@ -25,8 +24,8 @@
2524
import tests.integ
2625
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, kms_utils, timeout
2726
from tests.integ.retry import retries
27+
from tests.integ.utils import gpu_list, retry_with_instance_list
2828
from tests.integ.s3_utils import assert_s3_file_patterns_exist
29-
from sagemaker.exceptions import UnexpectedStatusException
3029

3130

3231
ROLE = "SageMakerRole"
@@ -48,41 +47,32 @@
4847
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
4948
reason="no ml.p2 or ml.p3 instances in this region",
5049
)
50+
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
5151
def test_framework_processing_job_with_deps(
5252
sagemaker_session,
53-
gpu_instance_type_list,
5453
tensorflow_training_latest_version,
5554
tensorflow_training_latest_py_version,
55+
**kwargs,
5656
):
57-
for i_type in gpu_instance_type_list:
58-
logging.info("Using the instance type: {}".format(i_type))
59-
with timeout.timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
60-
code_path = os.path.join(DATA_DIR, "dummy_code_bundle_with_reqs")
61-
entry_point = "main_script.py"
62-
63-
processor = TensorFlowProcessor(
64-
framework_version=tensorflow_training_latest_version,
65-
py_version=tensorflow_training_latest_py_version,
66-
role=ROLE,
67-
instance_count=1,
68-
instance_type=i_type,
69-
sagemaker_session=sagemaker_session,
70-
base_job_name="test-tensorflow",
71-
)
72-
try:
73-
processor.run(
74-
code=entry_point,
75-
source_dir=code_path,
76-
inputs=[],
77-
wait=True,
78-
)
79-
except UnexpectedStatusException as e:
80-
if "CapacityError" in str(e) and i_type != gpu_instance_type_list[-1]:
81-
logging.warning("Failure using instance type: {}. {}".format(i_type, str(e)))
82-
continue
83-
else:
84-
raise
85-
break
57+
with timeout.timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
58+
code_path = os.path.join(DATA_DIR, "dummy_code_bundle_with_reqs")
59+
entry_point = "main_script.py"
60+
61+
processor = TensorFlowProcessor(
62+
framework_version=tensorflow_training_latest_version,
63+
py_version=tensorflow_training_latest_py_version,
64+
role=ROLE,
65+
instance_count=1,
66+
instance_type=kwargs["instance_type"],
67+
sagemaker_session=sagemaker_session,
68+
base_job_name="test-tensorflow",
69+
)
70+
processor.run(
71+
code=entry_point,
72+
source_dir=code_path,
73+
inputs=[],
74+
wait=True,
75+
)
8676

8777

8878
def test_mnist_with_checkpoint_config(

0 commit comments

Comments
 (0)