Skip to content

fix: gated models unsupported region #4069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def _retrieve_model_package_arn(

regional_arn = model_specs.hosting_model_package_arns.get(region)

if regional_arn is None:
raise ValueError(
f"Model package arn for '{model_id}' not supported in {region}. "
"Please try one of the following regions: "
f"{', '.join(model_specs.hosting_model_package_arns.keys())}."
)

return regional_arn

raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")
Expand Down Expand Up @@ -130,6 +137,13 @@ def _retrieve_model_package_model_artifact_s3_uri(

model_s3_uri = model_specs.training_model_package_artifact_uris.get(region)

if model_s3_uri is None:
raise ValueError(
f"Model package artifact s3 uri for '{model_id}' not supported in {region}. "
"Please try one of the following regions: "
f"{', '.join(model_specs.training_model_package_artifact_uris.keys())}."
)

return model_s3_uri

raise NotImplementedError(f"Model Package Artifact URI not supported for scope: '{scope}'")
43 changes: 43 additions & 0 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,49 @@ def test_gated_model_s3_uri(
use_compiled_model=False,
)

@mock.patch("sagemaker.utils.sagemaker_timestamp")
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_jumpstart_model_package_artifact_s3_uri_unsupported_region(
self,
mock_estimator_deploy: mock.Mock,
mock_estimator_fit: mock.Mock,
mock_estimator_init: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session_estimator: mock.Mock,
mock_session_model: mock.Mock,
mock_is_valid_model_id: mock.Mock,
mock_timestamp: mock.Mock,
):
mock_estimator_deploy.return_value = default_predictor

mock_timestamp.return_value = "8675309"

mock_is_valid_model_id.return_value = True

model_id, _ = "js-gated-artifact-trainable-model", "*"

mock_get_model_specs.side_effect = get_special_model_spec

mock_session_estimator.return_value = sagemaker_session
mock_session_model.return_value = sagemaker_session

with pytest.raises(ValueError) as e:
JumpStartEstimator(model_id=model_id, region="eu-north-1")

assert (
str(e.value) == "Model package artifact s3 uri for 'js-gated-artifact-trainable-model' "
"not supported in eu-north-1. Please try one of the following regions: "
"us-west-2, us-east-1, eu-west-1, ap-southeast-1."
)

@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,32 @@ def test_jumpstart_model_package_arn_override(
},
)

@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_jumpstart_model_package_arn_unsupported_region(
self,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_is_valid_model_id: mock.Mock,
):

mock_is_valid_model_id.return_value = True

model_id, _ = "js-model-package-arn", "*"

mock_get_model_specs.side_effect = get_special_model_spec

mock_session.return_value = MagicMock(sagemaker_config={})

with pytest.raises(ValueError) as e:
JumpStartModel(model_id=model_id, region="us-east-2")
assert (
str(e.value) == "Model package arn for 'js-model-package-arn' not supported in "
"us-east-2. Please try one of the following regions: us-west-2, us-east-1."
)


def test_jumpstart_model_requires_model_id():
with pytest.raises(ValueError):
Expand Down