Skip to content

feat: jumpstart model package arn instance type variants #4186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/sagemaker/jumpstart/artifacts/model_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
def _retrieve_model_package_arn(
model_id: str,
model_version: str,
instance_type: Optional[str],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to do the same for _retrieve_model_package_model_artifact_s3_uri()?

region: Optional[str],
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
Expand All @@ -42,6 +43,8 @@ def _retrieve_model_package_arn(
retrieve the model package arn.
model_version (str): Version of the JumpStart model for which to retrieve the
model package arn.
instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
specific for the instance type.
region (Optional[str]): Region for which to retrieve the model package arn.
scope (Optional[str]): Scope for which to retrieve the model package arn.
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -75,6 +78,17 @@ def _retrieve_model_package_arn(

if scope == JumpStartScriptScope.INFERENCE:

instance_specific_arn: Optional[str] = (
model_specs.hosting_instance_type_variants.get_model_package_arn(
region=region, instance_type=instance_type
)
if getattr(model_specs, "hosting_instance_type_variants", None) is not None
else None
)

if instance_specific_arn is not None:
return instance_specific_arn

if model_specs.hosting_model_package_arns is None:
return None

Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
model_id=kwargs.model_id,
model_version=kwargs.model_version,
instance_type=kwargs.instance_type,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
Expand Down
40 changes: 31 additions & 9 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,42 +436,64 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
"""Returns image uri from instance type and region.

Returns None if no instance type is available or found.
None is also returned if the metadata is improperly formatted.
"""
return self._get_regional_property(
instance_type=instance_type, region=region, property_name="image_uri"
)

def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str]:
"""Returns model package arn from instance type and region.

Returns None if no instance type is available or found.
None is also returned if the metadata is improperly formatted.
"""
return self._get_regional_property(
instance_type=instance_type, region=region, property_name="model_package_arn"
)

def _get_regional_property(
self, instance_type: str, region: str, property_name: str
) -> Optional[str]:
"""Returns regional property from instance type and region.

Returns None if no instance type is available or found.
None is also returned if the metadata is improperly formatted.
"""

if None in [self.regional_aliases, self.variants]:
return None

image_uri_alias: Optional[str] = (
self.variants.get(instance_type, {}).get("regional_properties", {}).get("image_uri")
regional_property_alias: Optional[str] = (
self.variants.get(instance_type, {}).get("regional_properties", {}).get(property_name)
)
if image_uri_alias is None:
if regional_property_alias is None:
instance_type_family = get_instance_type_family(instance_type)

if instance_type_family in {"", None}:
return None

image_uri_alias = (
regional_property_alias = (
self.variants.get(instance_type_family, {})
.get("regional_properties", {})
.get("image_uri")
.get(property_name)
)

if image_uri_alias is None or len(image_uri_alias) == 0:
if regional_property_alias is None or len(regional_property_alias) == 0:
return None

if not image_uri_alias.startswith("$"):
if not regional_property_alias.startswith("$"):
# No leading '$' indicates bad metadata.
# There are tests to ensure this never happens.
# However, to allow for fallback options in the unlikely event
# of a regression, we do not raise an exception here.
# We return None, indicating the image uri does not exist.
# We return None, indicating the field does not exist.
return None

if region not in self.regional_aliases:
return None
alias_value = self.regional_aliases[region].get(image_uri_alias[1:], None)
alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None)
return alias_value


Expand Down
37 changes: 32 additions & 5 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@
"min_sdk_version": "2.49.0",
"training_supported": True,
"incremental_training_supported": True,
"hosting_model_package_arns": {
"us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll"
"ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
},
"hosting_ecr_specs": {
"framework": "pytorch",
"framework_version": "1.5.0",
Expand All @@ -192,13 +196,35 @@
"gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04",
"cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah",
"inf_model_package_arn": "us-west-2/blah/blah/blah/inf",
"gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu",
}
},
"variants": {
"p2": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
"p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
"p4": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
"g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
"p2": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"p3": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"p4": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"g4dn": {
"regional_properties": {
"image_uri": "$gpu_image_uri",
"model_package_arn": "$gpu_model_package_arn",
}
},
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"ml.g5.48xlarge": {
Expand All @@ -207,6 +233,8 @@
"ml.g5.12xlarge": {
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
},
"inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
"inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
},
},
"training_ecr_specs": {
Expand All @@ -224,7 +252,6 @@
"training_model_package_artifact_uris": None,
"deprecate_warn_message": None,
"deprecated_message": None,
"hosting_model_package_arns": None,
"hosting_eula_key": None,
"hyperparameters": [
{
Expand Down
113 changes: 112 additions & 1 deletion tests/unit/sagemaker/jumpstart/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import unittest
from unittest.mock import Mock


from mock.mock import patch
import pytest

from sagemaker.jumpstart import artifacts
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn
from sagemaker.jumpstart.enums import JumpStartScriptScope

from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
from tests.unit.sagemaker.workflow.conftest import mock_client


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
Expand Down Expand Up @@ -129,3 +134,109 @@ def test_estimator_fit_kwargs(self, patched_get_model_specs):
)

assert kwargs == {"some-estimator-fit-key": "some-estimator-fit-value"}


class RetrieveModelPackageArnTest(unittest.TestCase):

mock_session = Mock(s3_client=mock_client)

@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_retrieve_model_package_arn(self, patched_get_model_specs):
patched_get_model_specs.side_effect = get_special_model_spec

model_id = "variant-model"
region = "us-west-2"

assert (
_retrieve_model_package_arn(
region=region,
model_id=model_id,
scope=JumpStartScriptScope.INFERENCE,
model_version="*",
sagemaker_session=self.mock_session,
instance_type="ml.p2.48xlarge",
)
== "us-west-2/blah/blah/blah/gpu"
)

assert (
_retrieve_model_package_arn(
region=region,
model_id=model_id,
scope=JumpStartScriptScope.INFERENCE,
model_version="*",
sagemaker_session=self.mock_session,
instance_type="ml.p4.2xlarge",
)
== "us-west-2/blah/blah/blah/gpu"
)

assert (
_retrieve_model_package_arn(
region=region,
model_id=model_id,
scope=JumpStartScriptScope.INFERENCE,
model_version="*",
sagemaker_session=self.mock_session,
instance_type="ml.inf1.2xlarge",
)
== "us-west-2/blah/blah/blah/inf"
)

assert (
_retrieve_model_package_arn(
region=region,
model_id=model_id,
scope=JumpStartScriptScope.INFERENCE,
model_version="*",
sagemaker_session=self.mock_session,
instance_type="ml.inf2.12xlarge",
)
== "us-west-2/blah/blah/blah/inf"
)

assert (
_retrieve_model_package_arn(
region=region,
model_id=model_id,
scope=JumpStartScriptScope.INFERENCE,
model_version="*",
sagemaker_session=self.mock_session,
instance_type="ml.afasfasf.12xlarge",
)
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
)

assert (
_retrieve_model_package_arn(
region=region,
model_id=model_id,
scope=JumpStartScriptScope.INFERENCE,
model_version="*",
sagemaker_session=self.mock_session,
instance_type="ml.m2.12xlarge",
)
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
)

assert (
_retrieve_model_package_arn(
region=region,
model_id=model_id,
scope=JumpStartScriptScope.INFERENCE,
model_version="*",
sagemaker_session=self.mock_session,
instance_type="nobodycares",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's tough :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait till you receive the bill :)

)
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
)

with pytest.raises(ValueError):
_retrieve_model_package_arn(
region="cn-north-1",
model_id=model_id,
scope=JumpStartScriptScope.INFERENCE,
model_version="*",
sagemaker_session=self.mock_session,
instance_type="ml.p2.12xlarge",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming, we are going to release this in the SDK and bump SE version before 10/18 to make the last Studio cutoff before ReInvent right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe so