Skip to content

Commit 243ffa4

Browse files
authored
fix: gated models unsupported region (#4069)
1 parent 1f3754d commit 243ffa4

File tree

3 files changed

+83
-0
lines changed

3 files changed

+83
-0
lines changed

src/sagemaker/jumpstart/artifacts/model_packages.py

+14
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ def _retrieve_model_package_arn(
8080

8181
regional_arn = model_specs.hosting_model_package_arns.get(region)
8282

83+
if regional_arn is None:
84+
raise ValueError(
85+
f"Model package arn for '{model_id}' not supported in {region}. "
86+
"Please try one of the following regions: "
87+
f"{', '.join(model_specs.hosting_model_package_arns.keys())}."
88+
)
89+
8390
return regional_arn
8491

8592
raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")
@@ -143,6 +150,13 @@ def _retrieve_model_package_model_artifact_s3_uri(
143150

144151
model_s3_uri = model_specs.training_model_package_artifact_uris.get(region)
145152

153+
if model_s3_uri is None:
154+
raise ValueError(
155+
f"Model package artifact s3 uri for '{model_id}' not supported in {region}. "
156+
"Please try one of the following regions: "
157+
f"{', '.join(model_specs.training_model_package_artifact_uris.keys())}."
158+
)
159+
146160
return model_s3_uri
147161

148162
raise NotImplementedError(f"Model Package Artifact URI not supported for scope: '{scope}'")

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

+43
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,49 @@ def test_gated_model_s3_uri(
353353
use_compiled_model=False,
354354
)
355355

356+
@mock.patch("sagemaker.utils.sagemaker_timestamp")
357+
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
358+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
359+
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
360+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
361+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
362+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
363+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
364+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
365+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
366+
def test_jumpstart_model_package_artifact_s3_uri_unsupported_region(
367+
self,
368+
mock_estimator_deploy: mock.Mock,
369+
mock_estimator_fit: mock.Mock,
370+
mock_estimator_init: mock.Mock,
371+
mock_get_model_specs: mock.Mock,
372+
mock_session_estimator: mock.Mock,
373+
mock_session_model: mock.Mock,
374+
mock_is_valid_model_id: mock.Mock,
375+
mock_timestamp: mock.Mock,
376+
):
377+
mock_estimator_deploy.return_value = default_predictor
378+
379+
mock_timestamp.return_value = "8675309"
380+
381+
mock_is_valid_model_id.return_value = True
382+
383+
model_id, _ = "js-gated-artifact-trainable-model", "*"
384+
385+
mock_get_model_specs.side_effect = get_special_model_spec
386+
387+
mock_session_estimator.return_value = sagemaker_session
388+
mock_session_model.return_value = sagemaker_session
389+
390+
with pytest.raises(ValueError) as e:
391+
JumpStartEstimator(model_id=model_id, region="eu-north-1")
392+
393+
assert (
394+
str(e.value) == "Model package artifact s3 uri for 'js-gated-artifact-trainable-model' "
395+
"not supported in eu-north-1. Please try one of the following regions: "
396+
"us-west-2, us-east-1, eu-west-1, ap-southeast-1."
397+
)
398+
356399
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
357400
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
358401
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")

tests/unit/sagemaker/jumpstart/model/test_model.py

+26
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,32 @@ def test_jumpstart_model_package_arn_override(
652652
},
653653
)
654654

655+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
656+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
657+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
658+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
659+
def test_jumpstart_model_package_arn_unsupported_region(
660+
self,
661+
mock_get_model_specs: mock.Mock,
662+
mock_session: mock.Mock,
663+
mock_is_valid_model_id: mock.Mock,
664+
):
665+
666+
mock_is_valid_model_id.return_value = True
667+
668+
model_id, _ = "js-model-package-arn", "*"
669+
670+
mock_get_model_specs.side_effect = get_special_model_spec
671+
672+
mock_session.return_value = MagicMock(sagemaker_config={})
673+
674+
with pytest.raises(ValueError) as e:
675+
JumpStartModel(model_id=model_id, region="us-east-2")
676+
assert (
677+
str(e.value) == "Model package arn for 'js-model-package-arn' not supported in "
678+
"us-east-2. Please try one of the following regions: us-west-2, us-east-1."
679+
)
680+
655681

656682
def test_jumpstart_model_requires_model_id():
657683
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)