Skip to content

Commit 620bb36

Browse files
knikureNamrata Madan
authored and
Namrata Madan
committed
fix: Fix image_uris.retrieve() function to return ValueError when framework is not allowed for an instance_type (aws#3716)
1 parent 5a877f8 commit 620bb36

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

src/sagemaker/image_uris.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,23 @@ def _validate_instance_deprecation(framework, instance_type, version):
381381
)
382382

383383

384-
def _validate_for_suppported_frameworks_and_instance_type(framework, instace_type):
384+
def _validate_for_suppported_frameworks_and_instance_type(framework, instance_type):
385385
"""Validate if framework is supported for the instance_type"""
386+
# Validate for Trainium allowed frameworks
386387
if (
387-
instace_type is not None
388-
and "trn" in instace_type
388+
instance_type is not None
389+
and "trn" in instance_type
389390
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
390391
):
391-
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework")
392+
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium")
393+
394+
# Validate for Graviton allowed frameowrks
395+
if (
396+
instance_type is not None
397+
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
398+
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
399+
):
400+
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")
392401

393402

394403
def config_for_framework(framework):
@@ -572,12 +581,12 @@ def _validate_arg(arg, available_options, arg_name):
572581
)
573582

574583

575-
def _validate_framework(framework, allowed_frameworks, arg_name):
584+
def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
576585
"""Checks if the framework is in the allowed frameworks, and raises a ``ValueError`` if not."""
577586
if framework not in allowed_frameworks:
578587
raise ValueError(
579588
f"Unsupported {arg_name}: {framework}. "
580-
f"Supported {arg_name}(s) for trainium instances: {allowed_frameworks}."
589+
f"Supported {arg_name}(s) for {hardware_name} instances: {allowed_frameworks}."
581590
)
582591

583592

tests/unit/sagemaker/image_uris/test_graviton.py

+19
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from sagemaker import image_uris
1616
from tests.unit.sagemaker.image_uris import expected_uris
17+
from sagemaker.fw_utils import GRAVITON_ALLOWED_FRAMEWORKS
1718

1819
import pytest
1920

@@ -90,6 +91,24 @@ def test_graviton_pytorch(graviton_pytorch_version):
9091
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)
9192

9293

94+
def _test_graviton_unsupported_framework(framework, framework_version):
95+
for region in GRAVITON_REGIONS:
96+
for instance_type in GRAVITON_INSTANCE_TYPES:
97+
with pytest.raises(ValueError) as error:
98+
image_uris.retrieve(
99+
framework, region, version=framework_version, instance_type=instance_type
100+
)
101+
expectedErr = (
102+
f"Unsupported framework: {framework}. Supported framework(s) for Graviton instances: "
103+
f"{GRAVITON_ALLOWED_FRAMEWORKS}."
104+
)
105+
assert expectedErr in str(error)
106+
107+
108+
def test_graviton_unsupported_framework():
109+
_test_graviton_unsupported_framework("autogluon", "0.6.1")
110+
111+
93112
def test_graviton_xgboost_instance_type_specified(graviton_xgboost_versions):
94113
for xgboost_version in graviton_xgboost_versions:
95114
for instance_type in GRAVITON_INSTANCE_TYPES:

tests/unit/sagemaker/image_uris/test_trainium.py

+20
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from sagemaker import image_uris
1616
from tests.unit.sagemaker.image_uris import expected_uris
1717

18+
import pytest
19+
1820
ACCOUNTS = {
1921
"af-south-1": "626614931356",
2022
"ap-east-1": "871362719292",
@@ -45,6 +47,7 @@
4547
}
4648

4749
TRAINIUM_REGIONS = ACCOUNTS.keys()
50+
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
4851

4952

5053
def _expected_trainium_framework_uri(
@@ -73,3 +76,20 @@ def _test_trainium_framework_uris(framework, version):
7376

7477
def test_trainium_pytorch(pytorch_neuron_version):
7578
_test_trainium_framework_uris("pytorch", pytorch_neuron_version)
79+
80+
81+
def _test_trainium_unsupported_framework(framework, framework_version):
82+
for region in TRAINIUM_REGIONS:
83+
with pytest.raises(ValueError) as error:
84+
image_uris.retrieve(
85+
framework, region, version=framework_version, instance_type="ml.trn1.xlarge"
86+
)
87+
expectedErr = (
88+
f"Unsupported framework: {framework}. Supported framework(s) for Trainium instances: "
89+
f"{TRAINIUM_ALLOWED_FRAMEWORKS}."
90+
)
91+
assert expectedErr in str(error)
92+
93+
94+
def test_trainium_unsupported_framework():
95+
_test_trainium_unsupported_framework("autogluon", "0.6.1")

0 commit comments

Comments
 (0)