Skip to content

Commit 29718b4

Browse files
committed
feat: add hub and hubcontent support in retrieval function for jumpstart model cache (aws#4438)
1 parent 2d638eb commit 29718b4

File tree

6 files changed

+80
-7
lines changed

6 files changed

+80
-7
lines changed

src/sagemaker/jumpstart/cache.py

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
DescribeHubContentsResponse,
5858
HubType,
5959
HubContentType,
60+
HubDataType,
6061
)
6162
from sagemaker.jumpstart.curated_hub import utils as hub_utils
6263
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -428,6 +429,7 @@ def _retrieval_function(
428429
"""
429430

430431
data_type, id_info = key.data_type, key.id_info
432+
431433
if data_type in {
432434
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
433435
JumpStartS3FileType.PROPRIETARY_MANIFEST,

src/sagemaker/jumpstart/constants.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@
172172
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
173173
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
174174

175-
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
175+
# works cross-partition
176+
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
176177
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
177178

178179
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"

src/sagemaker/jumpstart/types.py

+19
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,25 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None:
972972
"""
973973
# TODO: Implement
974974

975+
def to_json(self) -> Dict[str, Any]:
976+
"""Returns json representation of JumpStartModelSpecs object."""
977+
json_obj = {}
978+
for att in self.__slots__:
979+
if hasattr(self, att):
980+
cur_val = getattr(self, att)
981+
if issubclass(type(cur_val), JumpStartDataHolderType):
982+
json_obj[att] = cur_val.to_json()
983+
elif isinstance(cur_val, list):
984+
json_obj[att] = []
985+
for obj in cur_val:
986+
if issubclass(type(obj), JumpStartDataHolderType):
987+
json_obj[att].append(obj.to_json())
988+
else:
989+
json_obj[att].append(obj)
990+
else:
991+
json_obj[att] = cur_val
992+
return json_obj
993+
975994
def supports_prepacked_inference(self) -> bool:
976995
"""Returns True if the model has a prepacked inference artifact."""
977996
return getattr(self, "hosting_prepacked_artifact_key", None) is not None

src/sagemaker/jumpstart/utils.py

+23
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import os
1717
from typing import Any, Dict, List, Set, Optional, Tuple, Union
18+
import re
1819
from urllib.parse import urlparse
1920
import boto3
2021
from packaging.version import Version
@@ -866,3 +867,25 @@ def get_jumpstart_model_id_version_from_resource_arn(
866867
def generate_studio_spec_file_prefix(model_id: str, model_version: str) -> str:
867868
"""Returns the Studio Spec file prefix given a model ID and version."""
868869
return f"studio_models/{model_id}/studio_specs_v{model_version}.json"
870+
871+
def extract_info_from_hub_content_arn(
872+
arn: str,
873+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
874+
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""
875+
876+
match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
877+
if match:
878+
hub_name = match.group(4)
879+
hub_region = match.group(2)
880+
content_name = match.group(5)
881+
content_version = match.group(6)
882+
883+
return hub_name, hub_region, content_name, content_version
884+
885+
match = re.match(constants.HUB_ARN_REGEX, arn)
886+
if match:
887+
hub_name = match.group(4)
888+
hub_region = match.group(2)
889+
return hub_name, hub_region, None, None
890+
891+
return None, None, None, None

tests/unit/sagemaker/jumpstart/test_utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,35 @@ def test_mime_type_enum_from_str():
12141214
assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type
12151215

12161216

1217+
def test_extract_info_from_hub_content_arn():
1218+
model_arn = (
1219+
"arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2"
1220+
)
1221+
assert utils.extract_info_from_hub_content_arn(model_arn) == (
1222+
"MockHub",
1223+
"us-west-2",
1224+
"my-mock-model",
1225+
"1.0.2",
1226+
)
1227+
1228+
hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub"
1229+
assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None)
1230+
1231+
invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123"
1232+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1233+
1234+
invalid_arn = "nonsense-string"
1235+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1236+
1237+
invalid_arn = ""
1238+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1239+
1240+
invalid_arn = (
1241+
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0"
1242+
)
1243+
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)
1244+
1245+
12171246
class TestIsValidModelId(TestCase):
12181247
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
12191248
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")

tests/unit/sagemaker/jumpstart/utils.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
JUMPSTART_REGION_NAME_SET,
2323
)
2424
from sagemaker.jumpstart.types import (
25+
HubDataType,
2526
JumpStartCachedContentKey,
2627
JumpStartCachedContentValue,
2728
JumpStartModelSpecs,
@@ -218,12 +219,10 @@ def patched_retrieval_function(
218219
datatype, id_info = key.data_type, key.id_info
219220
if datatype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST:
220221

221-
return JumpStartCachedContentValue(
222-
formatted_content=get_formatted_manifest(BASE_MANIFEST)
223-
)
222+
return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST))
224223

225-
if datatype == JumpStartCachedContentValue.OPEN_WEIGHT_SPECS:
226-
_, model_id, specs_version = s3_key.split("/")
224+
if datatype == JumpStartS3FileType.OPEN_WEIGHT_SPECS:
225+
_, model_id, specs_version = id_info.split("/")
227226
version = specs_version.replace("specs_v", "").replace(".json", "")
228227
return JumpStartCachedContentValue(
229228
formatted_content=get_spec_from_base_spec(model_id=model_id, version=version)
@@ -245,7 +244,7 @@ def patched_retrieval_function(
245244
)
246245

247246
if datatype == JumpStartS3FileType.PROPRIETARY_SPECS:
248-
_, model_id, specs_version = s3_key.split("/")
247+
_, model_id, specs_version = id_info.split("/")
249248
version = specs_version.replace("proprietary_specs_", "").replace(".json", "")
250249
return JumpStartCachedContentValue(
251250
formatted_content=get_spec_from_base_spec(

0 commit comments

Comments
 (0)