From 032cb8073743b8a3684ef29ccfeb1b73ec056f1b Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 19:31:18 +0000 Subject: [PATCH 1/6] add hub and hubcontent support in retrieval function for jumpstart model cache --- src/sagemaker/jumpstart/cache.py | 91 +++++++++++++------ src/sagemaker/jumpstart/constants.py | 4 + .../jumpstart/curated_hub/__init__.py | 0 .../jumpstart/curated_hub/constants.py | 15 +++ .../jumpstart/curated_hub/curated_hub.py | 55 +++++++++++ src/sagemaker/jumpstart/types.py | 47 +++++++--- src/sagemaker/jumpstart/utils.py | 24 +++++ .../jumpstart/curated_hub/__init__.py | 0 .../jumpstart/curated_hub/test_curated_hub.py | 0 tests/unit/sagemaker/jumpstart/test_cache.py | 23 ++++- tests/unit/sagemaker/jumpstart/test_utils.py | 29 ++++++ tests/unit/sagemaker/jumpstart/utils.py | 37 +++++--- 12 files changed, 270 insertions(+), 55 deletions(-) create mode 100644 src/sagemaker/jumpstart/curated_hub/__init__.py create mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py create mode 100644 src/sagemaker/jumpstart/curated_hub/curated_hub.py create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/__init__.py create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e26d588167..a327b4a87e 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -15,7 +15,7 @@ import datetime from difflib import get_close_matches import os -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import json import boto3 import botocore @@ -29,6 +29,7 @@ JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, ) +from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, @@ -37,12 +38,13 @@ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, + HubDataType, ) from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache @@ -95,7 +97,7 @@ def __init__( """ self._region = region - self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( + self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, @@ -172,8 +174,8 @@ def _get_manifest_key_from_model_id_semantic_version( model_id, version = key.model_id, key.version - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + manifest = self._content_cache.get( + JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) )[0].formatted_content sm_version = utils.get_sagemaker_version() @@ -301,50 +303,65 @@ def _get_json_file_from_local_override( def _retrieval_function( self, - key: JumpStartCachedS3ContentKey, - value: Optional[JumpStartCachedS3ContentValue], - ) -> JumpStartCachedS3ContentValue: - """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + key: JumpStartCachedContentKey, + value: Optional[JumpStartCachedContentValue], + ) -> JumpStartCachedContentValue: + """Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey``. If a manifest file is being fetched, we only download the object if the md5 hash in ``head_object`` does not match the current md5 hash for the stored value. This prevents unnecessarily downloading the full manifest when it hasn't changed. Args: - key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. + key (JumpStartCachedContentKey): key for which to fetch JumpStart content. value (Optional[JumpStartVersionedModelId]): Current value of old cached s3 content. This is used for the manifest file, so that it is only downloaded when its content changes. """ - file_type, s3_key = key.file_type, key.s3_key + data_type, id_info = key.data_type, key.id_info - if file_type == JumpStartS3FileType.MANIFEST: + if data_type == JumpStartS3FileType.MANIFEST: if value is not None and not self._is_local_metadata_mode(): - etag = self._get_json_md5_hash(s3_key) + etag = self._get_json_md5_hash(id_info) if etag == value.md5_hash: return value - formatted_body, etag = self._get_json_file(s3_key, file_type) - return JumpStartCachedS3ContentValue( + formatted_body, etag = self._get_json_file(id_info, data_type) + return JumpStartCachedContentValue( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type == JumpStartS3FileType.SPECS: - formatted_body, _ = self._get_json_file(s3_key, file_type) + if data_type == JumpStartS3FileType.SPECS: + formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=model_specs ) + if data_type == HubDataType.MODEL: + hub_name, hub_region, model_id, model_version = utils.extract_info_from_hub_content_arn(id_info) + hub = CuratedHub(hub_name=hub_name, hub_region=hub_region) + hub_content = hub.describe_model(model_id=model_id, model_version=model_version) + utils.emit_logs_based_on_model_specs(hub_content.content_document, self.get_region(), self._s3_client) + model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True) + return JumpStartCachedContentValue( + formatted_content=model_specs + ) + if data_type == HubDataType.HUB: + hub_name, hub_region, _, _ = utils.extract_info_from_hub_content_arn(id_info) + hub = CuratedHub(hub_name=hub_name, hub_region=hub_region) + hub_info = hub.describe() + return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( - f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" + f"Bad value for key '{key}': must be in", + f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}" ) def get_manifest(self) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + manifest_dict = self._content_cache.get( + JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest @@ -407,8 +424,8 @@ def _get_header_impl( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + manifest = self._content_cache.get( + JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) )[0].formatted_content try: header = manifest[versioned_model_id] # type: ignore @@ -430,8 +447,8 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS header = self.get_header(model_id, semantic_version_str) spec_key = header.spec_key - specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) + specs, cache_hit = self._content_cache.get( + JumpStartCachedContentKey(JumpStartS3FileType.SPECS, spec_key) ) if not cache_hit and "*" in semantic_version_str: JUMPSTART_LOGGER.warning( @@ -442,8 +459,28 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS ) ) return specs.formatted_content + + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + return specs.formatted_content + + def get_hub(self, hub_arn: str) -> Dict[str, Any]: + """Return descriptive info for a given Hub + + Args: + hub_arn (str): Arn for the Hub to get info for + """ + + manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + return manifest.formatted_content def clear(self) -> None: """Clears the model ID/version and s3 cache.""" - self._s3_cache.clear() + self._content_cache.clear() self._model_id_semantic_version_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 2e655ac285..0e552aaac6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -170,6 +170,10 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" +# works cross-partition +HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" +HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" diff --git a/src/sagemaker/jumpstart/curated_hub/__init__.py b/src/sagemaker/jumpstart/curated_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py new file mode 100644 index 0000000000..478c7e801b --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/constants.py @@ -0,0 +1,15 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.config import Config + +DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"}) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py new file mode 100644 index 0000000000..5ef0579609 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from typing import Optional, Dict, Any + +import boto3 + +from sagemaker.session import Session + +from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG + + +class CuratedHub: + """Class for creating and managing a curated JumpStart hub""" + + def __init__(self, hub_name: str, region: str, session: Optional[Session]): + self.hub_name = hub_name + self.region = region + self.session = session + self._s3_client = self._get_s3_client() + self._sm_session = session or Session() + + def _get_s3_client(self) -> Any: + """Returns an S3 client.""" + return boto3.client("s3", region_name=self._region, config=DEFAULT_CLIENT_CONFIG) + + def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: + """Returns descriptive information about the Hub Model""" + + hub_content = self._sm_session.describe_hub_content( + model_name, "Model", self.hub_name, model_version + ) + + # TODO: Parse HubContent + # TODO: Parse HubContentDocument + + return hub_content + + def describe(self) -> Dict[str, Any]: + """Returns descriptive information about the Hub""" + + hub_info = self._sm_session.describe_hub(hub_name=self.hub_name) + + # TODO: Validations? + + return hub_info diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 49d3e295c5..c811475efd 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -106,6 +106,17 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" +class HubDataType(str, Enum): + """Enum for Hub data storage objects.""" + + HUB = "hub" + MODEL = "model" + NOTEBOOK = "notebook" + + +JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType] + + class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" @@ -767,13 +778,16 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "gated_bucket", ] - def __init__(self, spec: Dict[str, Any]): + def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False): """Initializes a JumpStartModelSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ - self.from_json(spec) + if is_hub_content: + self.from_hub_content_doc(spec) + else: + self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json of header. @@ -895,6 +909,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: else None ) + def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: + """Sets fields in object based on values in HubContentDocument + + Args: + hub_content_doc (Dict[str, any]): parsed HubContentDocument returned + from SageMaker:DescribeHubContent + """ + # TODO: Implement + def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartModelSpecs object.""" json_obj = {} @@ -958,27 +981,27 @@ def __init__( self.version = version -class JumpStartCachedS3ContentKey(JumpStartDataHolderType): - """Data class for the s3 cached content keys.""" +class JumpStartCachedContentKey(JumpStartDataHolderType): + """Data class for the cached content keys.""" - __slots__ = ["file_type", "s3_key"] + __slots__ = ["data_type", "id_info"] def __init__( self, - file_type: JumpStartS3FileType, - s3_key: str, + data_type: JumpStartContentDataType, + id_info: str, ) -> None: """Instantiates JumpStartCachedS3ContentKey object. Args: - file_type (JumpStartS3FileType): JumpStart file type. - s3_key (str): object key in s3. + data_type (JumpStartContentDataType): JumpStart content data type. + id_info (str): if S3Content, object key in s3. if HubContent, hub content arn. """ - self.file_type = file_type - self.s3_key = s3_key + self.data_type = data_type + self.id_info = id_info -class JumpStartCachedS3ContentValue(JumpStartDataHolderType): +class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" __slots__ = ["formatted_content", "md5_hash"] diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 2621422811..6de2761d20 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging import os +import re from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse import boto3 @@ -810,3 +811,26 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version + + +def extract_info_from_hub_content_arn( + arn: str, +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Extracts hub_name, content_name, and content_version from a HubContentArn""" + + match = re.match(constants.HUB_MODEL_ARN_REGEX, arn) + if match: + hub_name = match.group(4) + hub_region = match.group(2) + content_name = match.group(5) + content_version = match.group(6) + + return hub_name, hub_region, content_name, content_version + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + hub_name = match.group(4) + hub_region = match.group(2) + return hub_name, hub_region, None, None + + return None, None, None, None diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/__init__.py b/tests/unit/sagemaker/jumpstart/curated_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 6633ecdc23..59d040e92f 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -420,8 +420,8 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache.get_manifest_file_s3_key() == manifest_file_key assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._model_id_semantic_version_manifest_key_cache._max_cache_items == max_semantic_version_cache_items @@ -854,3 +854,22 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( ), ] ) + + +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) +def test_jumpstart_cache_get_hub_model(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" + ) + assert get_spec_from_base_spec( + model_id="huggingface-mock-model-123", version="1.2.3" + ) == cache.get_hub_model(hub_model_arn=model_arn) + + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" + ) + assert get_spec_from_base_spec(model_id="pytorch-mock-model-123", version="*") == cache.get_hub_model( + hub_model_arn=model_arn + ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 556b99bc9c..f1d6ccba43 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1177,6 +1177,35 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type +def test_extract_info_from_hub_content_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.extract_info_from_hub_content_arn(model_arn) == ( + "MockHub", + "us-west-2", + "my-mock-model", + "1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = "nonsense-string" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = "" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" + ) + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 146c6fd1f7..76682c0f9e 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,8 +22,9 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + HubDataType, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, @@ -180,25 +181,33 @@ def get_spec_from_base_spec( def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, - key: JumpStartCachedS3ContentKey, - value: JumpStartCachedS3ContentValue, -) -> JumpStartCachedS3ContentValue: + key: JumpStartCachedContentKey, + value: JumpStartCachedContentValue, +) -> JumpStartCachedContentValue: - filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.MANIFEST: + datatype, id_info = key.data_type, key.id_info + if datatype == JumpStartS3FileType.MANIFEST: - return JumpStartCachedS3ContentValue( - formatted_content=get_formatted_manifest(BASE_MANIFEST) - ) + return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) - if filetype == JumpStartS3FileType.SPECS: - _, model_id, specs_version = s3_key.split("/") + if datatype == JumpStartS3FileType.SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - raise ValueError(f"Bad value for filetype: {filetype}") + if datatype == HubDataType.MODEL: + _, _, _, model_name, model_version = id_info.split("/") + return JumpStartCachedContentValue( + formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) + ) + + # TODO: Implement + if datatype == HubDataType.HUB: + return None + + raise ValueError(f"Bad value for filetype: {datatype}") def overwrite_dictionary( From ef042d9010192c160e1f1bd06a05a96123b4417a Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 19:56:02 +0000 Subject: [PATCH 2/6] update types and var names --- src/sagemaker/jumpstart/cache.py | 8 ++++---- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 7 ------- src/sagemaker/jumpstart/types.py | 4 ++-- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index a327b4a87e..d7cfb850f9 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -467,8 +467,8 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) - return specs.formatted_content + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + return details.formatted_content def get_hub(self, hub_arn: str) -> Dict[str, Any]: """Return descriptive info for a given Hub @@ -477,8 +477,8 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) - return manifest.formatted_content + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + return details.formatted_content def clear(self) -> None: """Clears the model ID/version and s3 cache.""" diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 5ef0579609..796651dc16 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -12,8 +12,6 @@ # language governing permissions and limitations under the License. from typing import Optional, Dict, Any -import boto3 - from sagemaker.session import Session from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG @@ -26,13 +24,8 @@ def __init__(self, hub_name: str, region: str, session: Optional[Session]): self.hub_name = hub_name self.region = region self.session = session - self._s3_client = self._get_s3_client() self._sm_session = session or Session() - def _get_s3_client(self) -> Any: - """Returns an S3 client.""" - return boto3.client("s3", region_name=self._region, config=DEFAULT_CLIENT_CONFIG) - def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: """Returns descriptive information about the Hub Model""" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index c811475efd..5a4e91d092 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -991,7 +991,7 @@ def __init__( data_type: JumpStartContentDataType, id_info: str, ) -> None: - """Instantiates JumpStartCachedS3ContentKey object. + """Instantiates JumpStartCachedContentKey object. Args: data_type (JumpStartContentDataType): JumpStart content data type. @@ -1014,7 +1014,7 @@ def __init__( ], md5_hash: Optional[str] = None, ) -> None: - """Instantiates JumpStartCachedS3ContentValue object. + """Instantiates JumpStartCachedContentValue object. Args: formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], From 49ae11bf9a043b2d1dbf645c1fbab950d22efa30 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:39:28 +0000 Subject: [PATCH 3/6] update linter --- src/sagemaker/jumpstart/cache.py | 20 +++++++++++++------ .../jumpstart/curated_hub/curated_hub.py | 2 +- tests/unit/sagemaker/jumpstart/test_cache.py | 14 +++++-------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index d7cfb850f9..3c7105a87c 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -339,17 +339,23 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.MODEL: - hub_name, hub_region, model_id, model_version = utils.extract_info_from_hub_content_arn(id_info) - hub = CuratedHub(hub_name=hub_name, hub_region=hub_region) - hub_content = hub.describe_model(model_id=model_id, model_version=model_version) - utils.emit_logs_based_on_model_specs(hub_content.content_document, self.get_region(), self._s3_client) + hub_name, hub_region, model_name, model_version = utils.extract_info_from_hub_content_arn( + id_info + ) + hub = CuratedHub(hub_name=hub_name, region=hub_region) + hub_content = hub.describe_model(model_name=model_name, model_version=model_version) + utils.emit_logs_based_on_model_specs( + hub_content.content_document, + self.get_region(), + self._s3_client + ) model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True) return JumpStartCachedContentValue( formatted_content=model_specs ) if data_type == HubDataType.HUB: hub_name, hub_region, _, _ = utils.extract_info_from_hub_content_arn(id_info) - hub = CuratedHub(hub_name=hub_name, hub_region=hub_region) + hub = CuratedHub(hub_name=hub_name, region=hub_region) hub_info = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( @@ -467,7 +473,9 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + details, _ = self._content_cache.get( + JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) + ) return details.formatted_content def get_hub(self, hub_arn: str) -> Dict[str, Any]: diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 796651dc16..58d899cad9 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -20,7 +20,7 @@ class CuratedHub: """Class for creating and managing a curated JumpStart hub""" - def __init__(self, hub_name: str, region: str, session: Optional[Session]): + def __init__(self, hub_name: str, region: str, session: Optional[Session] = None): self.hub_name = hub_name self.region = region self.session = session diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 59d040e92f..69c8659148 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -860,16 +860,12 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( def test_jumpstart_cache_get_hub_model(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" - ) + model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" assert get_spec_from_base_spec( model_id="huggingface-mock-model-123", version="1.2.3" ) == cache.get_hub_model(hub_model_arn=model_arn) - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" - ) - assert get_spec_from_base_spec(model_id="pytorch-mock-model-123", version="*") == cache.get_hub_model( - hub_model_arn=model_arn - ) + model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" + assert get_spec_from_base_spec( + model_id="pytorch-mock-model-123", version="*" + ) == cache.get_hub_model(hub_model_arn=model_arn) From 6175087f94eb5a966b12d784ee7adb2e491463bb Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:49:41 +0000 Subject: [PATCH 4/6] linter --- src/sagemaker/jumpstart/cache.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 3c7105a87c..d733f39864 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -339,10 +339,10 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.MODEL: - hub_name, hub_region, model_name, model_version = utils.extract_info_from_hub_content_arn( + hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn( id_info ) - hub = CuratedHub(hub_name=hub_name, region=hub_region) + hub = CuratedHub(hub_name=hub_name, region=region) hub_content = hub.describe_model(model_name=model_name, model_version=model_version) utils.emit_logs_based_on_model_specs( hub_content.content_document, @@ -354,8 +354,8 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.HUB: - hub_name, hub_region, _, _ = utils.extract_info_from_hub_content_arn(id_info) - hub = CuratedHub(hub_name=hub_name, region=hub_region) + hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info) + hub = CuratedHub(hub_name=hub_name, region=region) hub_info = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( @@ -465,10 +465,10 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS ) ) return specs.formatted_content - + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model - + Args: hub_model_arn (str): Arn for the Hub model to get specs for """ @@ -477,14 +477,14 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) ) return details.formatted_content - + def get_hub(self, hub_arn: str) -> Dict[str, Any]: """Return descriptive info for a given Hub - + Args: hub_arn (str): Arn for the Hub to get info for """ - + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) return details.formatted_content From 4c9b2d036efc351c530dd1bc21f2bace55f73cd5 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:58:42 +0000 Subject: [PATCH 5/6] linter --- src/sagemaker/jumpstart/curated_hub/constants.py | 15 --------------- .../jumpstart/curated_hub/curated_hub.py | 4 +--- 2 files changed, 1 insertion(+), 18 deletions(-) delete mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py deleted file mode 100644 index 478c7e801b..0000000000 --- a/src/sagemaker/jumpstart/curated_hub/constants.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from botocore.config import Config - -DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"}) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 58d899cad9..c39bd30d52 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -10,13 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""This module provides the JumpStart Curated Hub class.""" from typing import Optional, Dict, Any from sagemaker.session import Session -from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG - - class CuratedHub: """Class for creating and managing a curated JumpStart hub""" From 63345ea385c3f4cf3c6908278057b815032fb65f Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 22:04:23 +0000 Subject: [PATCH 6/6] flake8 check --- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index c39bd30d52..273deb097b 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -11,10 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """This module provides the JumpStart Curated Hub class.""" +from __future__ import absolute_import + from typing import Optional, Dict, Any from sagemaker.session import Session + class CuratedHub: """Class for creating and managing a curated JumpStart hub"""