Skip to content

Feat/jsch jumpstart estimator support #4439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
032cb80
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
ef042d9
update types and var names
bencrabtree Feb 19, 2024
49ae11b
update linter
bencrabtree Feb 19, 2024
6175087
linter
bencrabtree Feb 19, 2024
4c9b2d0
linter
bencrabtree Feb 19, 2024
63345ea
flake8 check
bencrabtree Feb 19, 2024
6efc206
add hub name support for jumpstart estimator
bencrabtree Feb 20, 2024
2e9f76f
linter
bencrabtree Feb 20, 2024
ac8dd60
linter2
bencrabtree Feb 20, 2024
d4f7a00
fix param
bencrabtree Feb 20, 2024
5492474
move to utils and test
bencrabtree Feb 21, 2024
1e26760
feat: add hub and hubcontent support in retrieval function for jumpst…
bencrabtree Feb 21, 2024
4ae201c
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
4a19a33
update types and var names
bencrabtree Feb 19, 2024
4d35379
update linter
bencrabtree Feb 19, 2024
174c4fd
linter
bencrabtree Feb 19, 2024
b7a8835
linter
bencrabtree Feb 19, 2024
bb7a9fb
flake8 check
bencrabtree Feb 19, 2024
8ba576a
pass hub_arn into all estimator utils/artifacts
bencrabtree Feb 21, 2024
ecd1f97
feat: add hub and hubcontent support in retrieval function for jumpst…
bencrabtree Feb 21, 2024
8df4478
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
4870c0b
update types and var names
bencrabtree Feb 19, 2024
4dc21f5
update linter
bencrabtree Feb 19, 2024
dd51314
remove duplicate
bencrabtree Feb 21, 2024
195b84b
Merge branch 'master-jumpstart-curated-hub' of https://github.com/aws…
bencrabtree Feb 21, 2024
a39ae5f
linter
bencrabtree Feb 21, 2024
354b33e
add important unit test
bencrabtree Feb 21, 2024
424254e
update tests
bencrabtree Feb 22, 2024
8a3160a
black styles
bencrabtree Feb 22, 2024
151350c
finish tests
bencrabtree Feb 22, 2024
dd087da
create curated hub utils and types
bencrabtree Feb 23, 2024
a61dfb4
fix linter
bencrabtree Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand Down Expand Up @@ -80,6 +81,7 @@ def retrieve_default(
model_id,
model_version,
scope,
hub_arn,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
Expand All @@ -92,6 +94,7 @@ def retrieve(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand Down Expand Up @@ -142,6 +145,7 @@ def retrieve(
model_id,
model_version,
scope,
hub_arn,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
Expand Down
15 changes: 13 additions & 2 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from sagemaker.deprecations import deprecated
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart import cache
from sagemaker.jumpstart import cache, utils
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME


Expand Down Expand Up @@ -239,7 +239,11 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel

@staticmethod
def get_model_specs(
region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None
region: str,
model_id: str,
version: str,
hub_arn: Optional[str] = None,
s3_client: Optional[boto3.client] = None,
) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.

Expand All @@ -259,6 +263,13 @@ def get_model_specs(
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)

if hub_arn:
hub_model_arn = utils.construct_hub_model_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn)

return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, semantic_version_str=version
)
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _retrieve_default_instance_type(
model_id: str,
model_version: str,
scope: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand Down Expand Up @@ -80,6 +81,7 @@ def _retrieve_default_instance_type(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down Expand Up @@ -119,6 +121,7 @@ def _retrieve_instance_types(
model_id: str,
model_version: str,
scope: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand Down Expand Up @@ -166,6 +169,7 @@ def _retrieve_instance_types(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/artifacts/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def _retrieve_estimator_init_kwargs(
def _retrieve_estimator_fit_kwargs(
model_id: str,
model_version: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand Down Expand Up @@ -234,6 +235,7 @@ def _retrieve_estimator_fit_kwargs(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
99 changes: 72 additions & 27 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import datetime
from difflib import get_close_matches
import os
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import boto3
import botocore
Expand All @@ -29,6 +29,7 @@
JUMPSTART_LOGGER,
MODEL_ID_LIST_WEB_URL,
)
from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
from sagemaker.jumpstart.parameters import (
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
Expand All @@ -37,12 +38,13 @@
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
)
from sagemaker.jumpstart.types import (
JumpStartCachedS3ContentKey,
JumpStartCachedS3ContentValue,
JumpStartCachedContentKey,
JumpStartCachedContentValue,
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartS3FileType,
JumpStartVersionedModelId,
HubDataType,
)
from sagemaker.jumpstart import utils
from sagemaker.utilities.cache import LRUCache
Expand Down Expand Up @@ -95,7 +97,7 @@ def __init__(
"""

self._region = region
self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
max_cache_items=max_s3_cache_items,
expiration_horizon=s3_cache_expiration_horizon,
retrieval_function=self._retrieval_function,
Expand Down Expand Up @@ -172,8 +174,8 @@ def _get_manifest_key_from_model_id_semantic_version(

model_id, version = key.model_id, key.version

manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
manifest = self._content_cache.get(
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
)[0].formatted_content

sm_version = utils.get_sagemaker_version()
Expand Down Expand Up @@ -301,50 +303,71 @@ def _get_json_file_from_local_override(

def _retrieval_function(
self,
key: JumpStartCachedS3ContentKey,
value: Optional[JumpStartCachedS3ContentValue],
) -> JumpStartCachedS3ContentValue:
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
key: JumpStartCachedContentKey,
value: Optional[JumpStartCachedContentValue],
) -> JumpStartCachedContentValue:
"""Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey``.

If a manifest file is being fetched, we only download the object if the md5 hash in
``head_object`` does not match the current md5 hash for the stored value. This prevents
unnecessarily downloading the full manifest when it hasn't changed.

Args:
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
key (JumpStartCachedContentKey): key for which to fetch JumpStart content.
value (Optional[JumpStartVersionedModelId]): Current value of old cached
s3 content. This is used for the manifest file, so that it is only
downloaded when its content changes.
"""

file_type, s3_key = key.file_type, key.s3_key
data_type, id_info = key.data_type, key.id_info

if file_type == JumpStartS3FileType.MANIFEST:
if data_type == JumpStartS3FileType.MANIFEST:
if value is not None and not self._is_local_metadata_mode():
etag = self._get_json_md5_hash(s3_key)
etag = self._get_json_md5_hash(id_info)
if etag == value.md5_hash:
return value
formatted_body, etag = self._get_json_file(s3_key, file_type)
return JumpStartCachedS3ContentValue(
formatted_body, etag = self._get_json_file(id_info, data_type)
return JumpStartCachedContentValue(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)
if file_type == JumpStartS3FileType.SPECS:
formatted_body, _ = self._get_json_file(s3_key, file_type)
if data_type == JumpStartS3FileType.SPECS:
formatted_body, _ = self._get_json_file(id_info, data_type)
model_specs = JumpStartModelSpecs(formatted_body)
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
return JumpStartCachedS3ContentValue(
return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubDataType.MODEL:
hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn(
id_info
)
hub = CuratedHub(hub_name=hub_name, region=region)
hub_content = hub.describe_model(model_name=model_name, model_version=model_version)
utils.emit_logs_based_on_model_specs(
hub_content.content_document,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumed parsed HubContent. Suggest either explicitly commenting that or waiting for that implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a feature branch so I feel comfortable merging with out any of the implementation done, but yes this will be parsed. The HubContentDocument should be parsed at this point too

self.get_region(),
self._s3_client
)
model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True)
return JumpStartCachedContentValue(
formatted_content=model_specs
)
if data_type == HubDataType.HUB:
hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info)
hub = CuratedHub(hub_name=hub_name, region=region)
hub_info = hub.describe()
return JumpStartCachedContentValue(formatted_content=hub_info)
raise ValueError(
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
f"Bad value for key '{key}': must be in",
f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}"
)

def get_manifest(self) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest."""

manifest_dict = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
manifest_dict = self._content_cache.get(
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
)[0].formatted_content
manifest = list(manifest_dict.values()) # type: ignore
return manifest
Expand Down Expand Up @@ -407,8 +430,8 @@ def _get_header_impl(
JumpStartVersionedModelId(model_id, semantic_version_str)
)[0]

manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
manifest = self._content_cache.get(
JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key)
)[0].formatted_content
try:
header = manifest[versioned_model_id] # type: ignore
Expand All @@ -430,8 +453,8 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS

header = self.get_header(model_id, semantic_version_str)
spec_key = header.spec_key
specs, cache_hit = self._s3_cache.get(
JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key)
specs, cache_hit = self._content_cache.get(
JumpStartCachedContentKey(JumpStartS3FileType.SPECS, spec_key)
)
if not cache_hit and "*" in semantic_version_str:
JUMPSTART_LOGGER.warning(
Expand All @@ -443,7 +466,29 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
)
return specs.formatted_content

def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
"""Return JumpStart-compatible specs for a given Hub model

Args:
hub_model_arn (str): Arn for the Hub model to get specs for
"""

details, _ = self._content_cache.get(
JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)
)
return details.formatted_content

def get_hub(self, hub_arn: str) -> Dict[str, Any]:
"""Return descriptive info for a given Hub

Args:
hub_arn (str): Arn for the Hub to get info for
"""

details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn))
return details.formatted_content

def clear(self) -> None:
"""Clears the model ID/version and s3 cache."""
self._s3_cache.clear()
self._content_cache.clear()
self._model_id_semantic_version_manifest_key_cache.clear()
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@

JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"

# works for cross-partition
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/(.*?)/(.*?)/(.*?)$"
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"

INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"

Expand Down
Empty file.
50 changes: 50 additions & 0 deletions src/sagemaker/jumpstart/curated_hub/curated_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module provides the JumpStart Curated Hub class."""
from __future__ import absolute_import

from typing import Optional, Dict, Any
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION

from sagemaker.session import Session


class CuratedHub:
"""Class for creating and managing a curated JumpStart hub"""

def __init__(self, hub_name: str, region: str, session: Optional[Session] = None):
self.hub_name = hub_name
self.region = region
self.session = session
self._sm_session = session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION

def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]:
"""Returns descriptive information about the Hub Model"""

hub_content = self._sm_session.describe_hub_content(
model_name, "Model", self.hub_name, model_version
)

# TODO: Parse HubContent
# TODO: Parse HubContentDocument

return hub_content

def describe(self) -> Dict[str, Any]:
"""Returns descriptive information about the Hub"""

hub_info = self._sm_session.describe_hub(hub_name=self.hub_name)

# TODO: Validations?

return hub_info
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class JumpStartTag(str, Enum):
MODEL_ID = "sagemaker-sdk:jumpstart-model-id"
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"

HUB_ARN = "sagemaker-hub:hub-arn"


class SerializerType(str, Enum):
"""Enum class for serializers associated with JumpStart models."""
Expand Down
Loading