Skip to content

Commit 700c16d

Browse files
committed
feat: add hub and hubcontent support in retrieval function for jumpstart model cache (aws#4438)
1 parent 9e8c622 commit 700c16d

File tree

11 files changed

+247
-58
lines changed

11 files changed

+247
-58
lines changed

src/sagemaker/jumpstart/cache.py

+69-23
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import datetime
1616
from difflib import get_close_matches
1717
import os
18-
from typing import List, Optional, Tuple, Union
18+
from typing import Any, Dict, List, Optional, Tuple, Union
1919
import json
2020
import boto3
2121
import botocore
@@ -43,12 +43,13 @@
4343
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
4444
)
4545
from sagemaker.jumpstart.types import (
46-
JumpStartCachedS3ContentKey,
47-
JumpStartCachedS3ContentValue,
46+
JumpStartCachedContentKey,
47+
JumpStartCachedContentValue,
4848
JumpStartModelHeader,
4949
JumpStartModelSpecs,
5050
JumpStartS3FileType,
5151
JumpStartVersionedModelId,
52+
HubDataType,
5253
)
5354
from sagemaker.jumpstart.enums import JumpStartModelType
5455
from sagemaker.jumpstart import utils
@@ -103,7 +104,7 @@ def __init__(
103104
"""
104105

105106
self._region = region
106-
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
107+
self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
107108
max_cache_items=max_s3_cache_items,
108109
expiration_horizon=s3_cache_expiration_horizon,
109110
retrieval_function=self._retrieval_function,
@@ -234,7 +235,7 @@ def _model_id_retrieval_function(
234235
model_id, version = key.model_id, key.version
235236
sm_version = utils.get_sagemaker_version()
236237
manifest = self._s3_cache.get(
237-
JumpStartCachedS3ContentKey(
238+
JumpStartCachedContentKey(
238239
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
239240
)[0].formatted_content
240241

@@ -399,46 +400,69 @@ def _get_json_file_from_local_override(
399400

400401
def _retrieval_function(
401402
self,
402-
key: JumpStartCachedS3ContentKey,
403-
value: Optional[JumpStartCachedS3ContentValue],
404-
) -> JumpStartCachedS3ContentValue:
405-
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
403+
key: JumpStartCachedContentKey,
404+
value: Optional[JumpStartCachedContentValue],
405+
) -> JumpStartCachedContentValue:
406+
"""Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey``.
406407
407408
If a manifest file is being fetched, we only download the object if the md5 hash in
408409
``head_object`` does not match the current md5 hash for the stored value. This prevents
409410
unnecessarily downloading the full manifest when it hasn't changed.
410411
411412
Args:
412-
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
413+
key (JumpStartCachedContentKey): key for which to fetch JumpStart content.
413414
value (Optional[JumpStartVersionedModelId]): Current value of old cached
414415
s3 content. This is used for the manifest file, so that it is only
415416
downloaded when its content changes.
416417
"""
417418

418-
file_type, s3_key = key.file_type, key.s3_key
419-
if file_type in {
419+
data_type, id_info = key.data_type, key.id_info
420+
421+
if data_type in {
420422
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
421423
JumpStartS3FileType.PROPRIETARY_MANIFEST,
422424
}:
423425
if value is not None and not self._is_local_metadata_mode():
424-
etag = self._get_json_md5_hash(s3_key)
426+
etag = self._get_json_md5_hash(id_info)
425427
if etag == value.md5_hash:
426428
return value
427-
formatted_body, etag = self._get_json_file(s3_key, file_type)
428-
return JumpStartCachedS3ContentValue(
429+
formatted_body, etag = self._get_json_file(id_info, data_type)
430+
return JumpStartCachedContentValue(
429431
formatted_content=utils.get_formatted_manifest(formatted_body),
430432
md5_hash=etag,
431433
)
432-
if file_type in {
434+
if data_type in {
433435
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
434436
JumpStartS3FileType.PROPRIETARY_SPECS,
435437
}:
436-
formatted_body, _ = self._get_json_file(s3_key, file_type)
438+
formatted_body, _ = self._get_json_file(id_info, data_type)
437439
model_specs = JumpStartModelSpecs(formatted_body)
438440
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
439-
return JumpStartCachedS3ContentValue(formatted_content=model_specs)
441+
return JumpStartCachedContentValue(formatted_content=model_specs)
442+
443+
if data_type == HubDataType.MODEL:
444+
hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn(
445+
id_info
446+
)
447+
hub = CuratedHub(hub_name=hub_name, region=region)
448+
hub_content = hub.describe_model(model_name=model_name, model_version=model_version)
449+
utils.emit_logs_based_on_model_specs(
450+
hub_content.content_document,
451+
self.get_region(),
452+
self._s3_client
453+
)
454+
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
455+
return JumpStartCachedContentValue(
456+
formatted_content=model_specs
457+
)
458+
if data_type == HubDataType.HUB:
459+
hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
460+
hub = CuratedHub(hub_name=hub_name, region=region)
461+
hub_info = hub.describe()
462+
return JumpStartCachedContentValue(formatted_content=hub_info)
463+
440464
raise ValueError(
441-
self._file_type_error_msg(file_type)
465+
self._file_type_error_msg(data_type)
442466
)
443467

444468
def get_manifest(
@@ -447,7 +471,7 @@ def get_manifest(
447471
) -> List[JumpStartModelHeader]:
448472
"""Return entire JumpStart models manifest."""
449473
manifest_dict = self._s3_cache.get(
450-
JumpStartCachedS3ContentKey(
474+
JumpStartCachedContentKey(
451475
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
452476
)[0].formatted_content
453477
manifest = list(manifest_dict.values()) # type: ignore
@@ -536,7 +560,7 @@ def _get_header_impl(
536560
)[0]
537561

538562
manifest = self._s3_cache.get(
539-
JumpStartCachedS3ContentKey(
563+
JumpStartCachedContentKey(
540564
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
541565
)[0].formatted_content
542566

@@ -566,7 +590,7 @@ def get_specs(
566590
header = self.get_header(model_id, version_str, model_type)
567591
spec_key = header.spec_key
568592
specs, cache_hit = self._s3_cache.get(
569-
JumpStartCachedS3ContentKey(
593+
JumpStartCachedContentKey(
570594
MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key
571595
)
572596
)
@@ -579,8 +603,30 @@ def get_specs(
579603
)
580604
return specs.formatted_content
581605

606+
def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
607+
"""Return JumpStart-compatible specs for a given Hub model
608+
609+
Args:
610+
hub_model_arn (str): Arn for the Hub model to get specs for
611+
"""
612+
613+
details, _ = self._content_cache.get(
614+
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
615+
)
616+
return details.formatted_content
617+
618+
def get_hub(self, hub_arn: str) -> Dict[str, Any]:
619+
"""Return descriptive info for a given Hub
620+
621+
Args:
622+
hub_arn (str): Arn for the Hub to get info for
623+
"""
624+
625+
details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
626+
return details.formatted_content
627+
582628
def clear(self) -> None:
583629
"""Clears the model ID/version and s3 cache."""
584-
self._s3_cache.clear()
630+
self._content_cache.clear()
585631
self._open_weight_model_id_manifest_key_cache.clear()
586632
self._proprietary_model_id_manifest_key_cache.clear()

src/sagemaker/jumpstart/constants.py

+4
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@
172172
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
173173
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
174174

175+
# works cross-partition
176+
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
177+
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"
178+
175179
INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
176180
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"
177181

src/sagemaker/jumpstart/curated_hub/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module provides the JumpStart Curated Hub class."""
14+
from __future__ import absolute_import
15+
16+
from typing import Optional, Dict, Any
17+
18+
from sagemaker.session import Session
19+
20+
21+
class CuratedHub:
22+
"""Class for creating and managing a curated JumpStart hub"""
23+
24+
def __init__(self, hub_name: str, region: str, session: Optional[Session] = None):
25+
self.hub_name = hub_name
26+
self.region = region
27+
self.session = session
28+
self._sm_session = session or Session()
29+
30+
def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
31+
"""Returns descriptive information about the Hub Model"""
32+
33+
hub_content = self._sm_session.describe_hub_content(
34+
model_name, "Model", self.hub_name, model_version
35+
)
36+
37+
# TODO: Parse HubContent
38+
# TODO: Parse HubContentDocument
39+
40+
return hub_content
41+
42+
def describe(self) -> Dict[str, Any]:
43+
"""Returns descriptive information about the Hub"""
44+
45+
hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)
46+
47+
# TODO: Validations?
48+
49+
return hub_info

src/sagemaker/jumpstart/types.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,17 @@ class JumpStartS3FileType(str, Enum):
109109
PROPRIETARY_SPECS = "proprietary_specs"
110110

111111

112+
class HubDataType(str, Enum):
113+
"""Enum for Hub data storage objects."""
114+
115+
HUB = "hub"
116+
MODEL = "model"
117+
NOTEBOOK = "notebook"
118+
119+
120+
JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType]
121+
122+
112123
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
113124
"""Data class for launched region info."""
114125

@@ -794,13 +805,16 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
794805
"model_subscription_link",
795806
]
796807

797-
def __init__(self, spec: Dict[str, Any]):
808+
def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False):
798809
"""Initializes a JumpStartModelSpecs object from its json representation.
799810
800811
Args:
801812
spec (Dict[str, Any]): Dictionary representation of spec.
802813
"""
803-
self.from_json(spec)
814+
if is_hub_content:
815+
self.from_hub_content_doc(spec)
816+
else:
817+
self.from_json(spec)
804818

805819
def from_json(self, json_obj: Dict[str, Any]) -> None:
806820
"""Sets fields in object based on json of header.
@@ -925,6 +939,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
925939
)
926940
self.model_subscription_link = json_obj.get("model_subscription_link")
927941

942+
def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None:
943+
"""Sets fields in object based on values in HubContentDocument
944+
945+
Args:
946+
hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
947+
from SageMaker:DescribeHubContent
948+
"""
949+
# TODO: Implement
950+
928951
def to_json(self) -> Dict[str, Any]:
929952
"""Returns json representation of JumpStartModelSpecs object."""
930953
json_obj = {}
@@ -992,27 +1015,27 @@ def __init__(
9921015
self.version = version
9931016

9941017

995-
class JumpStartCachedS3ContentKey(JumpStartDataHolderType):
996-
"""Data class for the s3 cached content keys."""
1018+
class JumpStartCachedContentKey(JumpStartDataHolderType):
1019+
"""Data class for the cached content keys."""
9971020

998-
__slots__ = ["file_type", "s3_key"]
1021+
__slots__ = ["data_type", "id_info"]
9991022

10001023
def __init__(
10011024
self,
1002-
file_type: JumpStartS3FileType,
1003-
s3_key: str,
1025+
data_type: JumpStartContentDataType,
1026+
id_info: str,
10041027
) -> None:
1005-
"""Instantiates JumpStartCachedS3ContentKey object.
1028+
"""Instantiates JumpStartCachedContentKey object.
10061029
10071030
Args:
1008-
file_type (JumpStartS3FileType): JumpStart file type.
1009-
s3_key (str): object key in s3.
1031+
data_type (JumpStartContentDataType): JumpStart content data type.
1032+
id_info (str): if S3Content, object key in s3. if HubContent, hub content arn.
10101033
"""
1011-
self.file_type = file_type
1012-
self.s3_key = s3_key
1034+
self.data_type = data_type
1035+
self.id_info = id_info
10131036

10141037

1015-
class JumpStartCachedS3ContentValue(JumpStartDataHolderType):
1038+
class JumpStartCachedContentValue(JumpStartDataHolderType):
10161039
"""Data class for the s3 cached content values."""
10171040

10181041
__slots__ = ["formatted_content", "md5_hash"]
@@ -1025,7 +1048,7 @@ def __init__(
10251048
],
10261049
md5_hash: Optional[str] = None,
10271050
) -> None:
1028-
"""Instantiates JumpStartCachedS3ContentValue object.
1051+
"""Instantiates JumpStartCachedContentValue object.
10291052
10301053
Args:
10311054
formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader],

src/sagemaker/jumpstart/utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import os
1717
from typing import Any, Dict, List, Set, Optional, Tuple, Union
18+
import re
1819
from urllib.parse import urlparse
1920
import boto3
2021
from packaging.version import Version
@@ -842,3 +843,26 @@ def get_jumpstart_model_id_version_from_resource_arn(
842843
model_version = model_version_from_tag
843844

844845
return model_id, model_version
846+
847+
848+
def extract_info_from_hub_content_arn(
849+
arn: str,
850+
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
851+
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""
852+
853+
match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
854+
if match:
855+
hub_name = match.group(4)
856+
hub_region = match.group(2)
857+
content_name = match.group(5)
858+
content_version = match.group(6)
859+
860+
return hub_name, hub_region, content_name, content_version
861+
862+
match = re.match(constants.HUB_ARN_REGEX, arn)
863+
if match:
864+
hub_name = match.group(4)
865+
hub_region = match.group(2)
866+
return hub_name, hub_region, None, None
867+
868+
return None, None, None, None

tests/unit/sagemaker/jumpstart/curated_hub/__init__.py

Whitespace-only changes.

tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py

Whitespace-only changes.

0 commit comments

Comments
 (0)