From 4c0453883b72276f73e9bb63b9088ee6c52c04cd Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 12 Oct 2023 15:40:07 +0000 Subject: [PATCH] feat: jumpstart model package arn instance type variants --- .../jumpstart/artifacts/model_packages.py | 14 +++ src/sagemaker/jumpstart/factory/model.py | 1 + src/sagemaker/jumpstart/types.py | 40 +++++-- tests/unit/sagemaker/jumpstart/constants.py | 37 +++++- .../sagemaker/jumpstart/test_artifacts.py | 113 +++++++++++++++++- 5 files changed, 190 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 540b16bb51..bd0ae365d9 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -29,6 +29,7 @@ def _retrieve_model_package_arn( model_id: str, model_version: str, + instance_type: Optional[str], region: Optional[str], scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -42,6 +43,8 @@ def _retrieve_model_package_arn( retrieve the model package arn. model_version (str): Version of the JumpStart model for which to retrieve the model package arn. + instance_type (Optional[str]): An instance type to optionally supply in order to get an arn + specific for the instance type. region (Optional[str]): Region for which to retrieve the model package arn. scope (Optional[str]): Scope for which to retrieve the model package arn. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -75,6 +78,17 @@ def _retrieve_model_package_arn( if scope == JumpStartScriptScope.INFERENCE: + instance_specific_arn: Optional[str] = ( + model_specs.hosting_instance_type_variants.get_model_package_arn( + region=region, instance_type=instance_type + ) + if getattr(model_specs, "hosting_instance_type_variants", None) is not None + else None + ) + + if instance_specific_arn is not None: + return instance_specific_arn + if model_specs.hosting_model_package_arns is None: return None diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 8b28059f7c..bfb051d15a 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -329,6 +329,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn( model_id=kwargs.model_id, model_version=kwargs.model_version, + instance_type=kwargs.instance_type, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index a4a7617aac..96e2cbb5f5 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -436,6 +436,28 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic def get_image_uri(self, instance_type: str, region: str) -> Optional[str]: """Returns image uri from instance type and region. + Returns None if no instance type is available or found. + None is also returned if the metadata is improperly formatted. + """ + return self._get_regional_property( + instance_type=instance_type, region=region, property_name="image_uri" + ) + + def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str]: + """Returns model package arn from instance type and region. + + Returns None if no instance type is available or found. + None is also returned if the metadata is improperly formatted. + """ + return self._get_regional_property( + instance_type=instance_type, region=region, property_name="model_package_arn" + ) + + def _get_regional_property( + self, instance_type: str, region: str, property_name: str + ) -> Optional[str]: + """Returns regional property from instance type and region. + Returns None if no instance type is available or found. None is also returned if the metadata is improperly formatted. """ @@ -443,35 +465,35 @@ def get_image_uri(self, instance_type: str, region: str) -> Optional[str]: if None in [self.regional_aliases, self.variants]: return None - image_uri_alias: Optional[str] = ( - self.variants.get(instance_type, {}).get("regional_properties", {}).get("image_uri") + regional_property_alias: Optional[str] = ( + self.variants.get(instance_type, {}).get("regional_properties", {}).get(property_name) ) - if image_uri_alias is None: + if regional_property_alias is None: instance_type_family = get_instance_type_family(instance_type) if instance_type_family in {"", None}: return None - image_uri_alias = ( + regional_property_alias = ( self.variants.get(instance_type_family, {}) .get("regional_properties", {}) - .get("image_uri") + .get(property_name) ) - if image_uri_alias is None or len(image_uri_alias) == 0: + if regional_property_alias is None or len(regional_property_alias) == 0: return None - if not image_uri_alias.startswith("$"): + if not regional_property_alias.startswith("$"): # No leading '$' indicates bad metadata. # There are tests to ensure this never happens. # However, to allow for fallback options in the unlikely event # of a regression, we do not raise an exception here. - # We return None, indicating the image uri does not exist. + # We return None, indicating the field does not exist. return None if region not in self.regional_aliases: return None - alias_value = self.regional_aliases[region].get(image_uri_alias[1:], None) + alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) return alias_value diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index b65167165c..a321b4b9f8 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -181,6 +181,10 @@ "min_sdk_version": "2.49.0", "training_supported": True, "incremental_training_supported": True, + "hosting_model_package_arns": { + "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll" + "ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + }, "hosting_ecr_specs": { "framework": "pytorch", "framework_version": "1.5.0", @@ -192,13 +196,35 @@ "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + "inf_model_package_arn": "us-west-2/blah/blah/blah/inf", + "gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu", } }, "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "p4": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "p2": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p3": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "p4": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, + "g4dn": { + "regional_properties": { + "image_uri": "$gpu_image_uri", + "model_package_arn": "$gpu_model_package_arn", + } + }, "m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, "ml.g5.48xlarge": { @@ -207,6 +233,8 @@ "ml.g5.12xlarge": { "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} }, + "inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, + "inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}}, }, }, "training_ecr_specs": { @@ -224,7 +252,6 @@ "training_model_package_artifact_uris": None, "deprecate_warn_message": None, "deprecated_message": None, - "hosting_model_package_arns": None, "hosting_eula_key": None, "hyperparameters": [ { diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 815db88f98..497691993c 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -12,13 +12,18 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import unittest +from unittest.mock import Mock from mock.mock import patch +import pytest from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn +from sagemaker.jumpstart.enums import JumpStartScriptScope -from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec +from tests.unit.sagemaker.workflow.conftest import mock_client @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -129,3 +134,109 @@ def test_estimator_fit_kwargs(self, patched_get_model_specs): ) assert kwargs == {"some-estimator-fit-key": "some-estimator-fit-value"} + + +class RetrieveModelPackageArnTest(unittest.TestCase): + + mock_session = Mock(s3_client=mock_client) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + def test_retrieve_model_package_arn(self, patched_get_model_specs): + patched_get_model_specs.side_effect = get_special_model_spec + + model_id = "variant-model" + region = "us-west-2" + + assert ( + _retrieve_model_package_arn( + region=region, + model_id=model_id, + scope=JumpStartScriptScope.INFERENCE, + model_version="*", + sagemaker_session=self.mock_session, + instance_type="ml.p2.48xlarge", + ) + == "us-west-2/blah/blah/blah/gpu" + ) + + assert ( + _retrieve_model_package_arn( + region=region, + model_id=model_id, + scope=JumpStartScriptScope.INFERENCE, + model_version="*", + sagemaker_session=self.mock_session, + instance_type="ml.p4.2xlarge", + ) + == "us-west-2/blah/blah/blah/gpu" + ) + + assert ( + _retrieve_model_package_arn( + region=region, + model_id=model_id, + scope=JumpStartScriptScope.INFERENCE, + model_version="*", + sagemaker_session=self.mock_session, + instance_type="ml.inf1.2xlarge", + ) + == "us-west-2/blah/blah/blah/inf" + ) + + assert ( + _retrieve_model_package_arn( + region=region, + model_id=model_id, + scope=JumpStartScriptScope.INFERENCE, + model_version="*", + sagemaker_session=self.mock_session, + instance_type="ml.inf2.12xlarge", + ) + == "us-west-2/blah/blah/blah/inf" + ) + + assert ( + _retrieve_model_package_arn( + region=region, + model_id=model_id, + scope=JumpStartScriptScope.INFERENCE, + model_version="*", + sagemaker_session=self.mock_session, + instance_type="ml.afasfasf.12xlarge", + ) + == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + ) + + assert ( + _retrieve_model_package_arn( + region=region, + model_id=model_id, + scope=JumpStartScriptScope.INFERENCE, + model_version="*", + sagemaker_session=self.mock_session, + instance_type="ml.m2.12xlarge", + ) + == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + ) + + assert ( + _retrieve_model_package_arn( + region=region, + model_id=model_id, + scope=JumpStartScriptScope.INFERENCE, + model_version="*", + sagemaker_session=self.mock_session, + instance_type="nobodycares", + ) + == "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c" + ) + + with pytest.raises(ValueError): + _retrieve_model_package_arn( + region="cn-north-1", + model_id=model_id, + scope=JumpStartScriptScope.INFERENCE, + model_version="*", + sagemaker_session=self.mock_session, + instance_type="ml.p2.12xlarge", + )