Skip to content

fix: Fix image_uris.retrieve() function to return ValueError when framework is not allowed for an instance_type #3716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,23 @@ def _validate_instance_deprecation(framework, instance_type, version):
)


def _validate_for_suppported_frameworks_and_instance_type(framework, instace_type):
def _validate_for_suppported_frameworks_and_instance_type(framework, instance_type):
"""Validate if framework is supported for the instance_type"""
# Validate for Trainium allowed frameworks
if (
instace_type is not None
and "trn" in instace_type
instance_type is not None
and "trn" in instance_type
and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
):
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework")
_validate_framework(framework, TRAINIUM_ALLOWED_FRAMEWORKS, "framework", "Trainium")

# Validate for Graviton allowed frameowrks
if (
instance_type is not None
and _get_instance_type_family(instance_type) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
and framework not in GRAVITON_ALLOWED_FRAMEWORKS
):
_validate_framework(framework, GRAVITON_ALLOWED_FRAMEWORKS, "framework", "Graviton")


def config_for_framework(framework):
Expand Down Expand Up @@ -572,12 +581,12 @@ def _validate_arg(arg, available_options, arg_name):
)


def _validate_framework(framework, allowed_frameworks, arg_name):
def _validate_framework(framework, allowed_frameworks, arg_name, hardware_name):
"""Checks if the framework is in the allowed frameworks, and raises a ``ValueError`` if not."""
if framework not in allowed_frameworks:
raise ValueError(
f"Unsupported {arg_name}: {framework}. "
f"Supported {arg_name}(s) for trainium instances: {allowed_frameworks}."
f"Supported {arg_name}(s) for {hardware_name} instances: {allowed_frameworks}."
)


Expand Down
19 changes: 19 additions & 0 deletions tests/unit/sagemaker/image_uris/test_graviton.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris
from sagemaker.fw_utils import GRAVITON_ALLOWED_FRAMEWORKS

import pytest

Expand Down Expand Up @@ -90,6 +91,24 @@ def test_graviton_pytorch(graviton_pytorch_version):
_test_graviton_framework_uris("pytorch", graviton_pytorch_version)


def _test_graviton_unsupported_framework(framework, framework_version):
for region in GRAVITON_REGIONS:
for instance_type in GRAVITON_INSTANCE_TYPES:
with pytest.raises(ValueError) as error:
image_uris.retrieve(
framework, region, version=framework_version, instance_type=instance_type
)
expectedErr = (
f"Unsupported framework: {framework}. Supported framework(s) for Graviton instances: "
f"{GRAVITON_ALLOWED_FRAMEWORKS}."
)
assert expectedErr in str(error)


def test_graviton_unsupported_framework():
_test_graviton_unsupported_framework("autogluon", "0.6.1")


def test_graviton_xgboost_instance_type_specified(graviton_xgboost_versions):
for xgboost_version in graviton_xgboost_versions:
for instance_type in GRAVITON_INSTANCE_TYPES:
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/sagemaker/image_uris/test_trainium.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris

import pytest

ACCOUNTS = {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
Expand Down Expand Up @@ -45,6 +47,7 @@
}

TRAINIUM_REGIONS = ACCOUNTS.keys()
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"


def _expected_trainium_framework_uri(
Expand Down Expand Up @@ -73,3 +76,20 @@ def _test_trainium_framework_uris(framework, version):

def test_trainium_pytorch(pytorch_neuron_version):
_test_trainium_framework_uris("pytorch", pytorch_neuron_version)


def _test_trainium_unsupported_framework(framework, framework_version):
for region in TRAINIUM_REGIONS:
with pytest.raises(ValueError) as error:
image_uris.retrieve(
framework, region, version=framework_version, instance_type="ml.trn1.xlarge"
)
expectedErr = (
f"Unsupported framework: {framework}. Supported framework(s) for Trainium instances: "
f"{TRAINIUM_ALLOWED_FRAMEWORKS}."
)
assert expectedErr in str(error)


def test_trainium_unsupported_framework():
_test_trainium_unsupported_framework("autogluon", "0.6.1")