From 032cb8073743b8a3684ef29ccfeb1b73ec056f1b Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 19:31:18 +0000 Subject: [PATCH 01/31] add hub and hubcontent support in retrieval function for jumpstart model cache --- src/sagemaker/jumpstart/cache.py | 91 +++++++++++++------ src/sagemaker/jumpstart/constants.py | 4 + .../jumpstart/curated_hub/__init__.py | 0 .../jumpstart/curated_hub/constants.py | 15 +++ .../jumpstart/curated_hub/curated_hub.py | 55 +++++++++++ src/sagemaker/jumpstart/types.py | 47 +++++++--- src/sagemaker/jumpstart/utils.py | 24 +++++ .../jumpstart/curated_hub/__init__.py | 0 .../jumpstart/curated_hub/test_curated_hub.py | 0 tests/unit/sagemaker/jumpstart/test_cache.py | 23 ++++- tests/unit/sagemaker/jumpstart/test_utils.py | 29 ++++++ tests/unit/sagemaker/jumpstart/utils.py | 37 +++++--- 12 files changed, 270 insertions(+), 55 deletions(-) create mode 100644 src/sagemaker/jumpstart/curated_hub/__init__.py create mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py create mode 100644 src/sagemaker/jumpstart/curated_hub/curated_hub.py create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/__init__.py create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e26d588167..a327b4a87e 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -15,7 +15,7 @@ import datetime from difflib import get_close_matches import os -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import json import boto3 import botocore @@ -29,6 +29,7 @@ JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, ) +from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, @@ -37,12 +38,13 @@ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, + HubDataType, ) from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache @@ -95,7 +97,7 @@ def __init__( """ self._region = region - self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( + self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, @@ -172,8 +174,8 @@ def _get_manifest_key_from_model_id_semantic_version( model_id, version = key.model_id, key.version - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + manifest = self._content_cache.get( + JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) )[0].formatted_content sm_version = utils.get_sagemaker_version() @@ -301,50 +303,65 @@ def _get_json_file_from_local_override( def _retrieval_function( self, - key: JumpStartCachedS3ContentKey, - value: Optional[JumpStartCachedS3ContentValue], - ) -> JumpStartCachedS3ContentValue: - """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + key: JumpStartCachedContentKey, + value: Optional[JumpStartCachedContentValue], + ) -> JumpStartCachedContentValue: + """Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey``. If a manifest file is being fetched, we only download the object if the md5 hash in ``head_object`` does not match the current md5 hash for the stored value. This prevents unnecessarily downloading the full manifest when it hasn't changed. Args: - key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. + key (JumpStartCachedContentKey): key for which to fetch JumpStart content. value (Optional[JumpStartVersionedModelId]): Current value of old cached s3 content. This is used for the manifest file, so that it is only downloaded when its content changes. """ - file_type, s3_key = key.file_type, key.s3_key + data_type, id_info = key.data_type, key.id_info - if file_type == JumpStartS3FileType.MANIFEST: + if data_type == JumpStartS3FileType.MANIFEST: if value is not None and not self._is_local_metadata_mode(): - etag = self._get_json_md5_hash(s3_key) + etag = self._get_json_md5_hash(id_info) if etag == value.md5_hash: return value - formatted_body, etag = self._get_json_file(s3_key, file_type) - return JumpStartCachedS3ContentValue( + formatted_body, etag = self._get_json_file(id_info, data_type) + return JumpStartCachedContentValue( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type == JumpStartS3FileType.SPECS: - formatted_body, _ = self._get_json_file(s3_key, file_type) + if data_type == JumpStartS3FileType.SPECS: + formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=model_specs ) + if data_type == HubDataType.MODEL: + hub_name, hub_region, model_id, model_version = utils.extract_info_from_hub_content_arn(id_info) + hub = CuratedHub(hub_name=hub_name, hub_region=hub_region) + hub_content = hub.describe_model(model_id=model_id, model_version=model_version) + utils.emit_logs_based_on_model_specs(hub_content.content_document, self.get_region(), self._s3_client) + model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True) + return JumpStartCachedContentValue( + formatted_content=model_specs + ) + if data_type == HubDataType.HUB: + hub_name, hub_region, _, _ = utils.extract_info_from_hub_content_arn(id_info) + hub = CuratedHub(hub_name=hub_name, hub_region=hub_region) + hub_info = hub.describe() + return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( - f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" + f"Bad value for key '{key}': must be in", + f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}" ) def get_manifest(self) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + manifest_dict = self._content_cache.get( + JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) )[0].formatted_content manifest = list(manifest_dict.values()) # type: ignore return manifest @@ -407,8 +424,8 @@ def _get_header_impl( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) + manifest = self._content_cache.get( + JumpStartCachedContentKey(JumpStartS3FileType.MANIFEST, self._manifest_file_s3_key) )[0].formatted_content try: header = manifest[versioned_model_id] # type: ignore @@ -430,8 +447,8 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS header = self.get_header(model_id, semantic_version_str) spec_key = header.spec_key - specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(JumpStartS3FileType.SPECS, spec_key) + specs, cache_hit = self._content_cache.get( + JumpStartCachedContentKey(JumpStartS3FileType.SPECS, spec_key) ) if not cache_hit and "*" in semantic_version_str: JUMPSTART_LOGGER.warning( @@ -442,8 +459,28 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS ) ) return specs.formatted_content + + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + return specs.formatted_content + + def get_hub(self, hub_arn: str) -> Dict[str, Any]: + """Return descriptive info for a given Hub + + Args: + hub_arn (str): Arn for the Hub to get info for + """ + + manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + return manifest.formatted_content def clear(self) -> None: """Clears the model ID/version and s3 cache.""" - self._s3_cache.clear() + self._content_cache.clear() self._model_id_semantic_version_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 2e655ac285..0e552aaac6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -170,6 +170,10 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" +# works cross-partition +HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" +HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" diff --git a/src/sagemaker/jumpstart/curated_hub/__init__.py b/src/sagemaker/jumpstart/curated_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py new file mode 100644 index 0000000000..478c7e801b --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/constants.py @@ -0,0 +1,15 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.config import Config + +DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"}) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py new file mode 100644 index 0000000000..5ef0579609 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from typing import Optional, Dict, Any + +import boto3 + +from sagemaker.session import Session + +from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG + + +class CuratedHub: + """Class for creating and managing a curated JumpStart hub""" + + def __init__(self, hub_name: str, region: str, session: Optional[Session]): + self.hub_name = hub_name + self.region = region + self.session = session + self._s3_client = self._get_s3_client() + self._sm_session = session or Session() + + def _get_s3_client(self) -> Any: + """Returns an S3 client.""" + return boto3.client("s3", region_name=self._region, config=DEFAULT_CLIENT_CONFIG) + + def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: + """Returns descriptive information about the Hub Model""" + + hub_content = self._sm_session.describe_hub_content( + model_name, "Model", self.hub_name, model_version + ) + + # TODO: Parse HubContent + # TODO: Parse HubContentDocument + + return hub_content + + def describe(self) -> Dict[str, Any]: + """Returns descriptive information about the Hub""" + + hub_info = self._sm_session.describe_hub(hub_name=self.hub_name) + + # TODO: Validations? + + return hub_info diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 49d3e295c5..c811475efd 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -106,6 +106,17 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" +class HubDataType(str, Enum): + """Enum for Hub data storage objects.""" + + HUB = "hub" + MODEL = "model" + NOTEBOOK = "notebook" + + +JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType] + + class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" @@ -767,13 +778,16 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "gated_bucket", ] - def __init__(self, spec: Dict[str, Any]): + def __init__(self, spec: Dict[str, Any], is_hub_content: bool = False): """Initializes a JumpStartModelSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ - self.from_json(spec) + if is_hub_content: + self.from_hub_content_doc(spec) + else: + self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: """Sets fields in object based on json of header. @@ -895,6 +909,15 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: else None ) + def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: + """Sets fields in object based on values in HubContentDocument + + Args: + hub_content_doc (Dict[str, any]): parsed HubContentDocument returned + from SageMaker:DescribeHubContent + """ + # TODO: Implement + def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartModelSpecs object.""" json_obj = {} @@ -958,27 +981,27 @@ def __init__( self.version = version -class JumpStartCachedS3ContentKey(JumpStartDataHolderType): - """Data class for the s3 cached content keys.""" +class JumpStartCachedContentKey(JumpStartDataHolderType): + """Data class for the cached content keys.""" - __slots__ = ["file_type", "s3_key"] + __slots__ = ["data_type", "id_info"] def __init__( self, - file_type: JumpStartS3FileType, - s3_key: str, + data_type: JumpStartContentDataType, + id_info: str, ) -> None: """Instantiates JumpStartCachedS3ContentKey object. Args: - file_type (JumpStartS3FileType): JumpStart file type. - s3_key (str): object key in s3. + data_type (JumpStartContentDataType): JumpStart content data type. + id_info (str): if S3Content, object key in s3. if HubContent, hub content arn. """ - self.file_type = file_type - self.s3_key = s3_key + self.data_type = data_type + self.id_info = id_info -class JumpStartCachedS3ContentValue(JumpStartDataHolderType): +class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" __slots__ = ["formatted_content", "md5_hash"] diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 2621422811..6de2761d20 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging import os +import re from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse import boto3 @@ -810,3 +811,26 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version + + +def extract_info_from_hub_content_arn( + arn: str, +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Extracts hub_name, content_name, and content_version from a HubContentArn""" + + match = re.match(constants.HUB_MODEL_ARN_REGEX, arn) + if match: + hub_name = match.group(4) + hub_region = match.group(2) + content_name = match.group(5) + content_version = match.group(6) + + return hub_name, hub_region, content_name, content_version + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + hub_name = match.group(4) + hub_region = match.group(2) + return hub_name, hub_region, None, None + + return None, None, None, None diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/__init__.py b/tests/unit/sagemaker/jumpstart/curated_hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 6633ecdc23..59d040e92f 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -420,8 +420,8 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache.get_manifest_file_s3_key() == manifest_file_key assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._model_id_semantic_version_manifest_key_cache._max_cache_items == max_semantic_version_cache_items @@ -854,3 +854,22 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( ), ] ) + + +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) +def test_jumpstart_cache_get_hub_model(): + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" + ) + assert get_spec_from_base_spec( + model_id="huggingface-mock-model-123", version="1.2.3" + ) == cache.get_hub_model(hub_model_arn=model_arn) + + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" + ) + assert get_spec_from_base_spec(model_id="pytorch-mock-model-123", version="*") == cache.get_hub_model( + hub_model_arn=model_arn + ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 556b99bc9c..f1d6ccba43 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1177,6 +1177,35 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type +def test_extract_info_from_hub_content_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.extract_info_from_hub_content_arn(model_arn) == ( + "MockHub", + "us-west-2", + "my-mock-model", + "1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = "nonsense-string" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = "" + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + invalid_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" + ) + assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + + class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 146c6fd1f7..76682c0f9e 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,8 +22,9 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + HubDataType, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, @@ -180,25 +181,33 @@ def get_spec_from_base_spec( def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, - key: JumpStartCachedS3ContentKey, - value: JumpStartCachedS3ContentValue, -) -> JumpStartCachedS3ContentValue: + key: JumpStartCachedContentKey, + value: JumpStartCachedContentValue, +) -> JumpStartCachedContentValue: - filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.MANIFEST: + datatype, id_info = key.data_type, key.id_info + if datatype == JumpStartS3FileType.MANIFEST: - return JumpStartCachedS3ContentValue( - formatted_content=get_formatted_manifest(BASE_MANIFEST) - ) + return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) - if filetype == JumpStartS3FileType.SPECS: - _, model_id, specs_version = s3_key.split("/") + if datatype == JumpStartS3FileType.SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - raise ValueError(f"Bad value for filetype: {filetype}") + if datatype == HubDataType.MODEL: + _, _, _, model_name, model_version = id_info.split("/") + return JumpStartCachedContentValue( + formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) + ) + + # TODO: Implement + if datatype == HubDataType.HUB: + return None + + raise ValueError(f"Bad value for filetype: {datatype}") def overwrite_dictionary( From ef042d9010192c160e1f1bd06a05a96123b4417a Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 19:56:02 +0000 Subject: [PATCH 02/31] update types and var names --- src/sagemaker/jumpstart/cache.py | 8 ++++---- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 7 ------- src/sagemaker/jumpstart/types.py | 4 ++-- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index a327b4a87e..d7cfb850f9 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -467,8 +467,8 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) - return specs.formatted_content + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + return details.formatted_content def get_hub(self, hub_arn: str) -> Dict[str, Any]: """Return descriptive info for a given Hub @@ -477,8 +477,8 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) - return manifest.formatted_content + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + return details.formatted_content def clear(self) -> None: """Clears the model ID/version and s3 cache.""" diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 5ef0579609..796651dc16 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -12,8 +12,6 @@ # language governing permissions and limitations under the License. from typing import Optional, Dict, Any -import boto3 - from sagemaker.session import Session from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG @@ -26,13 +24,8 @@ def __init__(self, hub_name: str, region: str, session: Optional[Session]): self.hub_name = hub_name self.region = region self.session = session - self._s3_client = self._get_s3_client() self._sm_session = session or Session() - def _get_s3_client(self) -> Any: - """Returns an S3 client.""" - return boto3.client("s3", region_name=self._region, config=DEFAULT_CLIENT_CONFIG) - def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: """Returns descriptive information about the Hub Model""" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index c811475efd..5a4e91d092 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -991,7 +991,7 @@ def __init__( data_type: JumpStartContentDataType, id_info: str, ) -> None: - """Instantiates JumpStartCachedS3ContentKey object. + """Instantiates JumpStartCachedContentKey object. Args: data_type (JumpStartContentDataType): JumpStart content data type. @@ -1014,7 +1014,7 @@ def __init__( ], md5_hash: Optional[str] = None, ) -> None: - """Instantiates JumpStartCachedS3ContentValue object. + """Instantiates JumpStartCachedContentValue object. Args: formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], From 49ae11bf9a043b2d1dbf645c1fbab950d22efa30 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:39:28 +0000 Subject: [PATCH 03/31] update linter --- src/sagemaker/jumpstart/cache.py | 20 +++++++++++++------ .../jumpstart/curated_hub/curated_hub.py | 2 +- tests/unit/sagemaker/jumpstart/test_cache.py | 14 +++++-------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index d7cfb850f9..3c7105a87c 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -339,17 +339,23 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.MODEL: - hub_name, hub_region, model_id, model_version = utils.extract_info_from_hub_content_arn(id_info) - hub = CuratedHub(hub_name=hub_name, hub_region=hub_region) - hub_content = hub.describe_model(model_id=model_id, model_version=model_version) - utils.emit_logs_based_on_model_specs(hub_content.content_document, self.get_region(), self._s3_client) + hub_name, hub_region, model_name, model_version = utils.extract_info_from_hub_content_arn( + id_info + ) + hub = CuratedHub(hub_name=hub_name, region=hub_region) + hub_content = hub.describe_model(model_name=model_name, model_version=model_version) + utils.emit_logs_based_on_model_specs( + hub_content.content_document, + self.get_region(), + self._s3_client + ) model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True) return JumpStartCachedContentValue( formatted_content=model_specs ) if data_type == HubDataType.HUB: hub_name, hub_region, _, _ = utils.extract_info_from_hub_content_arn(id_info) - hub = CuratedHub(hub_name=hub_name, hub_region=hub_region) + hub = CuratedHub(hub_name=hub_name, region=hub_region) hub_info = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( @@ -467,7 +473,9 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + details, _ = self._content_cache.get( + JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) + ) return details.formatted_content def get_hub(self, hub_arn: str) -> Dict[str, Any]: diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 796651dc16..58d899cad9 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -20,7 +20,7 @@ class CuratedHub: """Class for creating and managing a curated JumpStart hub""" - def __init__(self, hub_name: str, region: str, session: Optional[Session]): + def __init__(self, hub_name: str, region: str, session: Optional[Session] = None): self.hub_name = hub_name self.region = region self.session = session diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 59d040e92f..69c8659148 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -860,16 +860,12 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( def test_jumpstart_cache_get_hub_model(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" - ) + model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" assert get_spec_from_base_spec( model_id="huggingface-mock-model-123", version="1.2.3" ) == cache.get_hub_model(hub_model_arn=model_arn) - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" - ) - assert get_spec_from_base_spec(model_id="pytorch-mock-model-123", version="*") == cache.get_hub_model( - hub_model_arn=model_arn - ) + model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" + assert get_spec_from_base_spec( + model_id="pytorch-mock-model-123", version="*" + ) == cache.get_hub_model(hub_model_arn=model_arn) From 6175087f94eb5a966b12d784ee7adb2e491463bb Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:49:41 +0000 Subject: [PATCH 04/31] linter --- src/sagemaker/jumpstart/cache.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 3c7105a87c..d733f39864 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -339,10 +339,10 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.MODEL: - hub_name, hub_region, model_name, model_version = utils.extract_info_from_hub_content_arn( + hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn( id_info ) - hub = CuratedHub(hub_name=hub_name, region=hub_region) + hub = CuratedHub(hub_name=hub_name, region=region) hub_content = hub.describe_model(model_name=model_name, model_version=model_version) utils.emit_logs_based_on_model_specs( hub_content.content_document, @@ -354,8 +354,8 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.HUB: - hub_name, hub_region, _, _ = utils.extract_info_from_hub_content_arn(id_info) - hub = CuratedHub(hub_name=hub_name, region=hub_region) + hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info) + hub = CuratedHub(hub_name=hub_name, region=region) hub_info = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( @@ -465,10 +465,10 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS ) ) return specs.formatted_content - + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model - + Args: hub_model_arn (str): Arn for the Hub model to get specs for """ @@ -477,14 +477,14 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) ) return details.formatted_content - + def get_hub(self, hub_arn: str) -> Dict[str, Any]: """Return descriptive info for a given Hub - + Args: hub_arn (str): Arn for the Hub to get info for """ - + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) return details.formatted_content From 4c9b2d036efc351c530dd1bc21f2bace55f73cd5 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:58:42 +0000 Subject: [PATCH 05/31] linter --- src/sagemaker/jumpstart/curated_hub/constants.py | 15 --------------- .../jumpstart/curated_hub/curated_hub.py | 4 +--- 2 files changed, 1 insertion(+), 18 deletions(-) delete mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py deleted file mode 100644 index 478c7e801b..0000000000 --- a/src/sagemaker/jumpstart/curated_hub/constants.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from botocore.config import Config - -DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"}) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 58d899cad9..c39bd30d52 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -10,13 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""This module provides the JumpStart Curated Hub class.""" from typing import Optional, Dict, Any from sagemaker.session import Session -from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG - - class CuratedHub: """Class for creating and managing a curated JumpStart hub""" From 63345ea385c3f4cf3c6908278057b815032fb65f Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 22:04:23 +0000 Subject: [PATCH 06/31] flake8 check --- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index c39bd30d52..273deb097b 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -11,10 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """This module provides the JumpStart Curated Hub class.""" +from __future__ import absolute_import + from typing import Optional, Dict, Any from sagemaker.session import Session + class CuratedHub: """Class for creating and managing a curated JumpStart hub""" From 6efc2065e3f1f06ad580106c0cd228409011a551 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Tue, 20 Feb 2024 22:07:36 +0000 Subject: [PATCH 07/31] add hub name support for jumpstart estimator --- src/sagemaker/instance_types.py | 4 + src/sagemaker/jumpstart/accessors.py | 15 +- .../jumpstart/artifacts/instance_types.py | 4 + src/sagemaker/jumpstart/artifacts/kwargs.py | 2 + src/sagemaker/jumpstart/constants.py | 4 +- .../jumpstart/curated_hub/curated_hub.py | 2 +- src/sagemaker/jumpstart/enums.py | 2 + src/sagemaker/jumpstart/estimator.py | 19 ++- src/sagemaker/jumpstart/factory/estimator.py | 13 ++ src/sagemaker/jumpstart/factory/model.py | 6 + src/sagemaker/jumpstart/types.py | 51 +++++++ src/sagemaker/jumpstart/utils.py | 87 +++++++++-- .../jumpstart/test_instance_types.py | 119 ++++++++++++++- tests/unit/sagemaker/jumpstart/constants.py | 7 +- .../jumpstart/estimator/test_estimator.py | 138 ++++++++++++++++++ .../sagemaker/jumpstart/test_accessors.py | 16 +- tests/unit/sagemaker/jumpstart/test_utils.py | 118 +++++++++++++-- tests/unit/sagemaker/jumpstart/utils.py | 34 +++-- 18 files changed, 591 insertions(+), 50 deletions(-) diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 0471f374ae..4e79e2b400 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -80,6 +81,7 @@ def retrieve_default( model_id, model_version, scope, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, @@ -92,6 +94,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -142,6 +145,7 @@ def retrieve( model_id, model_version, scope, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index e03a13a7a3..3434c25fae 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -18,7 +18,7 @@ from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs -from sagemaker.jumpstart import cache +from sagemaker.jumpstart import cache, utils from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -239,7 +239,11 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel @staticmethod def get_model_specs( - region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None + region: str, + model_id: str, + version: str, + hub_arn: Optional[str] = None, + s3_client: Optional[boto3.client] = None, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -259,6 +263,13 @@ def get_model_specs( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + + if hub_arn: + hub_model_arn = utils.construct_hub_model_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn) + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version ) diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 38e02e3ebd..176159c3c6 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -33,6 +33,7 @@ def _retrieve_default_instance_type( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -80,6 +81,7 @@ def _retrieve_default_instance_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -119,6 +121,7 @@ def _retrieve_instance_types( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -166,6 +169,7 @@ def _retrieve_instance_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 7acad9b793..7a8c153a88 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -198,6 +198,7 @@ def _retrieve_estimator_init_kwargs( def _retrieve_estimator_fit_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -234,6 +235,7 @@ def _retrieve_estimator_fit_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 0e552aaac6..ce11f63a1b 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -170,8 +170,8 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" -# works cross-partition -HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" +# works for cross-partition +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/(.*?)/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 273deb097b..9b0c8b68fb 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -25,7 +25,7 @@ def __init__(self, hub_name: str, region: str, session: Optional[Session] = None self.hub_name = hub_name self.region = region self.session = session - self._sm_session = session or Session() + self._sm_session = session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: """Returns descriptive information about the Hub Model""" diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index e33daca046..f962fdca80 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -79,6 +79,8 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" + HUB_ARN = "sagemaker-hub:hub-arn" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 24105c4369..d68bc1077e 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module stores JumpStart implementation of Estimator class.""" from __future__ import absolute_import +import re from typing import Dict, List, Optional, Union @@ -27,7 +28,7 @@ from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, HUB_ARN_REGEX from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -35,6 +36,7 @@ from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job from sagemaker.jumpstart.utils import ( + construct_hub_arn_from_name, is_valid_model_id, resolve_model_sagemaker_config_field, ) @@ -57,6 +59,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -122,6 +125,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies @@ -518,9 +522,19 @@ def _is_valid_model_id_hook(): if not _is_valid_model_id_hook(): raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) + # TODO: Update to handle SageMakerJumpStart hub_name + hub_arn = None + if hub_name: + match = re.match(HUB_ARN_REGEX, hub_name) + if match: + hub_arn = hub_name + else: + hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, sagemaker_session=sagemaker_session) + estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, @@ -576,6 +590,7 @@ def _is_valid_model_id_hook(): enable_remote_debug=enable_remote_debug, ) + self.hub_arn = estimator_init_kwargs.hub_arn self.model_id = estimator_init_kwargs.model_id self.model_version = estimator_init_kwargs.model_version self.instance_type = estimator_init_kwargs.instance_type @@ -652,6 +667,7 @@ def fit( estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, inputs=inputs, wait=wait, @@ -1018,6 +1034,7 @@ def deploy( estimator_deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7ccf57983b..47d8a71ebe 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -61,6 +61,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_arn_tags, add_jumpstart_model_id_version_tags, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, @@ -77,6 +78,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -134,6 +136,7 @@ def get_init_kwargs( estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, role=role, region=region, instance_count=instance_count, @@ -209,6 +212,7 @@ def get_init_kwargs( def get_fit_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, wait: Optional[bool] = None, @@ -224,6 +228,7 @@ def get_fit_kwargs( estimator_fit_kwargs: JumpStartEstimatorFitKwargs = JumpStartEstimatorFitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, inputs=inputs, wait=wait, @@ -246,6 +251,7 @@ def get_fit_kwargs( def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -290,6 +296,7 @@ def get_deploy_kwargs( model_deploy_kwargs: JumpStartModelDeployKwargs = model.get_deploy_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -432,6 +439,7 @@ def _add_instance_type_and_count_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -465,6 +473,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima kwargs.tags = add_jumpstart_model_id_version_tags( kwargs.tags, kwargs.model_id, full_model_version ) + + if kwargs.hub_arn: + kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -728,6 +740,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim fit_kwargs_to_add = _retrieve_estimator_fit_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 64e4727116..1dfe9ef5e2 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -44,6 +44,7 @@ JumpStartModelRegisterKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_arn_tags, add_jumpstart_model_id_version_tags, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, @@ -447,6 +448,9 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: kwargs.tags, kwargs.model_id, full_model_version ) + if kwargs.hub_arn: + kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -489,6 +493,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -521,6 +526,7 @@ def get_deploy_kwargs( deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 5a4e91d092..cc34a257be 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1151,6 +1151,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "initial_instance_count", "instance_type", "region", @@ -1182,6 +1183,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1193,6 +1195,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1224,6 +1227,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1258,6 +1262,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "instance_count", "region", @@ -1317,12 +1322,14 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "hub_arn", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, @@ -1379,6 +1386,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.instance_count = instance_count self.region = region @@ -1440,6 +1448,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "region", "inputs", "wait", @@ -1454,6 +1463,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1464,6 +1474,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, wait: Optional[bool] = None, @@ -1478,6 +1489,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.region = region self.inputs = inputs self.wait = wait @@ -1495,6 +1507,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "initial_instance_count", "region", @@ -1540,6 +1553,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", } @@ -1547,6 +1561,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1589,6 +1604,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.initial_instance_count = initial_instance_count self.region = region @@ -1730,3 +1746,38 @@ def __init__( self.nearest_model_name = nearest_model_name self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation + + + +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_type = hub_content_type + self.hub_content_name = hub_content_name + self.hub_content_version = hub_content_version diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 6de2761d20..a4d9b9233b 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -36,13 +36,14 @@ get_old_model_version_msg, ) from sagemaker.jumpstart.types import ( + HubArnExtractedInfo, JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import resolve_value_from_config, TagsDict +from sagemaker.utils import aws_partition, resolve_value_from_config, TagsDict from sagemaker.workflow import is_pipeline_variable @@ -368,6 +369,21 @@ def add_jumpstart_model_id_version_tags( return tags +def add_hub_arn_tags( + tags: Optional[List[TagsDict]], + hub_arn: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + hub_arn, + enums.JumpStartTag.HUB_ARN, + tags, + is_uri=False, + ) + return tags + + def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, inference_model_uri: Optional[Union[str, dict]] = None, @@ -528,6 +544,7 @@ def verify_model_region_and_return_specs( version: Optional[str], scope: Optional[str], region: str, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -577,6 +594,7 @@ def verify_model_region_and_return_specs( model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore region=region, model_id=model_id, + hub_arn=hub_arn, version=version, s3_client=sagemaker_session.s3_client, ) @@ -815,22 +833,67 @@ def get_jumpstart_model_id_version_from_resource_arn( def extract_info_from_hub_content_arn( arn: str, -) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: - """Extracts hub_name, content_name, and content_version from a HubContentArn""" +) -> HubArnExtractedInfo: + """Extracts descriptive information from a Hub or HubContent Arn""" - match = re.match(constants.HUB_MODEL_ARN_REGEX, arn) + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) if match: - hub_name = match.group(4) + partition = match.group(1) hub_region = match.group(2) - content_name = match.group(5) - content_version = match.group(6) - - return hub_name, hub_region, content_name, content_version + account_id = match.group(3) + hub_name = match.group(4) + hub_content_type = match.group(5) + hub_content_name = match.group(6) + hub_content_version = match.group(7) + + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + hub_content_type=hub_content_type, + hub_content_name=hub_content_name, + hub_content_version=hub_content_version, + ) match = re.match(constants.HUB_ARN_REGEX, arn) if match: - hub_name = match.group(4) + partition = match.group(1) hub_region = match.group(2) - return hub_name, hub_region, None, None + account_id = match.group(3) + hub_name = match.group(4) + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + ) + + return None + + +def construct_hub_arn_from_name( + hub_name: str, + region: Optional[str] = None, + session: Optional[Session] = None, +) -> str: + """Constructs a Hub arn from the Hub name using default Session values""" + print('being called') + + if not session: + session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION + + account_id = session.account_id() + region = region or session.boto_region_name + partition = aws_partition(region) + + return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" + + +def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version""" + + info = extract_info_from_hub_content_arn(hub_arn) + arn = f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/{info.hub_name}/Model/{model_name}/{version}" - return None, None, None, None + return arn diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index bed2e50674..50f35cb872 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -46,6 +46,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, ) @@ -64,6 +65,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, ) @@ -88,6 +90,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, ) @@ -111,7 +114,11 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + hub_arn=None, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() @@ -163,6 +170,116 @@ def test_jumpstart_instance_types(patched_get_model_specs): instance_types.retrieve(model_id=model_id, scope="training") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_instance_types_from_hub(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_spec_from_base_spec + + model_id, model_version = "huggingface-eqa-bert-base-cased", "*" + region = "us-west-2" + hub_arn = f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub" + + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + + default_training_instance_types = instance_types.retrieve_default( + region=region, + model_id=model_id, + hub_arn=hub_arn, + model_version=model_version, + scope="training", + sagemaker_session=mock_session, + ) + + assert default_training_instance_types == "ml.p3.2xlarge" + + patched_get_model_specs.assert_called_once_with( + hub_arn=hub_arn, + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + ) + + patched_get_model_specs.reset_mock() + + default_inference_instance_types = instance_types.retrieve_default( + region=region, + model_id=model_id, + hub_arn=hub_arn, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + ) + + assert default_inference_instance_types == "ml.p2.xlarge" + + patched_get_model_specs.assert_called_once_with( + hub_arn=hub_arn, + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + ) + + patched_get_model_specs.reset_mock() + + default_training_instance_types = instance_types.retrieve( + region=region, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope="training", + sagemaker_session=mock_session, + ) + assert default_training_instance_types == [ + "ml.p3.2xlarge", + "ml.p2.xlarge", + "ml.g4dn.2xlarge", + "ml.m5.xlarge", + "ml.c5.2xlarge", + ] + + patched_get_model_specs.assert_called_once_with( + hub_arn=hub_arn, + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + ) + + patched_get_model_specs.reset_mock() + + default_inference_instance_types = instance_types.retrieve( + region=region, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope="inference", + sagemaker_session=mock_session, + ) + + assert default_inference_instance_types == [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + ] + + patched_get_model_specs.assert_called_once_with( + hub_arn=hub_arn, + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + ) + + patched_get_model_specs.reset_mock() + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_inference_instance_type_variants(patched_get_model_specs): patched_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index a3c4c747f7..4b8a49764d 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1831,7 +1831,7 @@ "estimator_kwargs": { "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "fit_kwargs": {"job_name": "some-estimator-fit-value"}, "predictor_specs": { "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], @@ -2058,7 +2058,7 @@ "estimator_kwargs": { "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "fit_kwargs": {"job_name": "some-estimator-fit-value"}, "predictor_specs": { "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], @@ -2284,7 +2284,7 @@ "estimator_kwargs": { "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "fit_kwargs": {"job_name": "some-estimator-fit-value"}, "predictor_specs": { "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], @@ -6239,7 +6239,6 @@ "training_volume_size": 456, "inference_enable_network_isolation": True, "training_enable_network_isolation": False, - "resource_name_base": "dfsdfsds", "hosting_resource_requirements": {"num_accelerators": 1, "min_memory_mb": 34360}, "dynamic_container_deployment_supported": True, } diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4dc35b65ca..9046d48a1c 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -34,6 +34,7 @@ from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag from sagemaker.jumpstart.estimator import JumpStartEstimator +from sagemaker.jumpstart.types import JumpStartEstimatorFitKwargs from sagemaker.jumpstart.utils import get_jumpstart_content_bucket from sagemaker.session_settings import SessionSettings @@ -41,6 +42,7 @@ from sagemaker.model import Model from sagemaker.predictor import Predictor from tests.unit.sagemaker.jumpstart.utils import ( + get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, ) @@ -281,6 +283,141 @@ def test_prepacked( ], ) + @mock.patch("sagemaker.session.Session.account_id") + @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") + @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_deploy_kwargs") + @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_fit_kwargs") + @mock.patch("sagemaker.jumpstart.estimator.construct_hub_arn_from_name") + @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_hub_model( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_construct_hub_arn_from_name: mock.Mock, + mock_retrieve_estimator_fit_kwargs: mock.Mock, + mock_retrieve_model_deploy_kwargs: mock.Mock, + mock_retrieve_model_init_kwargs: mock.Mock, + mock_get_caller_identity: mock.Mock, + ): + mock_get_caller_identity.return_value = "123456789123" + mock_estimator_deploy.return_value = default_predictor + + mock_is_valid_model_id.return_value = True + + model_id, _ = "pytorch-hub-model-1", "*" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + mock_retrieve_estimator_fit_kwargs.return_value = {} + mock_retrieve_model_deploy_kwargs.return_value = {} + mock_retrieve_model_init_kwargs.return_value = {} + + mock_get_model_specs.side_effect = get_spec_from_base_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=hub_arn, + ) + + mock_construct_hub_arn_from_name.assert_not_called() + + mock_estimator_init.assert_called_once_with( + instance_type="ml.p3.2xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + model_uri="s3://jumpstart-cache-prod-us-west-2/pytorch-training/" + "train-pytorch-ic-mobilenet-v2.tar.gz", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "epochs": "3", + "adam-learning-rate": "0.05", + "batch-size": "4", + }, + metric_definitions=[ + {"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"} + ], + role=execution_role, + volume_size=456, + sagemaker_session=sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-hub-model-1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "*"}, + { + "Key": JumpStartTag.HUB_ARN, + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + }, + ], + encrypt_inter_container_traffic=True, + enable_network_isolation=False, + ) + + channels = { + "training": f"s3://{get_jumpstart_content_bucket(region)}/" + f"some-training-dataset-doesn't-matter", + } + + estimator.fit(channels) + + mock_estimator_fit.assert_called_once_with(inputs=channels, wait=True) + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference" + ":1.5.0-gpu-py3", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + predictor_cls=Predictor, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-hub-model-1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "*"}, + { + "Key": JumpStartTag.HUB_ARN, + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + }, + ], + role=execution_role, + wait=True, + use_compiled_model=False, + ) + + # Test Hub arn util + estimator = JumpStartEstimator( + model_id=model_id, + hub_name="my-mock-hub", + ) + + mock_construct_hub_arn_from_name.assert_called_once_with( + hub_name="my-mock-hub", region=None, sagemaker_session=None + ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1082,6 +1219,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "model_id", "model_version", "region", + "hub_name", "tolerate_vulnerable_model", "tolerate_deprecated_model", } diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 97427be1ae..0923a7a43b 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -45,6 +45,7 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) mock_cache.get_header = Mock(side_effect=get_header_from_base_header) mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec) + mock_cache.get_hub_model = Mock(side_effect=get_spec_from_base_spec) assert get_header_from_base_header( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" @@ -56,6 +57,14 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): ) == accessors.JumpStartModelsAccessor.get_model_specs( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + assert get_spec_from_base_spec( + hub_arn="arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + ) == accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + hub_arn="arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + ) assert len(accessors.JumpStartModelsAccessor._get_manifest()) > 0 @@ -85,9 +94,14 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): mock_model_cache.assert_called_once() mock_model_cache.reset_mock() + # shouldn't matter if hub_arn is passed through accessors.JumpStartModelsAccessor.get_model_specs( - region="us-west-1", model_id="pytorch-ic-mobilenet-v2", version="*" + region="us-west-1", + model_id="pytorch-ic-mobilenet-v2", + version="*", + hub_arn="arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", ) + mock_model_cache.assert_called_once() mock_model_cache.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index f1d6ccba43..9ec17c78b5 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -37,7 +37,11 @@ DeprecatedJumpStartModelError, VulnerableJumpStartModelError, ) -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from sagemaker.jumpstart.types import ( + HubArnExtractedInfo, + JumpStartModelHeader, + JumpStartVersionedModelId, +) from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from mock import MagicMock @@ -259,6 +263,44 @@ def test_add_jumpstart_model_id_version_tags(): ) +def test_add_hub_arn_tags(): + tags = None + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert [ + { + "Key": "sagemaker-hub:hub-arn", + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + } + ] == utils.add_hub_arn_tags(tags=tags, hub_arn=hub_arn) + + tags = [ + { + "Key": "sagemaker-hub:hub-arn", + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + } + ] + # If tags are already present, don't modify existing tags + assert [ + { + "Key": "sagemaker-hub:hub-arn", + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + } + ] == utils.add_hub_arn_tags(tags=tags, hub_arn=hub_arn) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + assert [ + {"Key": "random key", "Value": "random_value"}, + { + "Key": "sagemaker-hub:hub-arn", + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + }, + ] == utils.add_hub_arn_tags(tags=tags, hub_arn=hub_arn) + + def test_add_jumpstart_uri_tags_inference(): tags = None inference_model_uri = "dfsdfsd" @@ -1181,29 +1223,83 @@ def test_extract_info_from_hub_content_arn(): model_arn = ( "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" ) - assert utils.extract_info_from_hub_content_arn(model_arn) == ( - "MockHub", - "us-west-2", - "my-mock-model", - "1.0.2", + assert utils.extract_info_from_hub_content_arn(model_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Model", + hub_content_name="my-mock-model", + hub_content_version="1.0.2", + ) + + notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Notebook/my-mock-notebook/1.0.2" + assert utils.extract_info_from_hub_content_arn(notebook_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Notebook", + hub_content_name="my-mock-notebook", + hub_content_version="1.0.2", ) hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" - assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) + assert utils.extract_info_from_hub_content_arn(hub_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + ) invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + assert None is utils.extract_info_from_hub_content_arn(invalid_arn) invalid_arn = "nonsense-string" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + assert None is utils.extract_info_from_hub_content_arn(invalid_arn) invalid_arn = "" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + assert None is utils.extract_info_from_hub_content_arn(invalid_arn) invalid_arn = ( "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" ) - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) + assert None is utils.extract_info_from_hub_content_arn(invalid_arn) + + +def test_construct_hub_arn_from_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-west-2" + hub_name = "my-cool-hub" + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=mock_sagemaker_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-cool-hub" + ) + + assert ( + utils.construct_hub_arn_from_name( + hub_name=hub_name, region="us-east-1", session=mock_sagemaker_session + ) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-cool-hub" + ) + + +def test_construct_hub_model_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/*" + ) class TestIsValidModelId(TestCase): diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 76682c0f9e..fe34fea801 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List +from typing import List, Optional import boto3 from sagemaker.jumpstart.cache import JumpStartModelsCache @@ -107,6 +107,7 @@ def get_special_model_spec( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, @@ -122,6 +123,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, @@ -145,25 +147,27 @@ def get_spec_from_base_spec( model_id: str = None, semantic_version_str: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: if version and semantic_version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") - - if all( - [ - "pytorch" not in model_id, - "tensorflow" not in model_id, - "huggingface" not in model_id, - "mxnet" not in model_id, - "xgboost" not in model_id, - "catboost" not in model_id, - "lightgbm" not in model_id, - "sklearn" not in model_id, - ] - ): - raise KeyError("Bad model ID") + + if model_id is not None: + if all( + [ + "pytorch" not in model_id, + "tensorflow" not in model_id, + "huggingface" not in model_id, + "mxnet" not in model_id, + "xgboost" not in model_id, + "catboost" not in model_id, + "lightgbm" not in model_id, + "sklearn" not in model_id, + ] + ): + raise KeyError("Bad model ID") if region is not None and region not in JUMPSTART_REGION_NAME_SET: raise ValueError( From 2e9f76f5352faa4cfac64bcb45252bb18f33a60e Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Tue, 20 Feb 2024 22:15:26 +0000 Subject: [PATCH 08/31] linter --- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 1 + src/sagemaker/jumpstart/estimator.py | 4 +++- src/sagemaker/jumpstart/types.py | 1 - src/sagemaker/jumpstart/utils.py | 5 +++-- tests/unit/sagemaker/jumpstart/estimator/test_estimator.py | 1 - tests/unit/sagemaker/jumpstart/utils.py | 2 +- 6 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 9b0c8b68fb..4fb65ff91b 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Optional, Dict, Any +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index d68bc1077e..9833906d8c 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -529,7 +529,9 @@ def _is_valid_model_id_hook(): if match: hub_arn = hub_name else: - hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, sagemaker_session=sagemaker_session) + hub_arn = construct_hub_arn_from_name( + hub_name=hub_name, region=region, sagemaker_session=sagemaker_session + ) estimator_init_kwargs = get_init_kwargs( model_id=model_id, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index cc34a257be..12111d3861 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1748,7 +1748,6 @@ def __init__( self.skip_model_validation = skip_model_validation - class HubArnExtractedInfo(JumpStartDataHolderType): """Data class for info extracted from Hub arn.""" diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index a4d9b9233b..d73afd6252 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -878,7 +878,7 @@ def construct_hub_arn_from_name( session: Optional[Session] = None, ) -> str: """Constructs a Hub arn from the Hub name using default Session values""" - print('being called') + print("being called") if not session: session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION @@ -894,6 +894,7 @@ def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: """Constructs a HubContent model arn from the Hub name, model name, and model version""" info = extract_info_from_hub_content_arn(hub_arn) - arn = f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/{info.hub_name}/Model/{model_name}/{version}" + arn = f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" \ + f"{info.hub_name}/Model/{model_name}/{version}" return arn diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 9046d48a1c..b48429dde1 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -34,7 +34,6 @@ from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartTag from sagemaker.jumpstart.estimator import JumpStartEstimator -from sagemaker.jumpstart.types import JumpStartEstimatorFitKwargs from sagemaker.jumpstart.utils import get_jumpstart_content_bucket from sagemaker.session_settings import SessionSettings diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index fe34fea801..bafed5ae1c 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -153,7 +153,7 @@ def get_spec_from_base_spec( if version and semantic_version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") - + if model_id is not None: if all( [ From ac8dd60b04d88188e442c1d8bddd00af46691712 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Tue, 20 Feb 2024 22:23:36 +0000 Subject: [PATCH 09/31] linter2 --- src/sagemaker/jumpstart/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index d73afd6252..24d6797169 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -894,7 +894,9 @@ def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: """Constructs a HubContent model arn from the Hub name, model name, and model version""" info = extract_info_from_hub_content_arn(hub_arn) - arn = f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" \ + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" f"{info.hub_name}/Model/{model_name}/{version}" + ) return arn From d4f7a00549077839d4e870e1d706b7fa31aa858d Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Tue, 20 Feb 2024 22:32:13 +0000 Subject: [PATCH 10/31] fix param --- src/sagemaker/jumpstart/estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 9833906d8c..6cb11cb124 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -530,7 +530,7 @@ def _is_valid_model_id_hook(): hub_arn = hub_name else: hub_arn = construct_hub_arn_from_name( - hub_name=hub_name, region=region, sagemaker_session=sagemaker_session + hub_name=hub_name, region=region, session=sagemaker_session ) estimator_init_kwargs = get_init_kwargs( From 5492474f0fcb1fa25850b02146364bf43637745e Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 21 Feb 2024 15:39:32 +0000 Subject: [PATCH 11/31] move to utils and test --- src/sagemaker/jumpstart/cache.py | 4 +- src/sagemaker/jumpstart/constants.py | 1 - src/sagemaker/jumpstart/estimator.py | 20 ++----- src/sagemaker/jumpstart/utils.py | 38 ++++++++---- tests/unit/sagemaker/jumpstart/test_utils.py | 62 +++++++++++++++++--- 5 files changed, 89 insertions(+), 36 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index d733f39864..3b3889fd20 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -339,7 +339,7 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.MODEL: - hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn( + hub_name, region, model_name, model_version = utils.extract_info_from_hub_resource_arn( id_info ) hub = CuratedHub(hub_name=hub_name, region=region) @@ -354,7 +354,7 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.HUB: - hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info) + hub_name, region, _, _ = utils.extract_info_from_hub_resource_arn(id_info) hub = CuratedHub(hub_name=hub_name, region=region) hub_info = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_info) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index ce11f63a1b..6ee5d8208c 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -170,7 +170,6 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" -# works for cross-partition HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/(.*?)/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 6cb11cb124..2a8273004b 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """This module stores JumpStart implementation of Estimator class.""" from __future__ import absolute_import -import re from typing import Dict, List, Optional, Union @@ -28,7 +27,7 @@ from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, HUB_ARN_REGEX +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -36,7 +35,7 @@ from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job from sagemaker.jumpstart.utils import ( - construct_hub_arn_from_name, + generate_hub_arn_for_estimator, is_valid_model_id, resolve_model_sagemaker_config_field, ) @@ -522,21 +521,12 @@ def _is_valid_model_id_hook(): if not _is_valid_model_id_hook(): raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) - # TODO: Update to handle SageMakerJumpStart hub_name - hub_arn = None - if hub_name: - match = re.match(HUB_ARN_REGEX, hub_name) - if match: - hub_arn = hub_name - else: - hub_arn = construct_hub_arn_from_name( - hub_name=hub_name, region=region, session=sagemaker_session - ) - estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, - hub_arn=hub_arn, + hub_arn=generate_hub_arn_for_estimator( + hub_name=hub_name, region=region, session=sagemaker_session + ), tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 24d6797169..b0aa9e31be 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -831,10 +831,10 @@ def get_jumpstart_model_id_version_from_resource_arn( return model_id, model_version -def extract_info_from_hub_content_arn( +def extract_info_from_hub_resource_arn( arn: str, ) -> HubArnExtractedInfo: - """Extracts descriptive information from a Hub or HubContent Arn""" + """Extracts descriptive information from a Hub or HubContent Arn.""" match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) if match: @@ -875,13 +875,9 @@ def extract_info_from_hub_content_arn( def construct_hub_arn_from_name( hub_name: str, region: Optional[str] = None, - session: Optional[Session] = None, + session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: - """Constructs a Hub arn from the Hub name using default Session values""" - print("being called") - - if not session: - session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION + """Constructs a Hub arn from the Hub name using default Session values.""" account_id = session.account_id() region = region or session.boto_region_name @@ -891,12 +887,34 @@ def construct_hub_arn_from_name( def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: - """Constructs a HubContent model arn from the Hub name, model name, and model version""" + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" - info = extract_info_from_hub_content_arn(hub_arn) + info = extract_info_from_hub_resource_arn(hub_arn) arn = ( f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" f"{info.hub_name}/Model/{model_name}/{version}" ) return arn + + +# TODO: Update to recognize JumpStartHub hub_name +def generate_hub_arn_for_estimator( + hub_name: Optional[str] = None, region: Optional[str] = None, session: Optional[Session] = None +): + """Generates the Hub Arn for JumpStartEstimator from a HubName or Arn. + + Args: + hub_name (str): HubName or HubArn from JumpStartEstimator args + region (str): Region from JumpStartEstimator args + session (Session): Custom SageMaker Session from JumpStartEstimator args + """ + + hub_arn = None + if hub_name: + match = re.match(constants.HUB_ARN_REGEX, hub_name) + if match: + hub_arn = hub_name + else: + hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) + return hub_arn diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 9ec17c78b5..7c9ab5931f 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1219,11 +1219,11 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type -def test_extract_info_from_hub_content_arn(): +def test_extract_info_from_hub_resource_arn(): model_arn = ( "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" ) - assert utils.extract_info_from_hub_content_arn(model_arn) == HubArnExtractedInfo( + assert utils.extract_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( partition="aws", region="us-west-2", account_id="000000000000", @@ -1234,7 +1234,7 @@ def test_extract_info_from_hub_content_arn(): ) notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Notebook/my-mock-notebook/1.0.2" - assert utils.extract_info_from_hub_content_arn(notebook_arn) == HubArnExtractedInfo( + assert utils.extract_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( partition="aws", region="us-west-2", account_id="000000000000", @@ -1245,7 +1245,7 @@ def test_extract_info_from_hub_content_arn(): ) hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" - assert utils.extract_info_from_hub_content_arn(hub_arn) == HubArnExtractedInfo( + assert utils.extract_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( partition="aws", region="us-west-2", account_id="000000000000", @@ -1253,18 +1253,18 @@ def test_extract_info_from_hub_content_arn(): ) invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" - assert None is utils.extract_info_from_hub_content_arn(invalid_arn) + assert None is utils.extract_info_from_hub_resource_arn(invalid_arn) invalid_arn = "nonsense-string" - assert None is utils.extract_info_from_hub_content_arn(invalid_arn) + assert None is utils.extract_info_from_hub_resource_arn(invalid_arn) invalid_arn = "" - assert None is utils.extract_info_from_hub_content_arn(invalid_arn) + assert None is utils.extract_info_from_hub_resource_arn(invalid_arn) invalid_arn = ( "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" ) - assert None is utils.extract_info_from_hub_content_arn(invalid_arn) + assert None is utils.extract_info_from_hub_resource_arn(invalid_arn) def test_construct_hub_arn_from_name(): @@ -1302,6 +1302,52 @@ def test_construct_hub_model_arn_from_inputs(): ) +def test_generate_hub_arn_for_estimator(): + hub_name = "my-hub-name" + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock default session with default values + mock_default_session = Mock() + mock_default_session.account_id.return_value = "123456789123" + mock_default_session.boto_region_name = JUMPSTART_DEFAULT_REGION_NAME + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + + assert ( + utils.generate_hub_arn_for_estimator(hub_name, session=mock_default_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator(hub_name, "us-east-1", session=mock_default_session) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator(hub_name, "eu-west-1", mock_custom_session) + == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator(hub_name, None, mock_custom_session) + == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" + ) + + assert utils.generate_hub_arn_for_estimator(hub_arn, session=mock_default_session) == hub_arn + + assert ( + utils.generate_hub_arn_for_estimator(hub_arn, "us-east-1", session=mock_default_session) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_estimator(hub_arn, "us-east-1", mock_custom_session) == hub_arn + ) + + assert utils.generate_hub_arn_for_estimator(hub_arn, None, mock_custom_session) == hub_arn + + class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") From 1e26760eea3b60dd7dcae5cf3d5ed20083b9e761 Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 21 Feb 2024 10:58:01 -0500 Subject: [PATCH 12/31] feat: add hub and hubcontent support in retrieval function for jumpstart model cache (#4438) --- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 4fb65ff91b..6261cdb1fe 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -22,11 +22,11 @@ class CuratedHub: """Class for creating and managing a curated JumpStart hub""" - def __init__(self, hub_name: str, region: str, session: Optional[Session] = None): + def __init__(self, hub_name: str, region: str, session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION): self.hub_name = hub_name self.region = region self.session = session - self._sm_session = session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION + self._sm_session = session def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: """Returns descriptive information about the Hub Model""" From 4ae201c6d80f82dc24d3ebc9c8754c57a919c7d4 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 19:31:18 +0000 Subject: [PATCH 13/31] add hub and hubcontent support in retrieval function for jumpstart model cache --- src/sagemaker/jumpstart/cache.py | 20 +++++++++++++++++++ .../jumpstart/curated_hub/constants.py | 15 ++++++++++++++ .../jumpstart/curated_hub/curated_hub.py | 2 ++ tests/unit/sagemaker/jumpstart/test_cache.py | 14 ++++++++----- 4 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 3b3889fd20..0216d3876b 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -465,6 +465,26 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS ) ) return specs.formatted_content + + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + return specs.formatted_content + + def get_hub(self, hub_arn: str) -> Dict[str, Any]: + """Return descriptive info for a given Hub + + Args: + hub_arn (str): Arn for the Hub to get info for + """ + + manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + return manifest.formatted_content def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py new file mode 100644 index 0000000000..478c7e801b --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/constants.py @@ -0,0 +1,15 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.config import Config + +DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"}) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 6261cdb1fe..42cce5aeed 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -18,6 +18,8 @@ from sagemaker.session import Session +from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG + class CuratedHub: """Class for creating and managing a curated JumpStart hub""" diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 69c8659148..59d040e92f 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -860,12 +860,16 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( def test_jumpstart_cache_get_hub_model(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") - model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" + ) assert get_spec_from_base_spec( model_id="huggingface-mock-model-123", version="1.2.3" ) == cache.get_hub_model(hub_model_arn=model_arn) - model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" - assert get_spec_from_base_spec( - model_id="pytorch-mock-model-123", version="*" - ) == cache.get_hub_model(hub_model_arn=model_arn) + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" + ) + assert get_spec_from_base_spec(model_id="pytorch-mock-model-123", version="*") == cache.get_hub_model( + hub_model_arn=model_arn + ) From 4a19a331021d479bd8dc108e37c8c46e5164e754 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 19:56:02 +0000 Subject: [PATCH 14/31] update types and var names --- src/sagemaker/jumpstart/cache.py | 8 ++++---- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 0216d3876b..c936cb1384 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -473,8 +473,8 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) - return specs.formatted_content + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + return details.formatted_content def get_hub(self, hub_arn: str) -> Dict[str, Any]: """Return descriptive info for a given Hub @@ -483,8 +483,8 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) - return manifest.formatted_content + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + return details.formatted_content def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 42cce5aeed..6261cdb1fe 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -18,8 +18,6 @@ from sagemaker.session import Session -from sagemaker.jumpstart.curated_hub.constants import DEFAULT_CLIENT_CONFIG - class CuratedHub: """Class for creating and managing a curated JumpStart hub""" From 4d3537941bfe20bad29b72699a86aa8161fe8eb1 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:39:28 +0000 Subject: [PATCH 15/31] update linter --- src/sagemaker/jumpstart/cache.py | 4 +++- tests/unit/sagemaker/jumpstart/test_cache.py | 14 +++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index c936cb1384..9bf72a044d 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -473,7 +473,9 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + details, _ = self._content_cache.get( + JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) + ) return details.formatted_content def get_hub(self, hub_arn: str) -> Dict[str, Any]: diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 59d040e92f..69c8659148 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -860,16 +860,12 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( def test_jumpstart_cache_get_hub_model(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" - ) + model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/huggingface-mock-model-123/1.2.3" assert get_spec_from_base_spec( model_id="huggingface-mock-model-123", version="1.2.3" ) == cache.get_hub_model(hub_model_arn=model_arn) - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" - ) - assert get_spec_from_base_spec(model_id="pytorch-mock-model-123", version="*") == cache.get_hub_model( - hub_model_arn=model_arn - ) + model_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/HubName/Model/pytorch-mock-model-123/*" + assert get_spec_from_base_spec( + model_id="pytorch-mock-model-123", version="*" + ) == cache.get_hub_model(hub_model_arn=model_arn) From 174c4fde955f60b143cc78891ad11c52853a9690 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:49:41 +0000 Subject: [PATCH 16/31] linter --- src/sagemaker/jumpstart/cache.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 9bf72a044d..bad599bdfb 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -465,10 +465,10 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS ) ) return specs.formatted_content - + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model - + Args: hub_model_arn (str): Arn for the Hub model to get specs for """ @@ -477,14 +477,14 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) ) return details.formatted_content - + def get_hub(self, hub_arn: str) -> Dict[str, Any]: """Return descriptive info for a given Hub - + Args: hub_arn (str): Arn for the Hub to get info for """ - + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) return details.formatted_content From b7a8835919e95716da9dbeb1111163a79452dcb0 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:58:42 +0000 Subject: [PATCH 17/31] linter --- src/sagemaker/jumpstart/curated_hub/constants.py | 15 --------------- .../jumpstart/curated_hub/curated_hub.py | 1 - 2 files changed, 16 deletions(-) delete mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py deleted file mode 100644 index 478c7e801b..0000000000 --- a/src/sagemaker/jumpstart/curated_hub/constants.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from botocore.config import Config - -DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"}) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 6261cdb1fe..a153c9bb1a 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -18,7 +18,6 @@ from sagemaker.session import Session - class CuratedHub: """Class for creating and managing a curated JumpStart hub""" From bb7a9fb78665c720d02252807834d44b9580cb77 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 22:04:23 +0000 Subject: [PATCH 18/31] flake8 check --- src/sagemaker/jumpstart/curated_hub/curated_hub.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index a153c9bb1a..6261cdb1fe 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -18,6 +18,7 @@ from sagemaker.session import Session + class CuratedHub: """Class for creating and managing a curated JumpStart hub""" From 8ba576a46b7144dad0a72133420b6045ba345221 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 21 Feb 2024 16:33:33 +0000 Subject: [PATCH 19/31] pass hub_arn into all estimator utils/artifacts --- src/sagemaker/environment_variables.py | 4 ++++ src/sagemaker/hyperparameters.py | 8 ++++++++ src/sagemaker/image_uris.py | 4 ++++ src/sagemaker/instance_types.py | 4 ++++ .../jumpstart/artifacts/environment_variables.py | 9 +++++++++ .../jumpstart/artifacts/hyperparameters.py | 4 ++++ src/sagemaker/jumpstart/artifacts/image_uris.py | 4 ++++ .../jumpstart/artifacts/incremental_training.py | 4 ++++ .../jumpstart/artifacts/instance_types.py | 4 ++++ src/sagemaker/jumpstart/artifacts/kwargs.py | 4 ++++ .../jumpstart/artifacts/metric_definitions.py | 4 ++++ .../jumpstart/artifacts/model_packages.py | 4 ++++ src/sagemaker/jumpstart/artifacts/model_uris.py | 4 ++++ src/sagemaker/jumpstart/artifacts/payloads.py | 4 ++++ .../jumpstart/artifacts/resource_names.py | 4 ++++ src/sagemaker/jumpstart/artifacts/script_uris.py | 4 ++++ src/sagemaker/jumpstart/cache.py | 12 +++++++----- src/sagemaker/jumpstart/factory/estimator.py | 12 ++++++++++++ src/sagemaker/jumpstart/utils.py | 6 ++++-- src/sagemaker/jumpstart/validators.py | 4 ++++ src/sagemaker/metric_definitions.py | 4 ++++ src/sagemaker/model_uris.py | 4 ++++ src/sagemaker/script_uris.py | 4 ++++ tests/unit/sagemaker/jumpstart/test_utils.py | 16 ++++++++-------- 24 files changed, 120 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b67066fcde..53da690a7a 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -30,6 +30,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default environment variables. (Default: None). model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -80,6 +83,7 @@ def retrieve_default( return artifacts._retrieve_default_environment_variables( model_id, model_version, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5873e37b9f..55ee28e073 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -31,6 +31,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default hyperparameters. (Default: None). model_version (str): The version of the model for which to retrieve the default hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. include_container_hyperparameters (bool): ``True`` if the container hyperparameters @@ -80,6 +83,7 @@ def retrieve_default( return artifacts._retrieve_default_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, include_container_hyperparameters=include_container_hyperparameters, @@ -93,6 +97,7 @@ def validate( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, tolerate_vulnerable_model: bool = False, @@ -107,6 +112,8 @@ def validate( (Default: None). model_version (str): The version of the model for which to validate hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). hyperparameters (dict): Hyperparameters to validate. (Default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with @@ -148,6 +155,7 @@ def validate( return validate_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, hyperparameters=hyperparameters, validation_mode=validation_mode, region=region, diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 252bf3c504..ee33e2d8dc 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -61,6 +61,7 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + hub_arn=None, tolerate_vulnerable_model=False, tolerate_deprecated_model=False, sdk_version=None, @@ -101,6 +102,8 @@ def retrieve( (default: None). model_version (str): The version of the JumpStart model for which to retrieve the image URI (default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without an exception raised. If ``False``, raises an exception if the script used by this version of the model has dependencies with known security @@ -146,6 +149,7 @@ def retrieve( model_id, model_version, image_scope, + hub_arn, framework, region, version, diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 4e79e2b400..5770c08f5b 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -45,6 +45,8 @@ def retrieve_default( retrieve the default instance type. (Default: None). model_version (str): The version of the model for which to retrieve the default instance type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -110,6 +112,8 @@ def retrieve( retrieve the supported instance types. (Default: None). model_version (str): The version of the model for which to retrieve the supported instance types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 0e666e4c14..0f0522b8a9 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -31,6 +31,7 @@ def _retrieve_default_environment_variables( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -46,6 +47,8 @@ def _retrieve_default_environment_variables( retrieve the default environment variables. model_version (str): Version of the JumpStart model for which to retrieve the default environment variables. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). region (Optional[str]): Region for which to retrieve default environment variables. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -77,6 +80,7 @@ def _retrieve_default_environment_variables( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -113,6 +117,7 @@ def _retrieve_default_environment_variables( gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -131,6 +136,7 @@ def _retrieve_default_environment_variables( def _retrieve_gated_model_uri_env_var_value( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -144,6 +150,8 @@ def _retrieve_gated_model_uri_env_var_value( retrieve the gated model env var URI. model_version (str): Version of the JumpStart model for which to retrieve the gated model env var URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). region (Optional[str]): Region for which to retrieve the gated model env var URI. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -174,6 +182,7 @@ def _retrieve_gated_model_uri_env_var_value( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index e9e6f613f8..f131d4a203 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -30,6 +30,7 @@ def _retrieve_default_hyperparameters( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters( retrieve the default hyperparameters. model_version (str): Version of the JumpStart model for which to retrieve the default hyperparameters. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). region (str): Region for which to retrieve default hyperparameters. (Default: None). include_container_hyperparameters (bool): True if container hyperparameters @@ -76,6 +79,7 @@ def _retrieve_default_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 6ea1ca84a1..7952048050 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -33,6 +33,7 @@ def _retrieve_image_uri( model_id: str, model_version: str, image_scope: str, + hub_arn: Optional[str] = None, framework: Optional[str] = None, region: Optional[str] = None, version: Optional[str] = None, @@ -57,6 +58,8 @@ def _retrieve_image_uri( model_id (str): JumpStart model ID for which to retrieve image URI. model_version (str): Version of the JumpStart model for which to retrieve the image URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). image_scope (str): The image type, i.e. what it is used for. Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. @@ -110,6 +113,7 @@ def _retrieve_image_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=image_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 753a911422..1c392199cb 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -30,6 +30,7 @@ def _model_supports_incremental_training( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -43,6 +44,8 @@ def _model_supports_incremental_training( support status for incremental training. region (Optional[str]): Region for which to retrieve the support status for incremental training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -64,6 +67,7 @@ def _model_supports_incremental_training( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 176159c3c6..cf283c0b9e 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -49,6 +49,8 @@ def _retrieve_default_instance_type( default instance type. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). region (Optional[str]): Region for which to retrieve default instance type. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -137,6 +139,8 @@ def _retrieve_instance_types( supported instance types. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). region (Optional[str]): Region for which to retrieve supported instance types. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 7a8c153a88..9adb838549 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -140,6 +140,7 @@ def _retrieve_estimator_init_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -154,6 +155,8 @@ def _retrieve_estimator_init_kwargs( kwargs. instance_type (str): Instance type of the training job, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -177,6 +180,7 @@ def _retrieve_estimator_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index b6f6019641..fce09681f0 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -31,6 +31,7 @@ def _retrieve_default_training_metric_definitions( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -45,6 +46,8 @@ def _retrieve_default_training_metric_definitions( default training metric definitions. region (Optional[str]): Region for which to retrieve default training metric definitions. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -68,6 +71,7 @@ def _retrieve_default_training_metric_definitions( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index bd0ae365d9..e013ac584e 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -110,6 +110,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -124,6 +125,8 @@ def _retrieve_model_package_model_artifact_s3_uri( model package artifact. region (Optional[str]): Region for which to retrieve the model package artifact. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). scope (Optional[str]): Scope for which to retrieve the model package artifact. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -152,6 +155,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index c41f0a75b7..c01715b616 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -178,6 +178,7 @@ def _model_supports_training_model_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -191,6 +192,8 @@ def _model_supports_training_model_uri( support status for model uri with training. region (Optional[str]): Region for which to retrieve the support status for model uri with training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -212,6 +215,7 @@ def _model_supports_training_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3ea2c16f80..eaf740a8dc 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -32,6 +32,7 @@ def _retrieve_example_payloads( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -45,6 +46,8 @@ def _retrieve_example_payloads( example payloads. region (Optional[str]): Region for which to retrieve the example payloads. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -67,6 +70,7 @@ def _retrieve_example_payloads( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 6b05f07b15..3f558219e7 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -30,6 +30,7 @@ def _retrieve_resource_name_base( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -43,6 +44,8 @@ def _retrieve_resource_name_base( default resource name. region (Optional[str]): Region for which to retrieve the default resource name. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -64,6 +67,7 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index c1b037ce61..ab4bd40d2c 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -32,6 +32,7 @@ def _retrieve_script_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -47,6 +48,8 @@ def _retrieve_script_uri( retrieve the script S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. @@ -77,6 +80,7 @@ def _retrieve_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index bad599bdfb..92f050461d 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -339,11 +339,11 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.MODEL: - hub_name, region, model_name, model_version = utils.extract_info_from_hub_resource_arn( + info = utils.get_info_from_hub_resource_arn( id_info ) - hub = CuratedHub(hub_name=hub_name, region=region) - hub_content = hub.describe_model(model_name=model_name, model_version=model_version) + hub = CuratedHub(hub_name=info.hub_name, region=info.region) + hub_content = hub.describe_model(model_name=info.hub_content_name, model_version=info.hub_content_version) utils.emit_logs_based_on_model_specs( hub_content.content_document, self.get_region(), @@ -354,8 +354,10 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubDataType.HUB: - hub_name, region, _, _ = utils.extract_info_from_hub_resource_arn(id_info) - hub = CuratedHub(hub_name=hub_name, region=region) + info = utils.get_info_from_hub_resource_arn( + id_info + ) + hub = CuratedHub(hub_name=info.hub_name, region=info.region) hub_info = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 47d8a71ebe..71a6419f82 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -462,6 +462,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -489,6 +490,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE image_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -504,6 +506,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE if _model_supports_training_model_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -513,6 +516,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE model_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -525,6 +529,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE and not _model_supports_incremental_training( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -560,6 +565,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart script_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -576,6 +582,7 @@ def _add_env_to_kwargs( extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -588,6 +595,7 @@ def _add_env_to_kwargs( model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -628,6 +636,7 @@ def _add_training_job_name_to_kwargs( default_training_job_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -654,6 +663,7 @@ def _add_hyperparameters_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -687,6 +697,7 @@ def _add_metric_definitions_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -715,6 +726,7 @@ def _add_estimator_extra_kwargs( estimator_kwargs_to_add = _retrieve_estimator_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index b0aa9e31be..3cea9ae7d9 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -559,6 +559,8 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -831,7 +833,7 @@ def get_jumpstart_model_id_version_from_resource_arn( return model_id, model_version -def extract_info_from_hub_resource_arn( +def get_info_from_hub_resource_arn( arn: str, ) -> HubArnExtractedInfo: """Extracts descriptive information from a Hub or HubContent Arn.""" @@ -889,7 +891,7 @@ def construct_hub_arn_from_name( def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: """Constructs a HubContent model arn from the Hub name, model name, and model version.""" - info = extract_info_from_hub_resource_arn(hub_arn) + info = get_info_from_hub_resource_arn(hub_arn) arn = ( f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" f"{info.hub_name}/Model/{model_name}/{version}" diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 3199e5fc2e..c06a3d1bbd 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -167,6 +167,7 @@ def validate_hyperparameters( model_id: str, model_version: str, hyperparameters: Dict[str, Any], + hub_arn: Optional[str] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, sagemaker_session: Optional[session.Session] = None, @@ -179,6 +180,8 @@ def validate_hyperparameters( model_id (str): Model ID of the model for which to validate hyperparameters. model_version (str): Version of the model for which to validate hyperparameters. hyperparameters (dict): Hyperparameters to validate. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided to this function will be validated, the missing hyperparameters will be ignored. @@ -211,6 +214,7 @@ def validate_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 71dd26db45..fff72a7d65 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -43,6 +44,8 @@ def retrieve_default( retrieve the default training metric definitions. (Default: None). model_version (str): The version of the model for which to retrieve the default training metric definitions. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -71,6 +74,7 @@ def retrieve_default( return artifacts._retrieve_default_training_metric_definitions( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 937180bd44..49895bbd49 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -43,6 +44,8 @@ def retrieve( the model artifact S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). model_scope (str): The model type. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -75,6 +78,7 @@ def retrieve( return artifacts._retrieve_model_uri( model_id=model_id, model_version=model_version, # type: ignore + hub_arn=hub_arn, model_scope=model_scope, instance_type=instance_type, region=region, diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 9a1c4933d2..44327d6056 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -42,6 +43,8 @@ def retrieve( retrieve the script S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). script_scope (str): The script type. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model @@ -73,6 +76,7 @@ def retrieve( return artifacts._retrieve_script_uri( model_id, model_version, + hub_arn, script_scope, region, tolerate_vulnerable_model, diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 7c9ab5931f..3bf63ed9e9 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1219,11 +1219,11 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type -def test_extract_info_from_hub_resource_arn(): +def test_get_info_from_hub_resource_arn(): model_arn = ( "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" ) - assert utils.extract_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( + assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( partition="aws", region="us-west-2", account_id="000000000000", @@ -1234,7 +1234,7 @@ def test_extract_info_from_hub_resource_arn(): ) notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Notebook/my-mock-notebook/1.0.2" - assert utils.extract_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( + assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( partition="aws", region="us-west-2", account_id="000000000000", @@ -1245,7 +1245,7 @@ def test_extract_info_from_hub_resource_arn(): ) hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" - assert utils.extract_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( + assert utils.get_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( partition="aws", region="us-west-2", account_id="000000000000", @@ -1253,18 +1253,18 @@ def test_extract_info_from_hub_resource_arn(): ) invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" - assert None is utils.extract_info_from_hub_resource_arn(invalid_arn) + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) invalid_arn = "nonsense-string" - assert None is utils.extract_info_from_hub_resource_arn(invalid_arn) + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) invalid_arn = "" - assert None is utils.extract_info_from_hub_resource_arn(invalid_arn) + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) invalid_arn = ( "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" ) - assert None is utils.extract_info_from_hub_resource_arn(invalid_arn) + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) def test_construct_hub_arn_from_name(): From ecd1f97d1de4a3bc4d640cb3b0921dbf969687be Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 21 Feb 2024 10:58:01 -0500 Subject: [PATCH 20/31] feat: add hub and hubcontent support in retrieval function for jumpstart model cache (#4438) --- src/sagemaker/jumpstart/cache.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 92f050461d..254a9042c6 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -490,28 +490,6 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) return details.formatted_content - def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: - """Return JumpStart-compatible specs for a given Hub model - - Args: - hub_model_arn (str): Arn for the Hub model to get specs for - """ - - details, _ = self._content_cache.get( - JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) - ) - return details.formatted_content - - def get_hub(self, hub_arn: str) -> Dict[str, Any]: - """Return descriptive info for a given Hub - - Args: - hub_arn (str): Arn for the Hub to get info for - """ - - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) - return details.formatted_content - def clear(self) -> None: """Clears the model ID/version and s3 cache.""" self._content_cache.clear() From 8df4478529898da43c7ffef94a646162562e2305 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 19:31:18 +0000 Subject: [PATCH 21/31] add hub and hubcontent support in retrieval function for jumpstart model cache --- src/sagemaker/jumpstart/cache.py | 20 +++++++++++++++++++ .../jumpstart/curated_hub/constants.py | 15 ++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 254a9042c6..81c4326799 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -467,6 +467,26 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS ) ) return specs.formatted_content + + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + return specs.formatted_content + + def get_hub(self, hub_arn: str) -> Dict[str, Any]: + """Return descriptive info for a given Hub + + Args: + hub_arn (str): Arn for the Hub to get info for + """ + + manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + return manifest.formatted_content def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py new file mode 100644 index 0000000000..478c7e801b --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/constants.py @@ -0,0 +1,15 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from botocore.config import Config + +DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"}) From 4870c0b208205f0c035d3c4cff4a5f16b1924f16 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 19:56:02 +0000 Subject: [PATCH 22/31] update types and var names --- src/sagemaker/jumpstart/cache.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 81c4326799..2e6559d343 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -475,8 +475,8 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - specs, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) - return specs.formatted_content + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + return details.formatted_content def get_hub(self, hub_arn: str) -> Dict[str, Any]: """Return descriptive info for a given Hub @@ -485,8 +485,8 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - manifest, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) - return manifest.formatted_content + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + return details.formatted_content def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model From 4dc21f53e7ea29f7e8dc0bd82b27f1248731e54c Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Mon, 19 Feb 2024 20:39:28 +0000 Subject: [PATCH 23/31] update linter --- src/sagemaker/jumpstart/cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 2e6559d343..b35c2d9edb 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -475,7 +475,9 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: hub_model_arn (str): Arn for the Hub model to get specs for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn)) + details, _ = self._content_cache.get( + JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) + ) return details.formatted_content def get_hub(self, hub_arn: str) -> Dict[str, Any]: From dd51314f388b6f4899873209857ae9271647f75d Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 21 Feb 2024 16:41:06 +0000 Subject: [PATCH 24/31] remove duplicate --- src/sagemaker/jumpstart/cache.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index b35c2d9edb..ee03f3e397 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -490,28 +490,6 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) return details.formatted_content - def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: - """Return JumpStart-compatible specs for a given Hub model - - Args: - hub_model_arn (str): Arn for the Hub model to get specs for - """ - - details, _ = self._content_cache.get( - JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) - ) - return details.formatted_content - - def get_hub(self, hub_arn: str) -> Dict[str, Any]: - """Return descriptive info for a given Hub - - Args: - hub_arn (str): Arn for the Hub to get info for - """ - - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) - return details.formatted_content - def clear(self) -> None: """Clears the model ID/version and s3 cache.""" self._content_cache.clear() From a39ae5f76a5fe9e1f38eed79442b01b8113d3114 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 21 Feb 2024 17:11:21 +0000 Subject: [PATCH 25/31] linter --- src/sagemaker/jumpstart/artifacts/model_uris.py | 4 ++++ src/sagemaker/jumpstart/cache.py | 14 ++++++++------ src/sagemaker/jumpstart/curated_hub/constants.py | 15 --------------- .../jumpstart/curated_hub/curated_hub.py | 7 ++++++- 4 files changed, 18 insertions(+), 22 deletions(-) delete mode 100644 src/sagemaker/jumpstart/curated_hub/constants.py diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index c01715b616..30bf1e9521 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -89,6 +89,7 @@ def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_t def _retrieve_model_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, region: Optional[str] = None, @@ -105,6 +106,8 @@ def _retrieve_model_uri( the model artifact S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from (default: None). model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -135,6 +138,7 @@ def _retrieve_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=model_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index ee03f3e397..905178ca6f 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -343,7 +343,9 @@ def _retrieval_function( id_info ) hub = CuratedHub(hub_name=info.hub_name, region=info.region) - hub_content = hub.describe_model(model_name=info.hub_content_name, model_version=info.hub_content_version) + hub_content = hub.describe_model( + model_name=info.hub_content_name, model_version=info.hub_content_version + ) utils.emit_logs_based_on_model_specs( hub_content.content_document, self.get_region(), @@ -467,10 +469,10 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS ) ) return specs.formatted_content - + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """Return JumpStart-compatible specs for a given Hub model - + Args: hub_model_arn (str): Arn for the Hub model to get specs for """ @@ -479,14 +481,14 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) ) return details.formatted_content - + def get_hub(self, hub_arn: str) -> Dict[str, Any]: """Return descriptive info for a given Hub - + Args: hub_arn (str): Arn for the Hub to get info for """ - + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) return details.formatted_content diff --git a/src/sagemaker/jumpstart/curated_hub/constants.py b/src/sagemaker/jumpstart/curated_hub/constants.py deleted file mode 100644 index 478c7e801b..0000000000 --- a/src/sagemaker/jumpstart/curated_hub/constants.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from botocore.config import Config - -DEFAULT_CLIENT_CONFIG = Config(retries={"max_attempts": 10, "mode": "standard"}) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 18d572c9e4..b8885ff250 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -22,7 +22,12 @@ class CuratedHub: """Class for creating and managing a curated JumpStart hub""" - def __init__(self, hub_name: str, region: str, session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION): + def __init__( + self, + hub_name: str, + region: str, + session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ): self.hub_name = hub_name self.region = region self._sm_session = session From 354b33e09b4b18545b283656e1309af86c538d81 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 21 Feb 2024 18:12:28 +0000 Subject: [PATCH 26/31] add important unit test --- src/sagemaker/jumpstart/accessors.py | 2 +- .../sagemaker/jumpstart/test_accessors.py | 29 +++++++++++++++++++ tests/unit/sagemaker/jumpstart/utils.py | 1 + 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 3434c25fae..34a1764110 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -268,7 +268,7 @@ def get_model_specs( hub_model_arn = utils.construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) - return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn) + return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn) return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 0923a7a43b..460494e116 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -72,6 +72,35 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): reload(accessors) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_models_cache_get_model_specs(mock_cache): + mock_cache.get_specs = Mock() + mock_cache.get_hub_model = Mock() + model_id, version = "pytorch-ic-mobilenet-v2", "*" + region = "us-west-2" + + accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=version + ) + mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version) + mock_cache.get_hub_model.assert_not_called() + + accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=version, + hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub", + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}" + ) + ) + + # necessary because accessors is a static module + reload(accessors) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index bafed5ae1c..83aab2ff0f 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -148,6 +148,7 @@ def get_spec_from_base_spec( semantic_version_str: str = None, version: str = None, hub_arn: Optional[str] = None, + hub_model_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: From 424254e0c9c0b04ba09be47ac41d6cb507321e0d Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 22 Feb 2024 16:05:00 +0000 Subject: [PATCH 27/31] update tests --- src/sagemaker/jumpstart/constants.py | 2 +- src/sagemaker/jumpstart/types.py | 2 +- src/sagemaker/jumpstart/utils.py | 2 +- .../jumpstart/test_accept_types.py | 4 +- .../jumpstart/test_content_types.py | 4 +- .../jumpstart/test_deserializers.py | 4 +- .../jumpstart/test_default.py | 8 +- .../hyperparameters/jumpstart/test_default.py | 5 +- .../jumpstart/test_validate.py | 6 +- .../image_uris/jumpstart/test_common.py | 8 +- tests/unit/sagemaker/jumpstart/constants.py | 1 + .../jumpstart/estimator/test_estimator.py | 17 +- .../sagemaker/jumpstart/test_artifacts.py | 199 ++++++++++++++++++ .../jumpstart/test_notebook_utils.py | 2 +- tests/unit/sagemaker/jumpstart/test_utils.py | 9 +- tests/unit/sagemaker/jumpstart/utils.py | 1 + .../jumpstart/test_default.py | 4 +- .../model_uris/jumpstart/test_common.py | 8 +- .../jumpstart/test_resource_requirements.py | 4 +- .../script_uris/jumpstart/test_common.py | 8 +- .../serializers/jumpstart/test_serializers.py | 4 +- 21 files changed, 253 insertions(+), 49 deletions(-) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 6ee5d8208c..1fec16b9b3 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -170,7 +170,7 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" -HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/(.*?)/(.*?)/(.*?)$" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 12111d3861..947250fd66 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -867,7 +867,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.inference_enable_network_isolation: bool = json_obj.get( "inference_enable_network_isolation", False ) - self.resource_name_base: bool = json_obj.get("resource_name_base") + self.resource_name_base: Optional[str] = json_obj.get("resource_name_base") self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key") diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 3cea9ae7d9..f120b42396 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -870,7 +870,7 @@ def get_info_from_hub_resource_arn( account_id=account_id, hub_name=hub_name, ) - + return None diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 28211d06f1..4284ac4d84 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -45,7 +45,7 @@ def test_jumpstart_default_accept_types( assert default_accept_type == "application/json" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) @@ -73,5 +73,5 @@ def test_jumpstart_supported_accept_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 4b2db7d7f4..c924417946 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -45,7 +45,7 @@ def test_jumpstart_default_content_types( assert default_content_type == "application/x-text" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) @@ -72,5 +72,5 @@ def test_jumpstart_supported_content_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 9d6e2f21de..4807fc7933 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -47,7 +47,7 @@ def test_jumpstart_default_deserializers( assert isinstance(default_deserializer, base_deserializers.JSONDeserializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) @@ -79,5 +79,5 @@ def test_jumpstart_deserializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index acd8d19923..d89687b3f4 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -48,7 +48,7 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -68,7 +68,7 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -122,7 +122,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -143,7 +143,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index eebc079164..346d1ceab8 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -47,6 +47,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, version="*", s3_client=mock_client, + hub_arn=None ) patched_get_model_specs.reset_mock() @@ -63,7 +64,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): region=region, model_id=model_id, version="1.*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -88,7 +89,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): region=region, model_id=model_id, version="1.*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0054ed9dbd..46df5cb87d 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -139,7 +139,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -437,7 +437,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -491,7 +491,7 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 8a41891280..330a5dbcd3 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -48,7 +48,7 @@ def test_jumpstart_common_image_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -68,7 +68,7 @@ def test_jumpstart_common_image_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -88,7 +88,7 @@ def test_jumpstart_common_image_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -108,7 +108,7 @@ def test_jumpstart_common_image_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 4b8a49764d..a60f8f9315 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -6239,6 +6239,7 @@ "training_volume_size": 456, "inference_enable_network_isolation": True, "training_enable_network_isolation": False, + "resource_name_base": "dfsdfsds", "hosting_resource_requirements": {"num_accelerators": 1, "min_memory_mb": 34360}, "dynamic_container_deployment_supported": True, } diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index b48429dde1..96c87cdf4d 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -282,11 +282,12 @@ def test_prepacked( ], ) + @mock.patch("sagemaker.jumpstart.artifacts.resource_names._retrieve_resource_name_base") @mock.patch("sagemaker.session.Session.account_id") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_deploy_kwargs") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_fit_kwargs") - @mock.patch("sagemaker.jumpstart.estimator.construct_hub_arn_from_name") + @mock.patch("sagemaker.jumpstart.utils.construct_hub_arn_from_name") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @@ -310,7 +311,9 @@ def test_hub_model( mock_retrieve_model_deploy_kwargs: mock.Mock, mock_retrieve_model_init_kwargs: mock.Mock, mock_get_caller_identity: mock.Mock, + mock_retrieve_resource_name_base: mock.Mock, ): + mock_retrieve_resource_name_base.return_value = "go-blue" mock_get_caller_identity.return_value = "123456789123" mock_estimator_deploy.return_value = default_predictor @@ -372,11 +375,11 @@ def test_hub_model( f"some-training-dataset-doesn't-matter", } - estimator.fit(channels) + estimator.fit(channels, job_name="go-blue") - mock_estimator_fit.assert_called_once_with(inputs=channels, wait=True) + mock_estimator_fit.assert_called_once_with(inputs=channels, wait=True, job_name="go-blue") - estimator.deploy() + estimator.deploy(endpoint_name="go-blue", model_name="go-blue") mock_estimator_deploy.assert_called_once_with( instance_type="ml.p2.xlarge", @@ -386,6 +389,8 @@ def test_hub_model( source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" "pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", entry_point="inference.py", + endpoint_name="go-blue", + model_name="go-blue", env={ "SAGEMAKER_PROGRAM": "inference.py", "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -414,7 +419,7 @@ def test_hub_model( ) mock_construct_hub_arn_from_name.assert_called_once_with( - hub_name="my-mock-hub", region=None, sagemaker_session=None + hub_name="my-mock-hub", region=None, session=None ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1475,6 +1480,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( model_id=model_id, model_version="*", region=region, + hub_arn=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, @@ -1526,6 +1532,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( model_id=model_id, model_version="*", region=region, + hub_arn=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 1a770f785f..3d26bd5739 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from importlib import reload import unittest from unittest.mock import Mock @@ -20,11 +21,18 @@ import copy from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.artifacts.environment_variables import _retrieve_default_environment_variables +from sagemaker.jumpstart.artifacts.hyperparameters import _retrieve_default_hyperparameters +from sagemaker.jumpstart.artifacts.image_uris import _retrieve_image_uri +from sagemaker.jumpstart.artifacts.incremental_training import _model_supports_incremental_training +from sagemaker.jumpstart.artifacts.instance_types import _retrieve_default_instance_type +from sagemaker.jumpstart.artifacts.metric_definitions import _retrieve_default_training_metric_definitions from sagemaker.jumpstart.artifacts.model_uris import ( _retrieve_hosting_prepacked_artifact_key, _retrieve_hosting_artifact_key, _retrieve_training_artifact_key, ) +from sagemaker.jumpstart.artifacts.script_uris import _retrieve_script_uri from sagemaker.jumpstart.types import JumpStartModelSpecs from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, @@ -456,3 +464,194 @@ def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): ), "s3://jumpstart-private-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", ) + + +class HubModelTest(unittest.TestCase): + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") + def test_retrieve_default_environment_variables(self, mock_cache): + mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_default_environment_variables( + model_id=model_id, model_version=version, hub_arn=hub_arn, script=JumpStartScriptScope.INFERENCE + ), + { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1" + } + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") + def test_retrieve_image_uri(self, mock_cache): + mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_image_uri( + model_id=model_id, model_version=version, hub_arn=hub_arn, instance_type="ml.p3.2xlarge", image_scope=JumpStartScriptScope.TRAINING + ), + "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3" + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") + def test_retrieve_default_hyperparameters(self, mock_cache): + mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_default_hyperparameters( + model_id=model_id, model_version=version, hub_arn=hub_arn + ), + { + "epochs": "3", + "adam-learning-rate": "0.05", + "batch-size": "4", + } + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") + def test_model_supports_incremental_training(self, mock_cache): + mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _model_supports_incremental_training( + model_id=model_id, model_version=version, hub_arn=hub_arn, region="us-west-2" + ), + True + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") + def test_retrieve_default_instance_type(self, mock_cache): + mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_default_instance_type( + model_id=model_id, model_version=version, hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING + ), + "ml.p3.2xlarge" + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + self.assertEqual( + _retrieve_default_instance_type( + model_id=model_id, model_version=version, hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE + ), + "ml.p2.xlarge" + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") + def test_retrieve_default_training_metric_definitions(self, mock_cache): + mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_default_training_metric_definitions( + model_id=model_id, model_version=version, hub_arn=hub_arn, region="us-west-2" + ), + [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}] + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") + def test_retrieve_model_uri(self, mock_cache): + mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_model_uri( + model_id=model_id, model_version=version, hub_arn=hub_arn, model_scope="training" + ), + "s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + self.assertEqual( + _retrieve_model_uri( + model_id=model_id, model_version=version, hub_arn=hub_arn, model_scope="inference" + ), + "s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") + def test_retrieve_script_uri(self, mock_cache): + mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_script_uri( + model_id=model_id, model_version=version, hub_arn=hub_arn, script_scope=JumpStartScriptScope.TRAINING + ), + "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + self.assertEqual( + _retrieve_script_uri( + model_id=model_id, model_version=version, hub_arn=hub_arn, script_scope=JumpStartScriptScope.INFERENCE + ), + "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz" + ) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 8aae4c36a8..0ded6b0e0d 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -696,5 +696,5 @@ def test_get_model_url( model_id=model_id, version=version, region=region, - s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 3bf63ed9e9..df2dd445f7 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1221,7 +1221,7 @@ def test_mime_type_enum_from_str(): def test_get_info_from_hub_resource_arn(): model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" ) assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( partition="aws", @@ -1233,7 +1233,7 @@ def test_get_info_from_hub_resource_arn(): hub_content_version="1.0.2", ) - notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Notebook/my-mock-notebook/1.0.2" + notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Notebook/my-mock-notebook/1.0.2" assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( partition="aws", region="us-west-2", @@ -1261,11 +1261,6 @@ def test_get_info_from_hub_resource_arn(): invalid_arn = "" assert None is utils.get_info_from_hub_resource_arn(invalid_arn) - invalid_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" - ) - assert None is utils.get_info_from_hub_resource_arn(invalid_arn) - def test_construct_hub_arn_from_name(): mock_sagemaker_session = Mock() diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 83aab2ff0f..af3d075943 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -92,6 +92,7 @@ def get_prototype_model_spec( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ffc6000c91..afba23b5a4 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -47,7 +47,7 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -63,7 +63,7 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 000540e12e..f9642ef3c5 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -46,7 +46,7 @@ def test_jumpstart_common_model_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -63,7 +63,7 @@ def test_jumpstart_common_model_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -81,7 +81,7 @@ def test_jumpstart_common_model_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -99,7 +99,7 @@ def test_jumpstart_common_model_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 28b53270f8..6f262562e7 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -45,7 +45,7 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -72,7 +72,7 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 3f38326608..6e2f8b37a3 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -46,7 +46,7 @@ def test_jumpstart_common_script_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -63,7 +63,7 @@ def test_jumpstart_common_script_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,7 +78,7 @@ def test_jumpstart_common_script_uri( sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -96,7 +96,7 @@ def test_jumpstart_common_script_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index b22b61dc40..80708c00c1 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -49,7 +49,7 @@ def test_jumpstart_default_serializers( region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -88,5 +88,5 @@ def test_jumpstart_serializer_options( region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=mock_client, hub_arn=None ) From 8a3160a06da23d4c10ae906c2f283b762b6cc8b4 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 22 Feb 2024 16:15:37 +0000 Subject: [PATCH 28/31] black styles --- src/sagemaker/jumpstart/utils.py | 2 +- .../hyperparameters/jumpstart/test_default.py | 16 +---- .../jumpstart/test_validate.py | 5 +- .../image_uris/jumpstart/test_common.py | 12 ++-- .../sagemaker/jumpstart/test_artifacts.py | 69 ++++++++++++------- .../jumpstart/test_notebook_utils.py | 3 +- .../model_uris/jumpstart/test_common.py | 12 ++-- .../jumpstart/test_resource_requirements.py | 10 +-- .../script_uris/jumpstart/test_common.py | 15 ++-- .../serializers/jumpstart/test_serializers.py | 10 +-- 10 files changed, 83 insertions(+), 71 deletions(-) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index f120b42396..3cea9ae7d9 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -870,7 +870,7 @@ def get_info_from_hub_resource_arn( account_id=account_id, hub_name=hub_name, ) - + return None diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index 346d1ceab8..ba08aa0825 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -43,11 +43,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version="*", - s3_client=mock_client, - hub_arn=None + region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -61,10 +57,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version="1.*", - s3_client=mock_client, hub_arn=None + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -86,10 +79,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version="1.*", - s3_client=mock_client, hub_arn=None + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 46df5cb87d..ae8138b7c5 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -136,10 +136,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, hub_arn=None + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 330a5dbcd3..b37b770875 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -48,7 +48,8 @@ def test_jumpstart_common_image_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -68,7 +69,8 @@ def test_jumpstart_common_image_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -88,7 +90,8 @@ def test_jumpstart_common_image_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -108,7 +111,8 @@ def test_jumpstart_common_image_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 3d26bd5739..400f286d97 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -11,7 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from importlib import reload import unittest from unittest.mock import Mock @@ -21,12 +20,16 @@ import copy from sagemaker.jumpstart import artifacts -from sagemaker.jumpstart.artifacts.environment_variables import _retrieve_default_environment_variables +from sagemaker.jumpstart.artifacts.environment_variables import ( + _retrieve_default_environment_variables, +) from sagemaker.jumpstart.artifacts.hyperparameters import _retrieve_default_hyperparameters from sagemaker.jumpstart.artifacts.image_uris import _retrieve_image_uri from sagemaker.jumpstart.artifacts.incremental_training import _model_supports_incremental_training from sagemaker.jumpstart.artifacts.instance_types import _retrieve_default_instance_type -from sagemaker.jumpstart.artifacts.metric_definitions import _retrieve_default_training_metric_definitions +from sagemaker.jumpstart.artifacts.metric_definitions import ( + _retrieve_default_training_metric_definitions, +) from sagemaker.jumpstart.artifacts.model_uris import ( _retrieve_hosting_prepacked_artifact_key, _retrieve_hosting_artifact_key, @@ -467,7 +470,6 @@ def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): class HubModelTest(unittest.TestCase): - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") def test_retrieve_default_environment_variables(self, mock_cache): mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) @@ -477,7 +479,10 @@ def test_retrieve_default_environment_variables(self, mock_cache): self.assertEqual( _retrieve_default_environment_variables( - model_id=model_id, model_version=version, hub_arn=hub_arn, script=JumpStartScriptScope.INFERENCE + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + script=JumpStartScriptScope.INFERENCE, ), { "SAGEMAKER_PROGRAM": "inference.py", @@ -487,8 +492,8 @@ def test_retrieve_default_environment_variables(self, mock_cache): "ENDPOINT_SERVER_TIMEOUT": "3600", "MODEL_CACHE_ROOT": "/opt/ml/model", "SAGEMAKER_ENV": "1", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1" - } + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, ) mock_cache.get_hub_model.assert_called_once_with( hub_model_arn=( @@ -496,7 +501,6 @@ def test_retrieve_default_environment_variables(self, mock_cache): ) ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") def test_retrieve_image_uri(self, mock_cache): mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) @@ -506,9 +510,13 @@ def test_retrieve_image_uri(self, mock_cache): self.assertEqual( _retrieve_image_uri( - model_id=model_id, model_version=version, hub_arn=hub_arn, instance_type="ml.p3.2xlarge", image_scope=JumpStartScriptScope.TRAINING + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + instance_type="ml.p3.2xlarge", + image_scope=JumpStartScriptScope.TRAINING, ), - "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3" + "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", ) mock_cache.get_hub_model.assert_called_once_with( hub_model_arn=( @@ -531,7 +539,7 @@ def test_retrieve_default_hyperparameters(self, mock_cache): "epochs": "3", "adam-learning-rate": "0.05", "batch-size": "4", - } + }, ) mock_cache.get_hub_model.assert_called_once_with( hub_model_arn=( @@ -550,7 +558,7 @@ def test_model_supports_incremental_training(self, mock_cache): _model_supports_incremental_training( model_id=model_id, model_version=version, hub_arn=hub_arn, region="us-west-2" ), - True + True, ) mock_cache.get_hub_model.assert_called_once_with( hub_model_arn=( @@ -567,9 +575,12 @@ def test_retrieve_default_instance_type(self, mock_cache): self.assertEqual( _retrieve_default_instance_type( - model_id=model_id, model_version=version, hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + scope=JumpStartScriptScope.TRAINING, ), - "ml.p3.2xlarge" + "ml.p3.2xlarge", ) mock_cache.get_hub_model.assert_called_once_with( hub_model_arn=( @@ -579,9 +590,12 @@ def test_retrieve_default_instance_type(self, mock_cache): self.assertEqual( _retrieve_default_instance_type( - model_id=model_id, model_version=version, hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + scope=JumpStartScriptScope.INFERENCE, ), - "ml.p2.xlarge" + "ml.p2.xlarge", ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") @@ -595,7 +609,7 @@ def test_retrieve_default_training_metric_definitions(self, mock_cache): _retrieve_default_training_metric_definitions( model_id=model_id, model_version=version, hub_arn=hub_arn, region="us-west-2" ), - [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}] + [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], ) mock_cache.get_hub_model.assert_called_once_with( hub_model_arn=( @@ -603,7 +617,6 @@ def test_retrieve_default_training_metric_definitions(self, mock_cache): ) ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") def test_retrieve_model_uri(self, mock_cache): mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) @@ -615,7 +628,7 @@ def test_retrieve_model_uri(self, mock_cache): _retrieve_model_uri( model_id=model_id, model_version=version, hub_arn=hub_arn, model_scope="training" ), - "s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" + "s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", ) mock_cache.get_hub_model.assert_called_once_with( hub_model_arn=( @@ -627,7 +640,7 @@ def test_retrieve_model_uri(self, mock_cache): _retrieve_model_uri( model_id=model_id, model_version=version, hub_arn=hub_arn, model_scope="inference" ), - "s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz" + "s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") @@ -639,9 +652,13 @@ def test_retrieve_script_uri(self, mock_cache): self.assertEqual( _retrieve_script_uri( - model_id=model_id, model_version=version, hub_arn=hub_arn, script_scope=JumpStartScriptScope.TRAINING + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + script_scope=JumpStartScriptScope.TRAINING, ), - "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/ic/v1.0.0/sourcedir.tar.gz", ) mock_cache.get_hub_model.assert_called_once_with( hub_model_arn=( @@ -651,7 +668,11 @@ def test_retrieve_script_uri(self, mock_cache): self.assertEqual( _retrieve_script_uri( - model_id=model_id, model_version=version, hub_arn=hub_arn, script_scope=JumpStartScriptScope.INFERENCE + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + script_scope=JumpStartScriptScope.INFERENCE, ), - "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz" + "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "inference/ic/v1.0.0/sourcedir.tar.gz", ) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 0ded6b0e0d..2fac16cc72 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -696,5 +696,6 @@ def test_get_model_url( model_id=model_id, version=version, region=region, - s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None + s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index f9642ef3c5..cde3258133 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -46,7 +46,8 @@ def test_jumpstart_common_model_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -63,7 +64,8 @@ def test_jumpstart_common_model_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -81,7 +83,8 @@ def test_jumpstart_common_model_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -99,7 +102,8 @@ def test_jumpstart_common_model_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 6f262562e7..86031fbd57 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -42,10 +42,7 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): assert default_inference_resource_requirements.requests["memory"] == 34360 patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, hub_arn=None + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -69,10 +66,7 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): assert default_inference_resource_requirements is None patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, hub_arn=None + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 6e2f8b37a3..5811a9d822 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -46,7 +46,8 @@ def test_jumpstart_common_script_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -63,7 +64,8 @@ def test_jumpstart_common_script_uri( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,7 +80,11 @@ def test_jumpstart_common_script_uri( sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, hub_arn=None + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -96,7 +102,8 @@ def test_jumpstart_common_script_uri( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*", - s3_client=mock_client, hub_arn=None + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 80708c00c1..94354e782e 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -46,10 +46,7 @@ def test_jumpstart_default_serializers( assert isinstance(default_serializer, base_serializers.IdentitySerializer) patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, hub_arn=None + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -85,8 +82,5 @@ def test_jumpstart_serializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, hub_arn=None + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) From 151350cf7e8f29059de734a17b04a98175cfef13 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Thu, 22 Feb 2024 19:28:49 +0000 Subject: [PATCH 29/31] finish tests --- .../sagemaker/jumpstart/test_artifacts.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 400f286d97..8acd04f1f6 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -470,9 +470,9 @@ def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): class HubModelTest(unittest.TestCase): - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") - def test_retrieve_default_environment_variables(self, mock_cache): - mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_default_environment_variables(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" @@ -495,15 +495,15 @@ def test_retrieve_default_environment_variables(self, mock_cache): "SAGEMAKER_MODEL_SERVER_WORKERS": "1", }, ) - mock_cache.get_hub_model.assert_called_once_with( + mock_get_hub_model.assert_called_once_with( hub_model_arn=( f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" ) ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") - def test_retrieve_image_uri(self, mock_cache): - mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_image_uri(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" @@ -518,15 +518,15 @@ def test_retrieve_image_uri(self, mock_cache): ), "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", ) - mock_cache.get_hub_model.assert_called_once_with( + mock_get_hub_model.assert_called_once_with( hub_model_arn=( f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" ) ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") - def test_retrieve_default_hyperparameters(self, mock_cache): - mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_default_hyperparameters(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" @@ -541,15 +541,15 @@ def test_retrieve_default_hyperparameters(self, mock_cache): "batch-size": "4", }, ) - mock_cache.get_hub_model.assert_called_once_with( + mock_get_hub_model.assert_called_once_with( hub_model_arn=( f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" ) ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") - def test_model_supports_incremental_training(self, mock_cache): - mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_model_supports_incremental_training(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" @@ -560,15 +560,15 @@ def test_model_supports_incremental_training(self, mock_cache): ), True, ) - mock_cache.get_hub_model.assert_called_once_with( + mock_get_hub_model.assert_called_once_with( hub_model_arn=( f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" ) ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") - def test_retrieve_default_instance_type(self, mock_cache): - mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_default_instance_type(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" @@ -582,7 +582,7 @@ def test_retrieve_default_instance_type(self, mock_cache): ), "ml.p3.2xlarge", ) - mock_cache.get_hub_model.assert_called_once_with( + mock_get_hub_model.assert_called_once_with( hub_model_arn=( f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" ) @@ -598,9 +598,9 @@ def test_retrieve_default_instance_type(self, mock_cache): "ml.p2.xlarge", ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") - def test_retrieve_default_training_metric_definitions(self, mock_cache): - mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_default_training_metric_definitions(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" @@ -611,15 +611,15 @@ def test_retrieve_default_training_metric_definitions(self, mock_cache): ), [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], ) - mock_cache.get_hub_model.assert_called_once_with( + mock_get_hub_model.assert_called_once_with( hub_model_arn=( f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" ) ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") - def test_retrieve_model_uri(self, mock_cache): - mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_model_uri(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" @@ -630,7 +630,7 @@ def test_retrieve_model_uri(self, mock_cache): ), "s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", ) - mock_cache.get_hub_model.assert_called_once_with( + mock_get_hub_model.assert_called_once_with( hub_model_arn=( f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" ) @@ -643,9 +643,9 @@ def test_retrieve_model_uri(self, mock_cache): "s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") - def test_retrieve_script_uri(self, mock_cache): - mock_cache.get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_script_uri(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" @@ -660,7 +660,7 @@ def test_retrieve_script_uri(self, mock_cache): "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" "transfer_learning/ic/v1.0.0/sourcedir.tar.gz", ) - mock_cache.get_hub_model.assert_called_once_with( + mock_get_hub_model.assert_called_once_with( hub_model_arn=( f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" ) From dd087da40194b31b80c5135e81e2299ba8373781 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Fri, 23 Feb 2024 20:31:14 +0000 Subject: [PATCH 30/31] create curated hub utils and types --- src/sagemaker/environment_variables.py | 2 +- src/sagemaker/hyperparameters.py | 4 +- src/sagemaker/image_uris.py | 2 +- src/sagemaker/instance_types.py | 4 +- src/sagemaker/jumpstart/accessors.py | 5 +- .../artifacts/environment_variables.py | 4 +- .../jumpstart/artifacts/hyperparameters.py | 2 +- .../jumpstart/artifacts/image_uris.py | 2 +- .../artifacts/incremental_training.py | 2 +- .../jumpstart/artifacts/instance_types.py | 4 +- src/sagemaker/jumpstart/artifacts/kwargs.py | 2 +- .../jumpstart/artifacts/metric_definitions.py | 2 +- .../jumpstart/artifacts/model_packages.py | 2 +- .../jumpstart/artifacts/model_uris.py | 4 +- src/sagemaker/jumpstart/artifacts/payloads.py | 2 +- .../jumpstart/artifacts/resource_names.py | 2 +- .../jumpstart/artifacts/script_uris.py | 2 +- src/sagemaker/jumpstart/cache.py | 12 +- src/sagemaker/jumpstart/curated_hub/types.py | 50 ++++++ src/sagemaker/jumpstart/curated_hub/utils.py | 110 ++++++++++++++ src/sagemaker/jumpstart/estimator.py | 12 +- src/sagemaker/jumpstart/types.py | 43 +----- src/sagemaker/jumpstart/utils.py | 91 +---------- src/sagemaker/jumpstart/validators.py | 2 +- src/sagemaker/metric_definitions.py | 2 +- src/sagemaker/model_uris.py | 2 +- src/sagemaker/script_uris.py | 2 +- .../jumpstart/curated_hub/test_utils.py | 142 ++++++++++++++++++ .../jumpstart/estimator/test_estimator.py | 2 +- tests/unit/sagemaker/jumpstart/test_utils.py | 125 --------------- tests/unit/sagemaker/jumpstart/utils.py | 6 +- 31 files changed, 354 insertions(+), 294 deletions(-) create mode 100644 src/sagemaker/jumpstart/curated_hub/types.py create mode 100644 src/sagemaker/jumpstart/curated_hub/utils.py create mode 100644 tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 53da690a7a..17f8ebdf2c 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -48,7 +48,7 @@ def retrieve_default( model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 55ee28e073..83554b302b 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -48,7 +48,7 @@ def retrieve_default( model_version (str): The version of the model for which to retrieve the default hyperparameters. (Default: None). hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. include_container_hyperparameters (bool): ``True`` if the container hyperparameters @@ -113,7 +113,7 @@ def validate( model_version (str): The version of the model for which to validate hyperparameters. (Default: None). hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). hyperparameters (dict): Hyperparameters to validate. (Default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index ee33e2d8dc..2b6870a11c 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -103,7 +103,7 @@ def retrieve( model_version (str): The version of the JumpStart model for which to retrieve the image URI (default: None). hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without an exception raised. If ``False``, raises an exception if the script used by this version of the model has dependencies with known security diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 5770c08f5b..d277f8cf3b 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -46,7 +46,7 @@ def retrieve_default( model_version (str): The version of the model for which to retrieve the default instance type. (Default: None). hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -113,7 +113,7 @@ def retrieve( model_version (str): The version of the model for which to retrieve the supported instance types. (Default: None). hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 34a1764110..456ba6baf6 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -18,7 +18,8 @@ from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs -from sagemaker.jumpstart import cache, utils +from sagemaker.jumpstart import cache +from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -265,7 +266,7 @@ def get_model_specs( JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) if hub_arn: - hub_model_arn = utils.construct_hub_model_arn_from_inputs( + hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 0f0522b8a9..b85cfe4572 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -48,7 +48,7 @@ def _retrieve_default_environment_variables( model_version (str): Version of the JumpStart model for which to retrieve the default environment variables. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). region (Optional[str]): Region for which to retrieve default environment variables. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -151,7 +151,7 @@ def _retrieve_gated_model_uri_env_var_value( model_version (str): Version of the JumpStart model for which to retrieve the gated model env var URI. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). region (Optional[str]): Region for which to retrieve the gated model env var URI. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index f131d4a203..6b9689485f 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -46,7 +46,7 @@ def _retrieve_default_hyperparameters( model_version (str): Version of the JumpStart model for which to retrieve the default hyperparameters. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). region (str): Region for which to retrieve default hyperparameters. (Default: None). include_container_hyperparameters (bool): True if container hyperparameters diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 7952048050..bb7b0c7528 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -59,7 +59,7 @@ def _retrieve_image_uri( model_version (str): Version of the JumpStart model for which to retrieve the image URI. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). image_scope (str): The image type, i.e. what it is used for. Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 1c392199cb..f877068e77 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -45,7 +45,7 @@ def _model_supports_incremental_training( region (Optional[str]): Region for which to retrieve the support status for incremental training. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index cf283c0b9e..5f252f00ad 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -50,7 +50,7 @@ def _retrieve_default_instance_type( scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). region (Optional[str]): Region for which to retrieve default instance type. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -140,7 +140,7 @@ def _retrieve_instance_types( scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). region (Optional[str]): Region for which to retrieve supported instance types. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 9adb838549..69b8bfb51a 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -156,7 +156,7 @@ def _retrieve_estimator_init_kwargs( instance_type (str): Instance type of the training job, to determine if volume size is supported. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index fce09681f0..64e53dbbbe 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -47,7 +47,7 @@ def _retrieve_default_training_metric_definitions( region (Optional[str]): Region for which to retrieve default training metric definitions. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index e013ac584e..e49d14682d 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -126,7 +126,7 @@ def _retrieve_model_package_model_artifact_s3_uri( region (Optional[str]): Region for which to retrieve the model package artifact. (Default: None). hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). scope (Optional[str]): Scope for which to retrieve the model package artifact. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 30bf1e9521..df9e1fd507 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -107,7 +107,7 @@ def _retrieve_model_uri( model_version (str): Version of the JumpStart model for which to retrieve the model artifact S3 URI. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -197,7 +197,7 @@ def _model_supports_training_model_uri( region (Optional[str]): Region for which to retrieve the support status for model uri with training. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index eaf740a8dc..21fca2abd2 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -47,7 +47,7 @@ def _retrieve_example_payloads( region (Optional[str]): Region for which to retrieve the example payloads. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 3f558219e7..ca3044068b 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -45,7 +45,7 @@ def _retrieve_resource_name_base( region (Optional[str]): Region for which to retrieve the default resource name. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index ab4bd40d2c..c04ae88ca3 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -49,7 +49,7 @@ def _retrieve_script_uri( model_version (str): Version of the JumpStart model for which to retrieve the model script S3 URI. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 905178ca6f..2fbc558aed 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -44,7 +44,7 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, - HubDataType, + HubContentType, ) from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache @@ -338,7 +338,7 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubDataType.MODEL: + if data_type == HubContentType.MODEL: info = utils.get_info_from_hub_resource_arn( id_info ) @@ -355,7 +355,7 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubDataType.HUB: + if data_type == HubContentType.HUB: info = utils.get_info_from_hub_resource_arn( id_info ) @@ -364,7 +364,7 @@ def _retrieval_function( return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( f"Bad value for key '{key}': must be in", - f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}" + f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}" ) def get_manifest(self) -> List[JumpStartModelHeader]: @@ -478,7 +478,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """ details, _ = self._content_cache.get( - JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) + JumpStartCachedContentKey(HubContentType.MODEL, hub_model_arn) ) return details.formatted_content @@ -489,7 +489,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn)) return details.formatted_content def clear(self) -> None: diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py new file mode 100644 index 0000000000..3f941e21e3 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import +from typing import Optional + +from sagemaker.jumpstart.types import JumpStartDataHolderType + +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_type = hub_content_type + self.hub_content_name = hub_content_name + self.hub_content_version = hub_content_version diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py new file mode 100644 index 0000000000..ff71f96f26 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -0,0 +1,110 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains utilities related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import +import re +from typing import Optional +from sagemaker.jumpstart import constants + +from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo +from sagemaker.jumpstart.types import HubContentType +from sagemaker.session import Session +from sagemaker.utils import aws_partition + +def get_info_from_hub_resource_arn( + arn: str, +) -> HubArnExtractedInfo: + """Extracts descriptive information from a Hub or HubContent Arn.""" + + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + hub_content_type = match.group(5) + hub_content_name = match.group(6) + hub_content_version = match.group(7) + + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + hub_content_type=hub_content_type, + hub_content_name=hub_content_name, + hub_content_version=hub_content_version, + ) + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + ) + + return None + + +def construct_hub_arn_from_name( + hub_name: str, + region: Optional[str] = None, + session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Constructs a Hub arn from the Hub name using default Session values.""" + + account_id = session.account_id() + region = region or session.boto_region_name + partition = aws_partition(region) + + return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" + + +def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL}/{model_name}/{version}" + ) + + return arn + + +# TODO: Update to recognize JumpStartHub hub_name +def generate_hub_arn_for_estimator_init_kwargs( + hub_name: str, region: Optional[str] = None, session: Optional[Session] = None +): + """Generates the Hub Arn for JumpStartEstimator from a HubName or Arn. + + Args: + hub_name (str): HubName or HubArn from JumpStartEstimator args + region (str): Region from JumpStartEstimator args + session (Session): Custom SageMaker Session from JumpStartEstimator args + """ + + hub_arn = None + if hub_name: + match = re.match(constants.HUB_ARN_REGEX, hub_name) + if match: + hub_arn = hub_name + else: + hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) + return hub_arn diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 2a8273004b..53d12b46a6 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -28,6 +28,7 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_estimator_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -35,7 +36,6 @@ from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.session_utils import get_model_id_version_from_training_job from sagemaker.jumpstart.utils import ( - generate_hub_arn_for_estimator, is_valid_model_id, resolve_model_sagemaker_config_field, ) @@ -521,12 +521,16 @@ def _is_valid_model_id_hook(): if not _is_valid_model_id_hook(): raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_estimator_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, - hub_arn=generate_hub_arn_for_estimator( - hub_name=hub_name, region=region, session=sagemaker_session - ), + hub_arn=hub_arn, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 947250fd66..0b43341339 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -106,15 +106,15 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" -class HubDataType(str, Enum): +class HubContentType(str, Enum): """Enum for Hub data storage objects.""" - HUB = "hub" - MODEL = "model" - NOTEBOOK = "notebook" + HUB = "Hub" + MODEL = "Model" + NOTEBOOK = "Notebook" -JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType] +JumpStartContentDataType = Union[JumpStartS3FileType, HubContentType] class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -1747,36 +1747,3 @@ def __init__( self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation - -class HubArnExtractedInfo(JumpStartDataHolderType): - """Data class for info extracted from Hub arn.""" - - __slots__ = [ - "partition", - "region", - "account_id", - "hub_name", - "hub_content_type", - "hub_content_name", - "hub_content_version", - ] - - def __init__( - self, - partition: str, - region: str, - account_id: str, - hub_name: str, - hub_content_type: Optional[str] = None, - hub_content_name: Optional[str] = None, - hub_content_version: Optional[str] = None, - ) -> None: - """Instantiates HubArnExtractedInfo object.""" - - self.partition = partition - self.region = region - self.account_id = account_id - self.hub_name = hub_name - self.hub_content_type = hub_content_type - self.hub_content_name = hub_content_name - self.hub_content_version = hub_content_version diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 3cea9ae7d9..d954b675db 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -36,7 +36,6 @@ get_old_model_version_msg, ) from sagemaker.jumpstart.types import ( - HubArnExtractedInfo, JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, @@ -560,7 +559,7 @@ def verify_model_region_and_return_specs( region (Optional[str]): region of the JumpStart model to verify and obtains specs. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -832,91 +831,3 @@ def get_jumpstart_model_id_version_from_resource_arn( return model_id, model_version - -def get_info_from_hub_resource_arn( - arn: str, -) -> HubArnExtractedInfo: - """Extracts descriptive information from a Hub or HubContent Arn.""" - - match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) - if match: - partition = match.group(1) - hub_region = match.group(2) - account_id = match.group(3) - hub_name = match.group(4) - hub_content_type = match.group(5) - hub_content_name = match.group(6) - hub_content_version = match.group(7) - - return HubArnExtractedInfo( - partition=partition, - region=hub_region, - account_id=account_id, - hub_name=hub_name, - hub_content_type=hub_content_type, - hub_content_name=hub_content_name, - hub_content_version=hub_content_version, - ) - - match = re.match(constants.HUB_ARN_REGEX, arn) - if match: - partition = match.group(1) - hub_region = match.group(2) - account_id = match.group(3) - hub_name = match.group(4) - return HubArnExtractedInfo( - partition=partition, - region=hub_region, - account_id=account_id, - hub_name=hub_name, - ) - - return None - - -def construct_hub_arn_from_name( - hub_name: str, - region: Optional[str] = None, - session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Constructs a Hub arn from the Hub name using default Session values.""" - - account_id = session.account_id() - region = region or session.boto_region_name - partition = aws_partition(region) - - return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" - - -def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: - """Constructs a HubContent model arn from the Hub name, model name, and model version.""" - - info = get_info_from_hub_resource_arn(hub_arn) - arn = ( - f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" - f"{info.hub_name}/Model/{model_name}/{version}" - ) - - return arn - - -# TODO: Update to recognize JumpStartHub hub_name -def generate_hub_arn_for_estimator( - hub_name: Optional[str] = None, region: Optional[str] = None, session: Optional[Session] = None -): - """Generates the Hub Arn for JumpStartEstimator from a HubName or Arn. - - Args: - hub_name (str): HubName or HubArn from JumpStartEstimator args - region (str): Region from JumpStartEstimator args - session (Session): Custom SageMaker Session from JumpStartEstimator args - """ - - hub_arn = None - if hub_name: - match = re.match(constants.HUB_ARN_REGEX, hub_name) - if match: - hub_arn = hub_name - else: - hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) - return hub_arn diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index c06a3d1bbd..55cbdd90eb 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -181,7 +181,7 @@ def validate_hyperparameters( model_version (str): Version of the model for which to validate hyperparameters. hyperparameters (dict): Hyperparameters to validate. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided to this function will be validated, the missing hyperparameters will be ignored. diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index fff72a7d65..a31d5d930d 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -45,7 +45,7 @@ def retrieve_default( model_version (str): The version of the model for which to retrieve the default training metric definitions. (Default: None). hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. tolerate_vulnerable_model (bool): True if vulnerable versions of model diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 49895bbd49..a2177c0ec5 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -45,7 +45,7 @@ def retrieve( model_version (str): The version of the JumpStart model for which to retrieve the model artifact S3 URI. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). model_scope (str): The model type. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 44327d6056..9341a4198f 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -44,7 +44,7 @@ def retrieve( model_version (str): The version of the JumpStart model for which to retrieve the model script S3 URI. hub_arn (str): The arn of the SageMaker Hub for which to retrieve - model details from (default: None). + model details from. (default: None). script_scope (str): The script type. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py new file mode 100644 index 0000000000..40d3bbdab0 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -0,0 +1,142 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import Mock +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME + +from sagemaker.jumpstart.curated_hub import utils +from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo + +def test_get_info_from_hub_resource_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Model", + hub_content_name="my-mock-model", + hub_content_version="1.0.2", + ) + + notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Notebook/my-mock-notebook/1.0.2" + assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Notebook", + hub_content_name="my-mock-notebook", + hub_content_version="1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.get_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + ) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "nonsense-string" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + +def test_construct_hub_arn_from_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-west-2" + hub_name = "my-cool-hub" + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=mock_sagemaker_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-cool-hub" + ) + + assert ( + utils.construct_hub_arn_from_name( + hub_name=hub_name, region="us-east-1", session=mock_sagemaker_session + ) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-cool-hub" + ) + + +def test_construct_hub_model_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/*" + ) + + +def test_generate_hub_arn_for_estimator_init_kwargs(): + hub_name = "my-hub-name" + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock default session with default values + mock_default_session = Mock() + mock_default_session.account_id.return_value = "123456789123" + mock_default_session.boto_region_name = JUMPSTART_DEFAULT_REGION_NAME + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, session=mock_default_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, "us-east-1", session=mock_default_session) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, "eu-west-1", mock_custom_session) + == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, None, mock_custom_session) + == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" + ) + + assert utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, session=mock_default_session) == hub_arn + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, "us-east-1", session=mock_default_session) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn + ) + + assert utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 96c87cdf4d..c8fc541816 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -287,7 +287,7 @@ def test_prepacked( @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_deploy_kwargs") @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_fit_kwargs") - @mock.patch("sagemaker.jumpstart.utils.construct_hub_arn_from_name") + @mock.patch("sagemaker.jumpstart.curated_hub.utils.construct_hub_arn_from_name") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index df2dd445f7..c42c15ecf5 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -38,7 +38,6 @@ VulnerableJumpStartModelError, ) from sagemaker.jumpstart.types import ( - HubArnExtractedInfo, JumpStartModelHeader, JumpStartVersionedModelId, ) @@ -1219,130 +1218,6 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type -def test_get_info_from_hub_resource_arn(): - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" - ) - assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( - partition="aws", - region="us-west-2", - account_id="000000000000", - hub_name="MockHub", - hub_content_type="Model", - hub_content_name="my-mock-model", - hub_content_version="1.0.2", - ) - - notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Notebook/my-mock-notebook/1.0.2" - assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( - partition="aws", - region="us-west-2", - account_id="000000000000", - hub_name="MockHub", - hub_content_type="Notebook", - hub_content_name="my-mock-notebook", - hub_content_version="1.0.2", - ) - - hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" - assert utils.get_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( - partition="aws", - region="us-west-2", - account_id="000000000000", - hub_name="MockHub", - ) - - invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" - assert None is utils.get_info_from_hub_resource_arn(invalid_arn) - - invalid_arn = "nonsense-string" - assert None is utils.get_info_from_hub_resource_arn(invalid_arn) - - invalid_arn = "" - assert None is utils.get_info_from_hub_resource_arn(invalid_arn) - - -def test_construct_hub_arn_from_name(): - mock_sagemaker_session = Mock() - mock_sagemaker_session.account_id.return_value = "123456789123" - mock_sagemaker_session.boto_region_name = "us-west-2" - hub_name = "my-cool-hub" - - assert ( - utils.construct_hub_arn_from_name(hub_name=hub_name, session=mock_sagemaker_session) - == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-cool-hub" - ) - - assert ( - utils.construct_hub_arn_from_name( - hub_name=hub_name, region="us-east-1", session=mock_sagemaker_session - ) - == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-cool-hub" - ) - - -def test_construct_hub_model_arn_from_inputs(): - model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" - hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" - - assert ( - utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) - == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/1.0.2" - ) - - version = "*" - assert ( - utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) - == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/*" - ) - - -def test_generate_hub_arn_for_estimator(): - hub_name = "my-hub-name" - hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" - # Mock default session with default values - mock_default_session = Mock() - mock_default_session.account_id.return_value = "123456789123" - mock_default_session.boto_region_name = JUMPSTART_DEFAULT_REGION_NAME - # Mock custom session with custom values - mock_custom_session = Mock() - mock_custom_session.account_id.return_value = "000000000000" - mock_custom_session.boto_region_name = "us-east-2" - - assert ( - utils.generate_hub_arn_for_estimator(hub_name, session=mock_default_session) - == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" - ) - - assert ( - utils.generate_hub_arn_for_estimator(hub_name, "us-east-1", session=mock_default_session) - == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" - ) - - assert ( - utils.generate_hub_arn_for_estimator(hub_name, "eu-west-1", mock_custom_session) - == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" - ) - - assert ( - utils.generate_hub_arn_for_estimator(hub_name, None, mock_custom_session) - == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" - ) - - assert utils.generate_hub_arn_for_estimator(hub_arn, session=mock_default_session) == hub_arn - - assert ( - utils.generate_hub_arn_for_estimator(hub_arn, "us-east-1", session=mock_default_session) - == hub_arn - ) - - assert ( - utils.generate_hub_arn_for_estimator(hub_arn, "us-east-1", mock_custom_session) == hub_arn - ) - - assert utils.generate_hub_arn_for_estimator(hub_arn, None, mock_custom_session) == hub_arn - - class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index af3d075943..81526485f9 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,7 +22,7 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubDataType, + HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -203,14 +203,14 @@ def patched_retrieval_function( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - if datatype == HubDataType.MODEL: + if datatype == HubContentType.MODEL: _, _, _, model_name, model_version = id_info.split("/") return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) ) # TODO: Implement - if datatype == HubDataType.HUB: + if datatype == HubContentType.HUB: return None raise ValueError(f"Bad value for filetype: {datatype}") From a61dfb4a4b5261885742571d60a8b9aa68478591 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Fri, 23 Feb 2024 20:36:01 +0000 Subject: [PATCH 31/31] fix linter --- src/sagemaker/jumpstart/cache.py | 5 ++-- src/sagemaker/jumpstart/curated_hub/types.py | 1 + src/sagemaker/jumpstart/curated_hub/utils.py | 1 + src/sagemaker/jumpstart/types.py | 1 - src/sagemaker/jumpstart/utils.py | 4 +--- .../jumpstart/curated_hub/test_utils.py | 23 ++++++++++++++----- 6 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 2fbc558aed..9dc505a2ff 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -30,6 +30,7 @@ MODEL_ID_LIST_WEB_URL, ) from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub +from sagemaker.jumpstart.curated_hub.utils import get_info_from_hub_resource_arn from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, @@ -339,7 +340,7 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubContentType.MODEL: - info = utils.get_info_from_hub_resource_arn( + info = get_info_from_hub_resource_arn( id_info ) hub = CuratedHub(hub_name=info.hub_name, region=info.region) @@ -356,7 +357,7 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubContentType.HUB: - info = utils.get_info_from_hub_resource_arn( + info = get_info_from_hub_resource_arn( id_info ) hub = CuratedHub(hub_name=info.hub_name, region=info.region) diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py index 3f941e21e3..d400137905 100644 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -16,6 +16,7 @@ from sagemaker.jumpstart.types import JumpStartDataHolderType + class HubArnExtractedInfo(JumpStartDataHolderType): """Data class for info extracted from Hub arn.""" diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index ff71f96f26..5c7a91382b 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -21,6 +21,7 @@ from sagemaker.session import Session from sagemaker.utils import aws_partition + def get_info_from_hub_resource_arn( arn: str, ) -> HubArnExtractedInfo: diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 0b43341339..22cc70fcab 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1746,4 +1746,3 @@ def __init__( self.nearest_model_name = nearest_model_name self.data_input_configuration = data_input_configuration self.skip_model_validation = skip_model_validation - diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index d954b675db..1e2bb11d45 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,6 @@ from __future__ import absolute_import import logging import os -import re from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse import boto3 @@ -42,7 +41,7 @@ ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import aws_partition, resolve_value_from_config, TagsDict +from sagemaker.utils import resolve_value_from_config, TagsDict from sagemaker.workflow import is_pipeline_variable @@ -830,4 +829,3 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version - diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 40d3bbdab0..2f0841b4ea 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -17,6 +17,7 @@ from sagemaker.jumpstart.curated_hub import utils from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo + def test_get_info_from_hub_resource_arn(): model_arn = ( "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" @@ -113,7 +114,9 @@ def test_generate_hub_arn_for_estimator_init_kwargs(): ) assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, "us-east-1", session=mock_default_session) + utils.generate_hub_arn_for_estimator_init_kwargs( + hub_name, "us-east-1", session=mock_default_session + ) == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" ) @@ -127,16 +130,24 @@ def test_generate_hub_arn_for_estimator_init_kwargs(): == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" ) - assert utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, session=mock_default_session) == hub_arn - assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, "us-east-1", session=mock_default_session) + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, session=mock_default_session) == hub_arn ) assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn + utils.generate_hub_arn_for_estimator_init_kwargs( + hub_arn, "us-east-1", session=mock_default_session + ) + == hub_arn ) - assert utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, "us-east-1", mock_custom_session) + == hub_arn + ) + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) + == hub_arn + )