diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 6a389f385f..0bbc2d5788 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -105,7 +105,7 @@ class JumpStartS3FileType(str, Enum): OPEN_WEIGHT_MANIFEST = "manifest" OPEN_WEIGHT_SPECS = "specs" - PROPRIETARY_MANIFEST = "proptietary_manifest" + PROPRIETARY_MANIFEST = "proprietary_manifest" PROPRIETARY_SPECS = "proprietary_specs" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 5f51173b24..000a9eca65 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -765,20 +765,6 @@ def validate_model_id_and_get_type( ValueError: If the script is not supported by JumpStart. """ - def _get_model_type( - model_id: str, - open_weights_model_ids: Set[str], - proprietary_model_ids: Set[str], - script: enums.JumpStartScriptScope, - ) -> Optional[enums.JumpStartModelType]: - if model_id in open_weights_model_ids: - return enums.JumpStartModelType.OPEN_WEIGHTS - if model_id in proprietary_model_ids: - if script == enums.JumpStartScriptScope.INFERENCE: - return enums.JumpStartModelType.PROPRIETARY - raise ValueError(f"Unsupported script for Marketplace models: {script}") - return None - if model_id in {None, ""}: return None if not isinstance(model_id, str): @@ -792,12 +778,19 @@ def _get_model_type( ) open_weight_model_id_set = {model.model_id for model in models_manifest_list} + if model_id in open_weight_model_id_set: + return enums.JumpStartModelType.OPEN_WEIGHTS + proprietary_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( region=region, s3_client=s3_client, model_type=enums.JumpStartModelType.PROPRIETARY ) proprietary_model_id_set = {model.model_id for model in proprietary_manifest_list} - return _get_model_type(model_id, open_weight_model_id_set, proprietary_model_id_set, script) + if model_id in proprietary_model_id_set: + if script == enums.JumpStartScriptScope.INFERENCE: + return enums.JumpStartModelType.PROPRIETARY + raise ValueError(f"Unsupported script for Proprietary models: {script}") + return None def get_jumpstart_model_id_version_from_resource_arn( diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 1ea08724b9..53e92c66a8 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6350,6 +6350,12 @@ "py_version": "py3", }, "training_artifact_key": "pytorch-training/train-pytorch-eqa-bert-base-cased.tar.gz", + "predictor_specs": { + "supported_content_types": ["application/x-image"], + "supported_accept_types": ["application/json;verbose", "application/json"], + "default_content_type": "application/x-image", + "default_accept_type": "application/json", + }, "inference_environment_variables": [ { "name": "SAGEMAKER_PROGRAM", diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index c4a96d4120..8b00eb5bcd 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -25,6 +25,10 @@ DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) +from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, +) from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag, JumpStartModelType from sagemaker.jumpstart.model import JumpStartModel @@ -41,6 +45,7 @@ overwrite_dictionary, get_special_model_spec_for_inference_component_based_endpoint, get_prototype_manifest, + get_prototype_model_spec, ) import boto3 @@ -1365,6 +1370,50 @@ def test_jumpstart_model_session( assert len(s3_clients) == 1 assert list(s3_clients)[0] == session.s3_client + @mock.patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", + }, + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.model.Model.deploy") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_model_local_mode( + self, + mock_model_deploy: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_get_manifest: mock.Mock, + ): + mock_get_model_specs.side_effect = get_prototype_model_spec + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) + mock_model_deploy.return_value = default_predictor + + model_id, _ = "pytorch-eqa-bert-base-cased", "*" + + mock_session.return_value = sagemaker_session + + model = JumpStartModel(model_id=model_id, instance_type="ml.p2.xlarge") + + model.deploy() + + mock_model_deploy.assert_called_once_with( + initial_instance_count=1, + instance_type="ml.p2.xlarge", + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"}, + ], + wait=True, + endpoint_logging=False, + ) + def test_jumpstart_model_requires_model_id(): with pytest.raises(ValueError): diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index c81d5639e5..ded42e2dc4 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -40,7 +40,7 @@ VulnerableJumpStartModelError, ) from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId -from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_prototype_manifest from mock import MagicMock @@ -1178,7 +1178,7 @@ def test_mime_type_enum_from_str(): class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_validate_model_id_and_get_type_true( + def test_validate_model_id_and_get_type_open_weights( self, mock_get_model_specs: Mock, mock_get_manifest: Mock, @@ -1197,11 +1197,11 @@ def test_validate_model_id_and_get_type_true( ) with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): - self.assertTrue(utils.validate_model_id_and_get_type("bee")) + assert utils.validate_model_id_and_get_type("bee") == JumpStartModelType.OPEN_WEIGHTS mock_get_manifest.assert_called_with( region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value, - model_type=JumpStartModelType.PROPRIETARY, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) mock_get_model_specs.assert_not_called() @@ -1215,25 +1215,30 @@ def test_validate_model_id_and_get_type_true( ] mock_get_model_specs.return_value = Mock(training_supported=True) - self.assertTrue( + self.assertIsNone( + utils.validate_model_id_and_get_type( + "invalid", script=JumpStartScriptScope.TRAINING + ) + ) + assert ( utils.validate_model_id_and_get_type("bee", script=JumpStartScriptScope.TRAINING) + == JumpStartModelType.OPEN_WEIGHTS ) + mock_get_manifest.assert_called_with( region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value, - model_type=JumpStartModelType.PROPRIETARY, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") - def test_validate_model_id_and_get_type_false( + def test_validate_model_id_and_get_type_invalid( self, mock_get_model_specs: Mock, mock_get_manifest: Mock ): - mock_get_manifest.return_value = [ - Mock(model_id="ay"), - Mock(model_id="bee"), - Mock(model_id="see"), - ] + mock_get_manifest.side_effect = ( + lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) + ) mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client @@ -1244,10 +1249,10 @@ def test_validate_model_id_and_get_type_false( with patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type", patched): - self.assertFalse(utils.validate_model_id_and_get_type("dee")) - self.assertFalse(utils.validate_model_id_and_get_type("")) - self.assertFalse(utils.validate_model_id_and_get_type(None)) - self.assertFalse(utils.validate_model_id_and_get_type(set())) + self.assertIsNone(utils.validate_model_id_and_get_type("dee")) + self.assertIsNone(utils.validate_model_id_and_get_type("")) + self.assertIsNone(utils.validate_model_id_and_get_type(None)) + self.assertIsNone(utils.validate_model_id_and_get_type(set())) mock_get_manifest.assert_called() @@ -1256,53 +1261,44 @@ def test_validate_model_id_and_get_type_false( mock_get_manifest.reset_mock() mock_get_model_specs.reset_mock() - mock_get_manifest.return_value = [ - Mock(model_id="ay"), - Mock(model_id="bee"), - Mock(model_id="see"), - ] - self.assertFalse( - utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) + assert ( + utils.validate_model_id_and_get_type("ai21-summarization") + == JumpStartModelType.PROPRIETARY ) + self.assertIsNone(utils.validate_model_id_and_get_type("ai21-summarization-2")) + mock_get_manifest.assert_called_with( region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value, model_type=JumpStartModelType.PROPRIETARY, ) - mock_get_manifest.reset_mock() - - self.assertFalse( + self.assertIsNone( utils.validate_model_id_and_get_type("dee", script=JumpStartScriptScope.TRAINING) ) - self.assertFalse( + self.assertIsNone( utils.validate_model_id_and_get_type("", script=JumpStartScriptScope.TRAINING) ) - self.assertFalse( + self.assertIsNone( utils.validate_model_id_and_get_type(None, script=JumpStartScriptScope.TRAINING) ) - self.assertFalse( + self.assertIsNone( utils.validate_model_id_and_get_type(set(), script=JumpStartScriptScope.TRAINING) ) - mock_get_model_specs.assert_not_called() + assert ( + utils.validate_model_id_and_get_type("pytorch-eqa-bert-base-cased") + == JumpStartModelType.OPEN_WEIGHTS + ) mock_get_manifest.assert_called_with( region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value, - model_type=JumpStartModelType.PROPRIETARY, + model_type=JumpStartModelType.OPEN_WEIGHTS, ) - mock_get_manifest.reset_mock() - mock_get_model_specs.reset_mock() - - mock_get_model_specs.return_value = Mock(training_supported=False) - self.assertTrue( - utils.validate_model_id_and_get_type("ay", script=JumpStartScriptScope.TRAINING) - ) - mock_get_manifest.assert_called_with( - region=JUMPSTART_DEFAULT_REGION_NAME, - s3_client=mock_s3_client_value, - model_type=JumpStartModelType.PROPRIETARY, + with pytest.raises(ValueError): + utils.validate_model_id_and_get_type( + "ai21-summarization", script=JumpStartScriptScope.TRAINING )