diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 0fbf1d74b3..da3513cc1d 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -72,6 +72,13 @@ def _retrieve_model_package_arn( regional_arn = model_specs.hosting_model_package_arns.get(region) + if regional_arn is None: + raise ValueError( + f"Model package arn for '{model_id}' not supported in {region}. " + "Please try one of the following regions: " + f"{', '.join(model_specs.hosting_model_package_arns.keys())}." + ) + return regional_arn raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'") @@ -130,6 +137,13 @@ def _retrieve_model_package_model_artifact_s3_uri( model_s3_uri = model_specs.training_model_package_artifact_uris.get(region) + if model_s3_uri is None: + raise ValueError( + f"Model package artifact s3 uri for '{model_id}' not supported in {region}. " + "Please try one of the following regions: " + f"{', '.join(model_specs.training_model_package_artifact_uris.keys())}." + ) + return model_s3_uri raise NotImplementedError(f"Model Package Artifact URI not supported for scope: '{scope}'") diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 198839475b..60420be9c2 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -353,6 +353,49 @@ def test_gated_model_s3_uri( use_compiled_model=False, ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_model_package_artifact_s3_uri_unsupported_region( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_timestamp: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + + mock_timestamp.return_value = "8675309" + + mock_is_valid_model_id.return_value = True + + model_id, _ = "js-gated-artifact-trainable-model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + with pytest.raises(ValueError) as e: + JumpStartEstimator(model_id=model_id, region="eu-north-1") + + assert ( + str(e.value) == "Model package artifact s3 uri for 'js-gated-artifact-trainable-model' " + "not supported in eu-north-1. Please try one of the following regions: " + "us-west-2, us-east-1, eu-west-1, ap-southeast-1." + ) + @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index fb7698741e..5194a18e16 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -646,6 +646,32 @@ def test_jumpstart_model_package_arn_override( }, ) + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_model_package_arn_unsupported_region( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_is_valid_model_id: mock.Mock, + ): + + mock_is_valid_model_id.return_value = True + + model_id, _ = "js-model-package-arn", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session.return_value = MagicMock(sagemaker_config={}) + + with pytest.raises(ValueError) as e: + JumpStartModel(model_id=model_id, region="us-east-2") + assert ( + str(e.value) == "Model package arn for 'js-model-package-arn' not supported in " + "us-east-2. Please try one of the following regions: us-west-2, us-east-1." + ) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError):