Skip to content

add hub and hubcontent support in retrieval function for jumpstart model cache #4438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 72 additions & 27 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import datetime
from difflib import get_close_matches
import os
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import boto3
import botocore
Expand All @@ -29,6 +29,7 @@
JUMPSTART_LOGGER,
MODEL_ID_LIST_WEB_URL,
)
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
from sagemaker.jumpstart.parameters import (
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
Expand All @@ -37,12 +38,13 @@
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
)
from sagemaker.jumpstart.types import (
JumpStartCachedS3ContentKey,
JumpStartCachedS3ContentValue,
JumpStartCachedContentKey,
JumpStartCachedContentValue,
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartS3FileType,
JumpStartVersionedModelId,
HubDataType,
)
from sagemaker.jumpstart import utils
from sagemaker.utilities.cache import LRUCache
Expand Down Expand Up @@ -95,7 +97,7 @@ def __init__(
"""

self._region = region
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
max_cache_items=max_s3_cache_items,
expiration_horizon=s3_cache_expiration_horizon,
retrieval_function=self._retrieval_function,
Expand Down Expand Up @@ -172,8 +174,8 @@ def _get_manifest_key_from_model_id_semantic_version(

model_id, version = key.model_id, key.version

manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
manifest = self._content_cache.get(
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
)[0].formatted_content

sm_version = utils.get_sagemaker_version()
Expand Down Expand Up @@ -301,50 +303,71 @@ def _get_json_file_from_local_override(

def _retrieval_function(
self,
key: JumpStartCachedS3ContentKey,
value: Optional[JumpStartCachedS3ContentValue],
) -> JumpStartCachedS3ContentValue:
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
key: JumpStartCachedContentKey,
value: Optional[JumpStartCachedContentValue],
) -> JumpStartCachedContentValue:
"""Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey``.

If a manifest file is being fetched, we only download the object if the md5 hash in
``head_object`` does not match the current md5 hash for the stored value. This prevents
unnecessarily downloading the full manifest when it hasn't changed.

Args:
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
key (JumpStartCachedContentKey): key for which to fetch JumpStart content.
value (Optional[JumpStartVersionedModelId]): Current value of old cached
s3 content. This is used for the manifest file, so that it is only
downloaded when its content changes.
"""

file_type, s3_key = key.file_type, key.s3_key
data_type, id_info = key.data_type, key.id_info

if file_type == JumpStartS3FileType.MANIFEST:
if data_type == JumpStartS3FileType.MANIFEST:
if value is not None and not self._is_local_metadata_mode():
etag = self._get_json_md5_hash(s3_key)
etag = self._get_json_md5_hash(id_info)
if etag == value.md5_hash:
return value
formatted_body, etag = self._get_json_file(s3_key, file_type)
return JumpStartCachedS3ContentValue(
formatted_body, etag = self._get_json_file(id_info, data_type)
return JumpStartCachedContentValue(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)
if file_type == JumpStartS3FileType.SPECS:
formatted_body, _ = self._get_json_file(s3_key, file_type)
if data_type == JumpStartS3FileType.SPECS:
formatted_body, _ = self._get_json_file(id_info, data_type)
model_specs = JumpStartModelSpecs(formatted_body)
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
return JumpStartCachedS3ContentValue(
return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubDataType.MODEL:
hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn(
id_info
)
hub = CuratedHub(hub_name=hub_name, region=region)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's weird that we're instantiating the CuratedHub class in this module. I'm confused why we cannot directly call the describe_model function via some shared utility

hub_content = hub.describe_model(model_name=model_name, model_version=model_version)
utils.emit_logs_based_on_model_specs(
hub_content.content_document,
self.get_region(),
self._s3_client
)
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubDataType.HUB:
hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
hub = CuratedHub(hub_name=hub_name, region=region)
hub_info = hub.describe()
return JumpStartCachedContentValue(formatted_content=hub_info)
raise ValueError(
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
f"Bad value for key '{key}': must be in",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
)

def get_manifest(self) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest."""

manifest_dict = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
manifest_dict = self._content_cache.get(
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
)[0].formatted_content
manifest = list(manifest_dict.values()) # type: ignore
return manifest
Expand Down Expand Up @@ -407,8 +430,8 @@ def _get_header_impl(
JumpStartVersionedModelId(model_id, semantic_version_str)
)[0]

manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
manifest = self._content_cache.get(
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
)[0].formatted_content
try:
header = manifest[versioned_model_id] # type: ignore
Expand All @@ -430,8 +453,8 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS

header = self.get_header(model_id, semantic_version_str)
spec_key = header.spec_key
specs, cache_hit = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
specs, cache_hit = self._content_cache.get(
JumpStartCachedContentKey(JumpStartS3FileType.SPECS, spec_key)
)
if not cache_hit and "*" in semantic_version_str:
JUMPSTART_LOGGER.warning(
Expand All @@ -443,7 +466,29 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
)
return specs.formatted_content

def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
"""Return JumpStart-compatible specs for a given Hub model

Args:
hub_model_arn (str): Arn for the Hub model to get specs for
"""

details, _ = self._content_cache.get(
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
)
return details.formatted_content

def get_hub(self, hub_arn: str) -> Dict[str, Any]:
"""Return descriptive info for a given Hub

Args:
hub_arn (str): Arn for the Hub to get info for
"""

details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
return details.formatted_content

def clear(self) -> None:
"""Clears the model ID/version and s3 cache."""
self._s3_cache.clear()
self._content_cache.clear()
self._model_id_semantic_version_manifest_key_cache.clear()
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@

JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"

# works cross-partition
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"

INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"

Expand Down
Empty file.
49 changes: 49 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/curated_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module provides the JumpStart Curated Hub class."""
from __future__ import absolute_import

from typing import Optional, Dict, Any

from sagemaker.session import Session


class CuratedHub:
"""Class for creating and managing a curated JumpStart hub"""

def __init__(self, hub_name: str, region: str, session: Optional[Session] = None):
self.hub_name = hub_name
self.region = region
self.session = session
self._sm_session = session or Session()

def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
"""Returns descriptive information about the Hub Model"""

hub_content = self._sm_session.describe_hub_content(
model_name, "Model", self.hub_name, model_version
)

# TODO: Parse HubContent
# TODO: Parse HubContentDocument

return hub_content

def describe(self) -> Dict[str, Any]:
"""Returns descriptive information about the Hub"""

hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)

# TODO: Validations?

return hub_info
51 changes: 37 additions & 14 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ class JumpStartS3FileType(str, Enum):
SPECS = "specs"


class HubDataType(str, Enum):
"""Enum for Hub data storage objects."""

HUB = "hub"
MODEL = "model"
NOTEBOOK = "notebook"


JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType]


class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
"""Data class for launched region info."""

Expand Down Expand Up @@ -767,13 +778,16 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
"gated_bucket",
]

def __init__(self, spec: Dict[str, Any]):
def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False):
"""Initializes a JumpStartModelSpecs object from its json representation.

Args:
spec (Dict[str, Any]): Dictionary representation of spec.
"""
self.from_json(spec)
if is_hub_content:
self.from_hub_content_doc(spec)
else:
self.from_json(spec)

def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json of header.
Expand Down Expand Up @@ -895,6 +909,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
else None
)

def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None:
"""Sets fields in object based on values in HubContentDocument

Args:
hub_content_doc (Dict[str, any]): parsed HubContentDocument returned
from SageMaker:DescribeHubContent
"""
# TODO: Implement

def to_json(self) -> Dict[str, Any]:
"""Returns json representation of JumpStartModelSpecs object."""
json_obj = {}
Expand Down Expand Up @@ -958,27 +981,27 @@ def __init__(
self.version = version


class JumpStartCachedS3ContentKey(JumpStartDataHolderType):
"""Data class for the s3 cached content keys."""
class JumpStartCachedContentKey(JumpStartDataHolderType):
"""Data class for the cached content keys."""

__slots__ = ["file_type", "s3_key"]
__slots__ = ["data_type", "id_info"]

def __init__(
self,
file_type: JumpStartS3FileType,
s3_key: str,
data_type: JumpStartContentDataType,
id_info: str,
) -> None:
"""Instantiates JumpStartCachedS3ContentKey object.
"""Instantiates JumpStartCachedContentKey object.

Args:
file_type (JumpStartS3FileType): JumpStart file type.
s3_key (str): object key in s3.
data_type (JumpStartContentDataType): JumpStart content data type.
id_info (str): if S3Content, object key in s3. if HubContent, hub content arn.
"""
self.file_type = file_type
self.s3_key = s3_key
self.data_type = data_type
self.id_info = id_info


class JumpStartCachedS3ContentValue(JumpStartDataHolderType):
class JumpStartCachedContentValue(JumpStartDataHolderType):
"""Data class for the s3 cached content values."""

__slots__ = ["formatted_content", "md5_hash"]
Expand All @@ -991,7 +1014,7 @@ def __init__(
],
md5_hash: Optional[str] = None,
) -> None:
"""Instantiates JumpStartCachedS3ContentValue object.
"""Instantiates JumpStartCachedContentValue object.

Args:
formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader],
Expand Down
24 changes: 24 additions & 0 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
import boto3
Expand Down Expand Up @@ -810,3 +811,26 @@ def get_jumpstart_model_id_version_from_resource_arn(
model_version = model_version_from_tag

return model_id, model_version


def extract_info_from_hub_content_arn(
arn: str,
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
"""Extracts hub_name, content_name, and content_version from a HubContentArn"""

match = re.match(constants.HUB_MODEL_ARN_REGEX, arn)
if match:
hub_name = match.group(4)
hub_region = match.group(2)
content_name = match.group(5)
content_version = match.group(6)

return hub_name, hub_region, content_name, content_version

match = re.match(constants.HUB_ARN_REGEX, arn)
if match:
hub_name = match.group(4)
hub_region = match.group(2)
return hub_name, hub_region, None, None

return None, None, None, None
Empty file.
Empty file.
Loading