Skip to content

Commit a44a755

Browse files
authored
feat: jumpstart model package arn instance type variants (#4186)
1 parent bfc63d2 commit a44a755

File tree

5 files changed

+190
-15
lines changed

5 files changed

+190
-15
lines changed

src/sagemaker/jumpstart/artifacts/model_packages.py

+14
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
def _retrieve_model_package_arn(
3030
model_id: str,
3131
model_version: str,
32+
instance_type: Optional[str],
3233
region: Optional[str],
3334
scope: Optional[str] = None,
3435
tolerate_vulnerable_model: bool = False,
@@ -42,6 +43,8 @@ def _retrieve_model_package_arn(
4243
retrieve the model package arn.
4344
model_version (str): Version of the JumpStart model for which to retrieve the
4445
model package arn.
46+
instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
47+
specific for the instance type.
4548
region (Optional[str]): Region for which to retrieve the model package arn.
4649
scope (Optional[str]): Scope for which to retrieve the model package arn.
4750
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -75,6 +78,17 @@ def _retrieve_model_package_arn(
7578

7679
if scope == JumpStartScriptScope.INFERENCE:
7780

81+
instance_specific_arn: Optional[str] = (
82+
model_specs.hosting_instance_type_variants.get_model_package_arn(
83+
region=region, instance_type=instance_type
84+
)
85+
if getattr(model_specs, "hosting_instance_type_variants", None) is not None
86+
else None
87+
)
88+
89+
if instance_specific_arn is not None:
90+
return instance_specific_arn
91+
7892
if model_specs.hosting_model_package_arns is None:
7993
return None
8094

src/sagemaker/jumpstart/factory/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt
329329
model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn(
330330
model_id=kwargs.model_id,
331331
model_version=kwargs.model_version,
332+
instance_type=kwargs.instance_type,
332333
scope=JumpStartScriptScope.INFERENCE,
333334
region=kwargs.region,
334335
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,

src/sagemaker/jumpstart/types.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -436,42 +436,64 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
436436
def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
437437
"""Returns image uri from instance type and region.
438438
439+
Returns None if no instance type is available or found.
440+
None is also returned if the metadata is improperly formatted.
441+
"""
442+
return self._get_regional_property(
443+
instance_type=instance_type, region=region, property_name="image_uri"
444+
)
445+
446+
def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str]:
447+
"""Returns model package arn from instance type and region.
448+
449+
Returns None if no instance type is available or found.
450+
None is also returned if the metadata is improperly formatted.
451+
"""
452+
return self._get_regional_property(
453+
instance_type=instance_type, region=region, property_name="model_package_arn"
454+
)
455+
456+
def _get_regional_property(
457+
self, instance_type: str, region: str, property_name: str
458+
) -> Optional[str]:
459+
"""Returns regional property from instance type and region.
460+
439461
Returns None if no instance type is available or found.
440462
None is also returned if the metadata is improperly formatted.
441463
"""
442464

443465
if None in [self.regional_aliases, self.variants]:
444466
return None
445467

446-
image_uri_alias: Optional[str] = (
447-
self.variants.get(instance_type, {}).get("regional_properties", {}).get("image_uri")
468+
regional_property_alias: Optional[str] = (
469+
self.variants.get(instance_type, {}).get("regional_properties", {}).get(property_name)
448470
)
449-
if image_uri_alias is None:
471+
if regional_property_alias is None:
450472
instance_type_family = get_instance_type_family(instance_type)
451473

452474
if instance_type_family in {"", None}:
453475
return None
454476

455-
image_uri_alias = (
477+
regional_property_alias = (
456478
self.variants.get(instance_type_family, {})
457479
.get("regional_properties", {})
458-
.get("image_uri")
480+
.get(property_name)
459481
)
460482

461-
if image_uri_alias is None or len(image_uri_alias) == 0:
483+
if regional_property_alias is None or len(regional_property_alias) == 0:
462484
return None
463485

464-
if not image_uri_alias.startswith("$"):
486+
if not regional_property_alias.startswith("$"):
465487
# No leading '$' indicates bad metadata.
466488
# There are tests to ensure this never happens.
467489
# However, to allow for fallback options in the unlikely event
468490
# of a regression, we do not raise an exception here.
469-
# We return None, indicating the image uri does not exist.
491+
# We return None, indicating the field does not exist.
470492
return None
471493

472494
if region not in self.regional_aliases:
473495
return None
474-
alias_value = self.regional_aliases[region].get(image_uri_alias[1:], None)
496+
alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None)
475497
return alias_value
476498

477499

tests/unit/sagemaker/jumpstart/constants.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@
181181
"min_sdk_version": "2.49.0",
182182
"training_supported": True,
183183
"incremental_training_supported": True,
184+
"hosting_model_package_arns": {
185+
"us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/ll"
186+
"ama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
187+
},
184188
"hosting_ecr_specs": {
185189
"framework": "pytorch",
186190
"framework_version": "1.5.0",
@@ -192,13 +196,35 @@
192196
"gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
193197
"huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04",
194198
"cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah",
199+
"inf_model_package_arn": "us-west-2/blah/blah/blah/inf",
200+
"gpu_model_package_arn": "us-west-2/blah/blah/blah/gpu",
195201
}
196202
},
197203
"variants": {
198-
"p2": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
199-
"p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
200-
"p4": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
201-
"g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}},
204+
"p2": {
205+
"regional_properties": {
206+
"image_uri": "$gpu_image_uri",
207+
"model_package_arn": "$gpu_model_package_arn",
208+
}
209+
},
210+
"p3": {
211+
"regional_properties": {
212+
"image_uri": "$gpu_image_uri",
213+
"model_package_arn": "$gpu_model_package_arn",
214+
}
215+
},
216+
"p4": {
217+
"regional_properties": {
218+
"image_uri": "$gpu_image_uri",
219+
"model_package_arn": "$gpu_model_package_arn",
220+
}
221+
},
222+
"g4dn": {
223+
"regional_properties": {
224+
"image_uri": "$gpu_image_uri",
225+
"model_package_arn": "$gpu_model_package_arn",
226+
}
227+
},
202228
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
203229
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
204230
"ml.g5.48xlarge": {
@@ -207,6 +233,8 @@
207233
"ml.g5.12xlarge": {
208234
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}}
209235
},
236+
"inf1": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
237+
"inf2": {"regional_properties": {"model_package_arn": "$inf_model_package_arn"}},
210238
},
211239
},
212240
"training_ecr_specs": {
@@ -224,7 +252,6 @@
224252
"training_model_package_artifact_uris": None,
225253
"deprecate_warn_message": None,
226254
"deprecated_message": None,
227-
"hosting_model_package_arns": None,
228255
"hosting_eula_key": None,
229256
"hyperparameters": [
230257
{

tests/unit/sagemaker/jumpstart/test_artifacts.py

+112-1
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
import unittest
15+
from unittest.mock import Mock
1516

1617

1718
from mock.mock import patch
19+
import pytest
1820

1921
from sagemaker.jumpstart import artifacts
22+
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn
23+
from sagemaker.jumpstart.enums import JumpStartScriptScope
2024

21-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
25+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
26+
from tests.unit.sagemaker.workflow.conftest import mock_client
2227

2328

2429
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -129,3 +134,109 @@ def test_estimator_fit_kwargs(self, patched_get_model_specs):
129134
)
130135

131136
assert kwargs == {"some-estimator-fit-key": "some-estimator-fit-value"}
137+
138+
139+
class RetrieveModelPackageArnTest(unittest.TestCase):
140+
141+
mock_session = Mock(s3_client=mock_client)
142+
143+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
144+
def test_retrieve_model_package_arn(self, patched_get_model_specs):
145+
patched_get_model_specs.side_effect = get_special_model_spec
146+
147+
model_id = "variant-model"
148+
region = "us-west-2"
149+
150+
assert (
151+
_retrieve_model_package_arn(
152+
region=region,
153+
model_id=model_id,
154+
scope=JumpStartScriptScope.INFERENCE,
155+
model_version="*",
156+
sagemaker_session=self.mock_session,
157+
instance_type="ml.p2.48xlarge",
158+
)
159+
== "us-west-2/blah/blah/blah/gpu"
160+
)
161+
162+
assert (
163+
_retrieve_model_package_arn(
164+
region=region,
165+
model_id=model_id,
166+
scope=JumpStartScriptScope.INFERENCE,
167+
model_version="*",
168+
sagemaker_session=self.mock_session,
169+
instance_type="ml.p4.2xlarge",
170+
)
171+
== "us-west-2/blah/blah/blah/gpu"
172+
)
173+
174+
assert (
175+
_retrieve_model_package_arn(
176+
region=region,
177+
model_id=model_id,
178+
scope=JumpStartScriptScope.INFERENCE,
179+
model_version="*",
180+
sagemaker_session=self.mock_session,
181+
instance_type="ml.inf1.2xlarge",
182+
)
183+
== "us-west-2/blah/blah/blah/inf"
184+
)
185+
186+
assert (
187+
_retrieve_model_package_arn(
188+
region=region,
189+
model_id=model_id,
190+
scope=JumpStartScriptScope.INFERENCE,
191+
model_version="*",
192+
sagemaker_session=self.mock_session,
193+
instance_type="ml.inf2.12xlarge",
194+
)
195+
== "us-west-2/blah/blah/blah/inf"
196+
)
197+
198+
assert (
199+
_retrieve_model_package_arn(
200+
region=region,
201+
model_id=model_id,
202+
scope=JumpStartScriptScope.INFERENCE,
203+
model_version="*",
204+
sagemaker_session=self.mock_session,
205+
instance_type="ml.afasfasf.12xlarge",
206+
)
207+
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
208+
)
209+
210+
assert (
211+
_retrieve_model_package_arn(
212+
region=region,
213+
model_id=model_id,
214+
scope=JumpStartScriptScope.INFERENCE,
215+
model_version="*",
216+
sagemaker_session=self.mock_session,
217+
instance_type="ml.m2.12xlarge",
218+
)
219+
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
220+
)
221+
222+
assert (
223+
_retrieve_model_package_arn(
224+
region=region,
225+
model_id=model_id,
226+
scope=JumpStartScriptScope.INFERENCE,
227+
model_version="*",
228+
sagemaker_session=self.mock_session,
229+
instance_type="nobodycares",
230+
)
231+
== "arn:aws:sagemaker:us-west-2:594846645681:model-package/llama2-7b-v3-740347e540da35b4ab9f6fc0ab3fed2c"
232+
)
233+
234+
with pytest.raises(ValueError):
235+
_retrieve_model_package_arn(
236+
region="cn-north-1",
237+
model_id=model_id,
238+
scope=JumpStartScriptScope.INFERENCE,
239+
model_version="*",
240+
sagemaker_session=self.mock_session,
241+
instance_type="ml.p2.12xlarge",
242+
)

0 commit comments

Comments
 (0)