Skip to content

Commit ea5bb5a

Browse files
committed
feat: add hub and hubcontent support in retrieval function for jumpstart model cache (aws#4438)
1 parent 3eb2dc4 commit ea5bb5a

File tree

6 files changed

+54
-2
lines changed

6 files changed

+54
-2
lines changed

src/sagemaker/jumpstart/cache.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
DescribeHubContentResponse,
5858
HubType,
5959
HubContentType,
60+
HubDataType,
6061
)
6162
from sagemaker.jumpstart.curated_hub import utils as hub_utils
6263
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -411,6 +412,7 @@ def _retrieval_function(
411412
"""
412413

413414
data_type, id_info = key.data_type, key.id_info
415+
414416
if data_type in {
415417
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
416418
JumpStartS3FileType.PROPRIETARY_MANIFEST,

src/sagemaker/jumpstart/constants.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@
172172
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
173173
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
174174

175-
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
175+
# works cross-partition
176+
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
176177
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
177178

178179
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/types.py

+19
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,25 @@ def from_describe_hub_content_response(self, response: DescribeHubContentRespons
15941594
else None
15951595
)
15961596

1597+
def to_json(self) -> Dict[str, Any]:
1598+
"""Returns json representation of JumpStartModelSpecs object."""
1599+
json_obj = {}
1600+
for att in self.__slots__:
1601+
if hasattr(self, att):
1602+
cur_val = getattr(self, att)
1603+
if issubclass(type(cur_val), JumpStartDataHolderType):
1604+
json_obj[att] = cur_val.to_json()
1605+
elif isinstance(cur_val, list):
1606+
json_obj[att] = []
1607+
for obj in cur_val:
1608+
if issubclass(type(obj), JumpStartDataHolderType):
1609+
json_obj[att].append(obj.to_json())
1610+
else:
1611+
json_obj[att].append(obj)
1612+
else:
1613+
json_obj[att] = cur_val
1614+
return json_obj
1615+
15971616
def supports_prepacked_inference(self) -> bool:
15981617
"""Returns True if the model has a prepacked inference artifact."""
15991618
return getattr(self, "hosting_prepacked_artifact_key", None) is not None

src/sagemaker/jumpstart/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import re
1818
from typing import Any, Dict, List, Set, Optional, Tuple, Union
19+
import re
1920
from urllib.parse import urlparse
2021
import boto3
2122
from packaging.version import Version
@@ -863,7 +864,6 @@ def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str:
863864
"""Returns the Studio Spec file prefix given a model ID and version."""
864865
return f"studio_models/{model_id}/studio_specs_v{model_version}.json"
865866

866-
867867
def extract_info_from_hub_content_arn(
868868
arn: str,
869869
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:

tests/unit/sagemaker/jumpstart/test_utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,35 @@ def test_mime_type_enum_from_str():
12061206
assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type
12071207

12081208

1209+
def test_extract_info_from_hub_content_arn():
1210+
model_arn = (
1211+
"arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2"
1212+
)
1213+
assert utils.extract_info_from_hub_content_arn(model_arn) == (
1214+
"MockHub",
1215+
"us-west-2",
1216+
"my-mock-model",
1217+
"1.0.2",
1218+
)
1219+
1220+
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub"
1221+
assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None)
1222+
1223+
invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123"
1224+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1225+
1226+
invalid_arn = "nonsense-string"
1227+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1228+
1229+
invalid_arn = ""
1230+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1231+
1232+
invalid_arn = (
1233+
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0"
1234+
)
1235+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1236+
1237+
12091238
class TestIsValidModelId(TestCase):
12101239
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12111240
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")

tests/unit/sagemaker/jumpstart/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25+
HubDataType,
2526
JumpStartCachedContentKey,
2627
JumpStartCachedContentValue,
2728
JumpStartModelSpecs,

0 commit comments

Comments
 (0)