From 492251137e78f3aaac7cebdbfbe0c8dc2b59a898 Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Thu, 22 Feb 2024 21:22:29 +0000 Subject: [PATCH 01/14] initial barebone for hub utils and curated hub --- .../jumpstart/curated_hub/curated_hub.py | 58 +++-- src/sagemaker/jumpstart/curated_hub/utils.py | 203 ++++++++++++++++++ 2 files changed, 248 insertions(+), 13 deletions(-) create mode 100644 src/sagemaker/jumpstart/curated_hub/utils.py diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 273deb097b..ea57fed298 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -14,36 +14,68 @@ from __future__ import absolute_import from typing import Optional, Dict, Any - +import boto3 from sagemaker.session import Session +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_REGION_NAME, +) + +from sagemaker.jumpstart.types import HubDataType +import sagemaker.jumpstart.curated_hub.utils as hubutils class CuratedHub: """Class for creating and managing a curated JumpStart hub""" - def __init__(self, hub_name: str, region: str, session: Optional[Session] = None): - self.hub_name = hub_name + def __init__( + self, + name: str, + region: str = JUMPSTART_DEFAULT_REGION_NAME, + session: Optional[Session] = None, + ): + self.name = name + if session.boto_region_name != region: + # TODO: Handle error + pass self.region = region - self.session = session - self._sm_session = session or Session() + self._session = session or Session(boto3.Session(region_name=region)) + + def create( + self, + description: str, + display_name: Optional[str] = None, + search_keywords: Optional[str] = None, + bucket_name: Optional[str] = None, + tags: Optional[str] = None, + ) -> Dict[str, str]: + """Creates a hub with the given description""" + + return hubutils.create_hub( + hub_name=self.name, + hub_description=description, + hub_display_name=display_name, + hub_search_keywords=search_keywords, + hub_bucket_name=bucket_name, + tags=tags, + sagemaker_session=self._session, + ) def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: """Returns descriptive information about the Hub Model""" - hub_content = self._sm_session.describe_hub_content( - model_name, "Model", self.hub_name, model_version + hub_content = hubutils.describe_hub_content( + hub_name=self.name, + content_name=model_name, + content_type=HubDataType.MODEL, + content_version=model_version, + sagemaker_session=self._session, ) - # TODO: Parse HubContent - # TODO: Parse HubContentDocument - return hub_content def describe(self) -> Dict[str, Any]: """Returns descriptive information about the Hub""" - hub_info = self._sm_session.describe_hub(hub_name=self.hub_name) - - # TODO: Validations? + hub_info = hubutils.describe_hub(hub_name=self.name, sagemaker_session=self._session) return hub_info diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py new file mode 100644 index 0000000000..205c87329b --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -0,0 +1,203 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Mid-level wrappers to HubService API. These utilities handles parsing, custom +errors, and validations on top of the low-level HubService API calls in Session.""" +from __future__ import absolute_import +from typing import Optional, Dict, Any, List + +from sagemaker.jumpstart.types import HubDataType +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) +from sagemaker.session import Session + + +# def _validate_hub_name(hub_name: str) -> bool: +# """Validates hub_name to be either a name or a full ARN""" +# pass + + +def _generate_default_hub_bucket_name( + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Return the name of the default bucket to use in relevant Amazon SageMaker interactions. + + This function will create the s3 bucket if it does not exist. + + Returns: + str: The name of the default bucket. If the name was not explicitly specified through + the Session or sagemaker_config, the bucket will take the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + account_id: str = sagemaker_session.account_id() + + # TODO: Validate and fast fail + + return f"sagemaker-hubs-{region}-{account_id}" + + +def create_hub( + hub_name: str, + hub_description: str, + hub_display_name: str = None, + hub_search_keywords: Optional[List[str]] = None, + hub_bucket_name: Optional[str] = None, + tags: Optional[List[Dict[str, Any]]] = None, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Creates a SageMaker Hub + + Returns: + (str): Arn of the created hub. + """ + + if hub_bucket_name is None: + hub_bucket_name = _generate_default_hub_bucket_name(sagemaker_session) + s3_storage_config = {"S3OutputPath": hub_bucket_name} + response = sagemaker_session.create_hub( + hub_name, hub_description, hub_display_name, hub_search_keywords, s3_storage_config, tags + ) + + # TODO: Custom error message + + hub_arn = response["HubArn"] + return hub_arn + + +def describe_hub( + hub_name: str, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION +) -> Dict[str, Any]: + """Returns descriptive information about the Hub""" + # TODO: hub_name validation and fast-fail + + response = sagemaker_session.describe_hub(hub_name=hub_name) + + # TODO: Make HubInfo and parse response? + # TODO: Custom error message + + return response + + +def delete_hub(hub_name, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION) -> None: + """Deletes a SageMaker Hub""" + response = sagemaker_session.delete_hub(hub_name=hub_name) + + # TODO: Custom error message + + return response + + +def import_hub_content( + document_schema_version: str, + hub_name: str, + hub_content_name: str, + hub_content_type: str, + hub_content_document: str, + hub_content_display_name: str = None, + hub_content_description: str = None, + hub_content_version: str = None, + hub_content_markdown: str = None, + hub_content_search_keywords: List[str] = None, + tags: List[Dict[str, Any]] = None, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Dict[str, str]: + """Imports a new HubContent into a SageMaker Hub + + Returns arns for the Hub and the HubContent where import was successful. + """ + + response = sagemaker_session.import_hub_content( + document_schema_version, + hub_name, + hub_content_name, + hub_content_type, + hub_content_document, + hub_content_display_name, + hub_content_description, + hub_content_version, + hub_content_markdown, + hub_content_search_keywords, + tags, + ) + return response + + +def list_hub_contents( + hub_name: str, + hub_content_type: HubDataType.MODEL or HubDataType.NOTEBOOK, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Dict[str, Any]: + """List contents of a hub.""" + + response = sagemaker_session.list_hub_contents( + hub_name, + hub_content_type, + creation_time_after, + creation_time_before, + max_results, + max_schema_version, + name_contains, + next_token, + sort_by, + sort_order, + ) + return response + + +def describe_hub_content( + hub_name: str, + content_name: str, + content_type: HubDataType, + content_version: str = None, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> Dict[str, Any]: + """Returns descriptive information about the content of a hub.""" + # TODO: hub_name validation and fast-fail + + hub_content: Dict[str, Any] = sagemaker_session.describe_hub_content( + hub_content_name=content_name, + hub_content_type=content_type, + hub_name=hub_name, + hub_content_version=content_version, + ) + + # TODO: Parse HubContent + # TODO: Parse HubContentDocument + + return hub_content + + +def delete_hub_content( + hub_content_name: str, + hub_content_version: str, + hub_content_type: str, + hub_name: str, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> None: + """Deletes a given HubContent in a SageMaker Hub""" + # TODO: Validate hub name + + response = sagemaker_session.delete_hub_content( + hub_content_name, hub_content_version, hub_content_type, hub_name + ) + return response From f4da2ad156fea1113dd884bc23bfb176e953b260 Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Fri, 23 Feb 2024 01:00:03 +0000 Subject: [PATCH 02/14] refactor, add types for hub/hubcontent descriptions, add helpers. --- src/sagemaker/jumpstart/cache.py | 34 +-- .../jumpstart/curated_hub/curated_hub.py | 76 +++++-- src/sagemaker/jumpstart/curated_hub/utils.py | 200 ++--------------- src/sagemaker/jumpstart/session_utils.py | 42 ++++ src/sagemaker/jumpstart/types.py | 210 +++++++++++++++++- tests/unit/sagemaker/jumpstart/utils.py | 6 +- 6 files changed, 344 insertions(+), 224 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index d733f39864..50ec12e894 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -29,7 +29,6 @@ JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, ) -from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, @@ -44,9 +43,12 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, - HubDataType, + HubContentType, + HubDescription, + HubContentDescription, ) from sagemaker.jumpstart import utils +from sagemaker.jumpstart.curated_hub import utils as hub_utils from sagemaker.utilities.cache import LRUCache @@ -338,29 +340,33 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubDataType.MODEL: + if data_type == HubContentType.MODEL: hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn( id_info ) - hub = CuratedHub(hub_name=hub_name, region=region) - hub_content = hub.describe_model(model_name=model_name, model_version=model_version) + hub_model_description: HubContentDescription = hub_utils.describe_model( + hub_name=hub_name, + region=region, + model_name=model_name, + model_version=model_version + ) + model_specs = JumpStartModelSpecs(hub_model_description, is_hub_content=True) utils.emit_logs_based_on_model_specs( - hub_content.content_document, + model_specs, self.get_region(), self._s3_client ) - model_specs = JumpStartModelSpecs(hub_content.content_document, is_hub_content=True) + # TODO: Parse HubContentDescription return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubDataType.HUB: + if data_type == HubContentType.HUB: hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info) - hub = CuratedHub(hub_name=hub_name, region=region) - hub_info = hub.describe() - return JumpStartCachedContentValue(formatted_content=hub_info) + hub_description: HubDescription = hub_utils.describe(hub_name=hub_name, region=region) + return JumpStartCachedContentValue(formatted_content=hub_description) raise ValueError( f"Bad value for key '{key}': must be in", - f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}" + f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}" ) def get_manifest(self) -> List[JumpStartModelHeader]: @@ -474,7 +480,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """ details, _ = self._content_cache.get( - JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) + JumpStartCachedContentKey(HubContentType.MODEL, hub_model_arn) ) return details.formatted_content @@ -485,7 +491,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn)) return details.formatted_content def clear(self) -> None: diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index ea57fed298..5f717fd7ee 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -13,15 +13,15 @@ """This module provides the JumpStart Curated Hub class.""" from __future__ import absolute_import -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional import boto3 from sagemaker.session import Session from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, ) -from sagemaker.jumpstart.types import HubDataType -import sagemaker.jumpstart.curated_hub.utils as hubutils +from sagemaker.jumpstart.types import HubDescription, HubContentType, HubContentDescription +import sagemaker.jumpstart.session_utils as session_utils class CuratedHub: @@ -29,16 +29,16 @@ class CuratedHub: def __init__( self, - name: str, + hub_name: str, region: str = JUMPSTART_DEFAULT_REGION_NAME, - session: Optional[Session] = None, + sagemaker_session: Optional[Session] = None, ): - self.name = name - if session.boto_region_name != region: + self.hub_name = hub_name + if sagemaker_session.boto_region_name != region: # TODO: Handle error pass self.region = region - self._session = session or Session(boto3.Session(region_name=region)) + self._sagemaker_session = sagemaker_session or Session(boto3.Session(region_name=region)) def create( self, @@ -50,32 +50,60 @@ def create( ) -> Dict[str, str]: """Creates a hub with the given description""" - return hubutils.create_hub( - hub_name=self.name, + bucket_name = session_utils.create_hub_bucket_if_it_does_not_exist( + bucket_name, self._sagemaker_session + ) + + return self._sagemaker_session.create_hub( + hub_name=self.hub_name, hub_description=description, hub_display_name=display_name, hub_search_keywords=search_keywords, hub_bucket_name=bucket_name, tags=tags, - sagemaker_session=self._session, ) - def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: - """Returns descriptive information about the Hub Model""" + def describe(self) -> HubDescription: + """Returns descriptive information about the Hub""" + + hub_description = self._sagemaker_session.describe_hub(hub_name=self.hub_name) + + return HubDescription(hub_description) + + def list_models(self, **kwargs) -> Dict[str, Any]: + """Lists the models in this Curated Hub - hub_content = hubutils.describe_hub_content( - hub_name=self.name, - content_name=model_name, - content_type=HubDataType.MODEL, - content_version=model_version, - sagemaker_session=self._session, + **kwargs: Passed to invocation of ``Session:list_hub_contents``. + """ + # TODO: Validate kwargs and fast-fail? + + hub_content_summaries = self._sagemaker_session.list_hub_contents( + hub_name=self.hub_name, hub_content_type=HubContentType.MODEL, **kwargs ) + # TODO: Handle pagination + return hub_content_summaries - return hub_content + def describe_model(self, model_name: str, model_version: str = "*") -> HubContentDescription: + """Returns descriptive information about the Hub Model""" - def describe(self) -> Dict[str, Any]: - """Returns descriptive information about the Hub""" + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL, + ) + + return HubContentDescription(hub_content_description) - hub_info = hubutils.describe_hub(hub_name=self.name, sagemaker_session=self._session) + def delete_model(self, model_name: str, model_version: str = "*") -> None: + """Deletes a model from this CuratedHub.""" + return self._sagemaker_session.delete_hub_content( + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL, + hub_name=self.hub_name, + ) - return hub_info + def delete(self) -> None: + """Deletes this Curated Hub""" + return self._sagemaker_session.delete_hub(self.hub_name) diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index 205c87329b..09307f5fb6 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -10,194 +10,32 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Mid-level wrappers to HubService API. These utilities handles parsing, custom -errors, and validations on top of the low-level HubService API calls in Session.""" -from __future__ import absolute_import -from typing import Optional, Dict, Any, List - -from sagemaker.jumpstart.types import HubDataType -from sagemaker.jumpstart.constants import ( - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) +"""Utilities to interact with Hub.""" +from typing import Any, Dict +import boto3 from sagemaker.session import Session +from sagemaker.jumpstart.types import HubDescription, HubContentType, HubContentDescription -# def _validate_hub_name(hub_name: str) -> bool: -# """Validates hub_name to be either a name or a full ARN""" -# pass - - -def _generate_default_hub_bucket_name( - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Return the name of the default bucket to use in relevant Amazon SageMaker interactions. - - This function will create the s3 bucket if it does not exist. - - Returns: - str: The name of the default bucket. If the name was not explicitly specified through - the Session or sagemaker_config, the bucket will take the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - account_id: str = sagemaker_session.account_id() - - # TODO: Validate and fast fail - - return f"sagemaker-hubs-{region}-{account_id}" - - -def create_hub( - hub_name: str, - hub_description: str, - hub_display_name: str = None, - hub_search_keywords: Optional[List[str]] = None, - hub_bucket_name: Optional[str] = None, - tags: Optional[List[Dict[str, Any]]] = None, - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Creates a SageMaker Hub - - Returns: - (str): Arn of the created hub. - """ - - if hub_bucket_name is None: - hub_bucket_name = _generate_default_hub_bucket_name(sagemaker_session) - s3_storage_config = {"S3OutputPath": hub_bucket_name} - response = sagemaker_session.create_hub( - hub_name, hub_description, hub_display_name, hub_search_keywords, s3_storage_config, tags - ) - - # TODO: Custom error message - - hub_arn = response["HubArn"] - return hub_arn - - -def describe_hub( - hub_name: str, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION -) -> Dict[str, Any]: - """Returns descriptive information about the Hub""" - # TODO: hub_name validation and fast-fail - - response = sagemaker_session.describe_hub(hub_name=hub_name) - - # TODO: Make HubInfo and parse response? - # TODO: Custom error message +def describe(hub_name: str, region: str) -> HubDescription: + """Returns descriptive information about the Hub.""" - return response + sagemaker_session = Session(boto3.Session(region_name=region)) + hub_description = sagemaker_session.describe_hub(hub_name=hub_name) + return HubDescription(hub_description) -def delete_hub(hub_name, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION) -> None: - """Deletes a SageMaker Hub""" - response = sagemaker_session.delete_hub(hub_name=hub_name) +def describe_model( + hub_name: str, region: str, model_name: str, model_version: str = "*" +) -> HubContentDescription: + """Returns descriptive information about the Hub model.""" - # TODO: Custom error message - - return response - - -def import_hub_content( - document_schema_version: str, - hub_name: str, - hub_content_name: str, - hub_content_type: str, - hub_content_document: str, - hub_content_display_name: str = None, - hub_content_description: str = None, - hub_content_version: str = None, - hub_content_markdown: str = None, - hub_content_search_keywords: List[str] = None, - tags: List[Dict[str, Any]] = None, - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Dict[str, str]: - """Imports a new HubContent into a SageMaker Hub - - Returns arns for the Hub and the HubContent where import was successful. - """ - - response = sagemaker_session.import_hub_content( - document_schema_version, - hub_name, - hub_content_name, - hub_content_type, - hub_content_document, - hub_content_display_name, - hub_content_description, - hub_content_version, - hub_content_markdown, - hub_content_search_keywords, - tags, - ) - return response - - -def list_hub_contents( - hub_name: str, - hub_content_type: HubDataType.MODEL or HubDataType.NOTEBOOK, - creation_time_after: str = None, - creation_time_before: str = None, - max_results: int = None, - max_schema_version: str = None, - name_contains: str = None, - next_token: str = None, - sort_by: str = None, - sort_order: str = None, - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Dict[str, Any]: - """List contents of a hub.""" - - response = sagemaker_session.list_hub_contents( - hub_name, - hub_content_type, - creation_time_after, - creation_time_before, - max_results, - max_schema_version, - name_contains, - next_token, - sort_by, - sort_order, - ) - return response - - -def describe_hub_content( - hub_name: str, - content_name: str, - content_type: HubDataType, - content_version: str = None, - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> Dict[str, Any]: - """Returns descriptive information about the content of a hub.""" - # TODO: hub_name validation and fast-fail - - hub_content: Dict[str, Any] = sagemaker_session.describe_hub_content( - hub_content_name=content_name, - hub_content_type=content_type, + sagemaker_session = Session(boto3.Session(region_name=region)) + hub_content_description: Dict[str, Any] = sagemaker_session.describe_hub_content( hub_name=hub_name, - hub_content_version=content_version, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL, ) - # TODO: Parse HubContent - # TODO: Parse HubContentDocument - - return hub_content - - -def delete_hub_content( - hub_content_name: str, - hub_content_version: str, - hub_content_type: str, - hub_name: str, - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> None: - """Deletes a given HubContent in a SageMaker Hub""" - # TODO: Validate hub name - - response = sagemaker_session.delete_hub_content( - hub_content_name, hub_content_version, hub_content_type, hub_name - ) - return response + return HubContentDescription(hub_content_description) diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..d5f90cff07 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -208,3 +208,45 @@ def get_model_id_version_from_training_job( ) return model_id, model_version + + +def generate_default_hub_bucket_name( + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. + + Returns: + str: The name of the default bucket. If the name was not explicitly specified through + the Session or sagemaker_config, the bucket will take the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + account_id: str = sagemaker_session.account_id() + + # TODO: Validate and fast fail + + return f"sagemaker-hubs-{region}-{account_id}" + + +def create_hub_bucket_if_it_does_not_exist( + bucket_name: Optional[str] = None, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Creates the default SageMaker Hub bucket if it does not exist. + + Returns: + str: The name of the default bucket. Takes the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + if bucket_name is None: + bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) + + sagemaker_session._create_s3_bucket_if_it_does_not_exist( + bucket_name=bucket_name, + region=region, + ) + + return bucket_name diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 5a4e91d092..03cadd8758 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -106,15 +106,21 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" -class HubDataType(str, Enum): +class HubContentType(str, Enum): """Enum for Hub data storage objects.""" HUB = "hub" MODEL = "model" NOTEBOOK = "notebook" + @classmethod + @property + def content_only(cls): + """Subset of HubContentType defined by SageMaker API Documentation""" + return cls.MODEL, cls.NOTEBOOK -JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType] + +JumpStartContentDataType = Union[JumpStartS3FileType, HubContentType] class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -1001,6 +1007,206 @@ def __init__( self.id_info = id_info +class HubContentDependency(JumpStartDataHolderType): + """Data class for any dependencies related to hub content. + + Content can be scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["dependency_copy_path", "dependency_origin_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDependency object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.dependency_copy_path: Optional[str] = json_obj.get("dependency_copy_path", "") + self.dependency_origin_path: Optional[str] = json_obj.get("dependency_origin_path", "") + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubContentDependency object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class HubContentDescription(JumpStartDataHolderType): + """Data class for the Hub Content from session.describe_hub_contents()""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "failure_reason", + "hub_arn", + "hub_content_arn", + "hub_content_dependencies", + "hub_content_description", + "hub_content_display_name", + "hub_content_document", + "hub_content_markdown", + "hub_content_name", + "hub_content_search_keywords", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_name", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDescription object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: int = int(json_obj["creation_time"]) + self.document_schema_version: str = json_obj["document_schema_version"] + self.failure_reason: str = json_obj["failure_reason"] + self.hub_arn: str = json_obj["hub_arn"] + self.hub_content_arn: str = json_obj["hub_content_arn"] + self.hub_content_dependencies: List[HubContentDependency] = [ + HubContentDependency(dep) for dep in json_obj["hub_content_dependencies"] + ] + self.hub_content_description: str = json_obj["hub_content_description"] + self.hub_content_display_name: str = json_obj["hub_content_display_name"] + self.hub_content_document: str = json_obj["hub_content_document"] + self.hub_content_markdown: str = json_obj["hub_content_markdown"] + self.hub_content_name: str = json_obj["hub_content_name"] + self.hub_content_search_keywords: str = json_obj["hub_content_search_keywords"] + self.hub_content_status: str = json_obj["hub_content_status"] + self.hub_content_type: HubContentType.content_only = json_obj["hub_content_type"] + self.hub_content_version: str = json_obj["hub_content_version"] + self.hub_name: str = json_obj["hub_name"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubContentDescription object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + + +class HubS3StorageConfig(JumpStartDataHolderType): + """Data class for any dependencies related to hub content. + + Includes scripts, model artifacts, datasets, or notebooks.""" + + __slots__ = ["s3_output_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubS3StorageConfig object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.s3_output_path: Optional[str] = json_obj.get("s3_output_path", "") + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubS3StorageConfig object.""" + return {"s3_output_path": self.s3_output_path} + + +class HubDescription(JumpStartDataHolderType): + """Data class for the Hub from session.describe_hub()""" + + __slots__ = [ + "creation_time", + "failure_reason", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + "s3_storage_config", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubDescription object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + + self.creation_time: int = int(json_obj["creation_time"]) + self.failure_reason: str = json_obj["failure_reason"] + self.hub_arn: str = json_obj["hub_arn"] + self.hub_description: str = json_obj["hub_description"] + self.hub_display_name: str = json_obj["hub_display_name"] + self.hub_name: str = json_obj["hub_name"] + self.hub_search_keywords: List[str] = json_obj["hub_search_keywords"] + self.hub_status: str = json_obj["hub_status"] + self.last_modified_time: int = int(json_obj["last_modified_time"]) + self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig( + json_obj["s3_storage_config"] + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubContentDescription object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + + class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 76682c0f9e..3806f42854 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,7 +22,7 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubDataType, + HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -197,14 +197,14 @@ def patched_retrieval_function( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - if datatype == HubDataType.MODEL: + if datatype == HubContentType.MODEL: _, _, _, model_name, model_version = id_info.split("/") return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) ) # TODO: Implement - if datatype == HubDataType.HUB: + if datatype == HubContentType.HUB: return None raise ValueError(f"Bad value for filetype: {datatype}") From 0435ae978dc07c042e00a940f096e58bc9b548a4 Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Mon, 26 Feb 2024 23:35:31 +0000 Subject: [PATCH 03/14] Refactor hub related stuff under curated_hub --- src/sagemaker/jumpstart/cache.py | 20 +- .../jumpstart/curated_hub/curated_hub.py | 20 +- src/sagemaker/jumpstart/curated_hub/types.py | 217 +++++++++++++++++- src/sagemaker/jumpstart/curated_hub/utils.py | 18 +- src/sagemaker/jumpstart/types.py | 217 +----------------- 5 files changed, 252 insertions(+), 240 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index c8d0f07590..83745b941e 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -29,8 +29,6 @@ JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, ) -from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub -from sagemaker.jumpstart.curated_hub.utils import get_info_from_hub_resource_arn from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, @@ -45,9 +43,11 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, +) +from sagemaker.jumpstart.curated_hub.types import ( HubContentType, - HubDescription, - HubContentDescription, + DescribeHubResponse, + DescribeHubContentsResponse, ) from sagemaker.jumpstart import utils from sagemaker.jumpstart.curated_hub import utils as hub_utils @@ -343,10 +343,10 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubContentType.MODEL: - hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn( + hub_name, region, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( id_info ) - hub_model_description: HubContentDescription = hub_utils.describe_model( + hub_model_description: DescribeHubContentsResponse = hub_utils.describe_model( hub_name=hub_name, region=region, model_name=model_name, @@ -359,13 +359,15 @@ def _retrieval_function( self.get_region(), self._s3_client ) - # TODO: Parse HubContentDescription return JumpStartCachedContentValue( formatted_content=model_specs ) if data_type == HubContentType.HUB: - hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info) - hub_description: HubDescription = hub_utils.describe(hub_name=hub_name, region=region) + hub_name, region, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) + hub_description: DescribeHubResponse = hub_utils.describe( + hub_name=hub_name, + region=region + ) return JumpStartCachedContentValue(formatted_content=hub_description) raise ValueError( f"Bad value for key '{key}': must be in", diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index f142922a0c..da77806161 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -15,15 +15,15 @@ from typing import Any, Dict, Optional -import boto3 from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session -from sagemaker.jumpstart.constants import ( - JUMPSTART_DEFAULT_REGION_NAME, -) -from sagemaker.jumpstart.types import HubDescription, HubContentType, HubContentDescription +from sagemaker.jumpstart.curated_hub.types import ( + DescribeHubResponse, + HubContentType, + DescribeHubContentsResponse, +) import sagemaker.jumpstart.session_utils as session_utils @@ -66,12 +66,12 @@ def create( tags=tags, ) - def describe(self) -> HubDescription: + def describe(self) -> DescribeHubResponse: """Returns descriptive information about the Hub""" hub_description = self._sagemaker_session.describe_hub(hub_name=self.hub_name) - return HubDescription(hub_description) + return DescribeHubResponse(hub_description) def list_models(self, **kwargs) -> Dict[str, Any]: """Lists the models in this Curated Hub @@ -86,7 +86,9 @@ def list_models(self, **kwargs) -> Dict[str, Any]: # TODO: Handle pagination return hub_content_summaries - def describe_model(self, model_name: str, model_version: str = "*") -> HubContentDescription: + def describe_model( + self, model_name: str, model_version: str = "*" + ) -> DescribeHubContentsResponse: """Returns descriptive information about the Hub Model""" hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( @@ -96,7 +98,7 @@ def describe_model(self, model_name: str, model_version: str = "*") -> HubConten hub_content_type=HubContentType.MODEL, ) - return HubContentDescription(hub_content_description) + return DescribeHubContentsResponse(hub_content_description) def delete_model(self, model_name: str, model_version: str = "*") -> None: """Deletes a model from this CuratedHub.""" diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py index d400137905..f68037e1ff 100644 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -12,11 +12,26 @@ # language governing permissions and limitations under the License. """This module stores types related to SageMaker JumpStart CuratedHub.""" from __future__ import absolute_import -from typing import Optional +from enum import Enum +from typing import Any, Dict, List, Optional from sagemaker.jumpstart.types import JumpStartDataHolderType +class HubContentType(str, Enum): + """Enum for Hub data storage objects.""" + + HUB = "Hub" + MODEL = "Model" + NOTEBOOK = "Notebook" + + @classmethod + @property + def content_only(cls): + """Subset of HubContentType defined by SageMaker API Documentation""" + return cls.MODEL, cls.NOTEBOOK + + class HubArnExtractedInfo(JumpStartDataHolderType): """Data class for info extracted from Hub arn.""" @@ -49,3 +64,203 @@ def __init__( self.hub_content_type = hub_content_type self.hub_content_name = hub_content_name self.hub_content_version = hub_content_version + + +class HubContentDependency(JumpStartDataHolderType): + """Data class for any dependencies related to hub content. + + Content can be scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["dependency_copy_path", "dependency_origin_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDependency object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.dependency_copy_path: Optional[str] = json_obj.get("dependency_copy_path", "") + self.dependency_origin_path: Optional[str] = json_obj.get("dependency_origin_path", "") + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubContentDependency object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class DescribeHubContentsResponse(JumpStartDataHolderType): + """Data class for the Hub Content from session.describe_hub_contents()""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "failure_reason", + "hub_arn", + "hub_content_arn", + "hub_content_dependencies", + "hub_content_description", + "hub_content_display_name", + "hub_content_document", + "hub_content_markdown", + "hub_content_name", + "hub_content_search_keywords", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_name", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubContentsResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: int = int(json_obj["creation_time"]) + self.document_schema_version: str = json_obj["document_schema_version"] + self.failure_reason: str = json_obj["failure_reason"] + self.hub_arn: str = json_obj["hub_arn"] + self.hub_content_arn: str = json_obj["hub_content_arn"] + self.hub_content_dependencies: List[HubContentDependency] = [ + HubContentDependency(dep) for dep in json_obj["hub_content_dependencies"] + ] + self.hub_content_description: str = json_obj["hub_content_description"] + self.hub_content_display_name: str = json_obj["hub_content_display_name"] + self.hub_content_document: str = json_obj["hub_content_document"] + self.hub_content_markdown: str = json_obj["hub_content_markdown"] + self.hub_content_name: str = json_obj["hub_content_name"] + self.hub_content_search_keywords: str = json_obj["hub_content_search_keywords"] + self.hub_content_status: str = json_obj["hub_content_status"] + self.hub_content_type: HubContentType.content_only = json_obj["hub_content_type"] + self.hub_content_version: str = json_obj["hub_content_version"] + self.hub_name: str = json_obj["hub_name"] + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of DescribeHubContentsResponse object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + + +class HubS3StorageConfig(JumpStartDataHolderType): + """Data class for any dependencies related to hub content. + + Includes scripts, model artifacts, datasets, or notebooks.""" + + __slots__ = ["s3_output_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubS3StorageConfig object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.s3_output_path: Optional[str] = json_obj.get("s3_output_path", "") + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubS3StorageConfig object.""" + return {"s3_output_path": self.s3_output_path} + + +class DescribeHubResponse(JumpStartDataHolderType): + """Data class for the Hub from session.describe_hub()""" + + __slots__ = [ + "creation_time", + "failure_reason", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + "s3_storage_config", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + + self.creation_time: int = int(json_obj["creation_time"]) + self.failure_reason: str = json_obj["failure_reason"] + self.hub_arn: str = json_obj["hub_arn"] + self.hub_description: str = json_obj["hub_description"] + self.hub_display_name: str = json_obj["hub_display_name"] + self.hub_name: str = json_obj["hub_name"] + self.hub_search_keywords: List[str] = json_obj["hub_search_keywords"] + self.hub_status: str = json_obj["hub_status"] + self.last_modified_time: int = int(json_obj["last_modified_time"]) + self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig( + json_obj["s3_storage_config"] + ) + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of DescribeHubContentsResponse object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index 5039427757..00df687a1a 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -18,20 +18,25 @@ from sagemaker.session import Session from sagemaker.jumpstart import constants from sagemaker.utils import aws_partition -from sagemaker.jumpstart.types import HubDescription, HubContentType, HubContentDescription -from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo +from sagemaker.jumpstart.curated_hub.types import ( + DescribeHubResponse, + HubContentType, + DescribeHubContentsResponse, + HubArnExtractedInfo, +) -def describe(hub_name: str, region: str) -> HubDescription: + +def describe(hub_name: str, region: str) -> DescribeHubResponse: """Returns descriptive information about the Hub.""" sagemaker_session = Session(boto3.Session(region_name=region)) hub_description = sagemaker_session.describe_hub(hub_name=hub_name) - return HubDescription(hub_description) + return DescribeHubResponse(hub_description) def describe_model( hub_name: str, region: str, model_name: str, model_version: str = "*" -) -> HubContentDescription: +) -> DescribeHubContentsResponse: """Returns descriptive information about the Hub model.""" sagemaker_session = Session(boto3.Session(region_name=region)) @@ -42,7 +47,8 @@ def describe_model( hub_content_type=HubContentType.MODEL, ) - return HubContentDescription(hub_content_description) + return DescribeHubContentsResponse(hub_content_description) + def get_info_from_hub_resource_arn( arn: str, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 6e5973600f..8798792682 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -24,6 +24,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType +from sagemaker.jumpstart.curated_hub import types as hub_types class JumpStartDataHolderType: @@ -106,21 +107,7 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" -class HubContentType(str, Enum): - """Enum for Hub data storage objects.""" - - HUB = "Hub" - MODEL = "Model" - NOTEBOOK = "Notebook" - - @classmethod - @property - def content_only(cls): - """Subset of HubContentType defined by SageMaker API Documentation""" - return cls.MODEL, cls.NOTEBOOK - - -JumpStartContentDataType = Union[JumpStartS3FileType, HubContentType] +JumpStartContentDataType = Union[JumpStartS3FileType, hub_types.HubContentType] class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -1007,206 +994,6 @@ def __init__( self.id_info = id_info -class HubContentDependency(JumpStartDataHolderType): - """Data class for any dependencies related to hub content. - - Content can be scripts, model artifacts, datasets, or notebooks. - """ - - __slots__ = ["dependency_copy_path", "dependency_origin_path"] - - def __init__(self, json_obj: Dict[str, Any]) -> None: - """Instantiates HubContentDependency object - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - self.from_json(json_obj) - - def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - - self.dependency_copy_path: Optional[str] = json_obj.get("dependency_copy_path", "") - self.dependency_origin_path: Optional[str] = json_obj.get("dependency_origin_path", "") - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of HubContentDependency object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} - return json_obj - - -class HubContentDescription(JumpStartDataHolderType): - """Data class for the Hub Content from session.describe_hub_contents()""" - - __slots__ = [ - "creation_time", - "document_schema_version", - "failure_reason", - "hub_arn", - "hub_content_arn", - "hub_content_dependencies", - "hub_content_description", - "hub_content_display_name", - "hub_content_document", - "hub_content_markdown", - "hub_content_name", - "hub_content_search_keywords", - "hub_content_status", - "hub_content_type", - "hub_content_version", - "hub_name", - ] - - def __init__(self, json_obj: Dict[str, Any]) -> None: - """Instantiates HubContentDescription object. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - self.from_json(json_obj) - - def from_json(self, json_obj: Dict[str, Any]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - self.creation_time: int = int(json_obj["creation_time"]) - self.document_schema_version: str = json_obj["document_schema_version"] - self.failure_reason: str = json_obj["failure_reason"] - self.hub_arn: str = json_obj["hub_arn"] - self.hub_content_arn: str = json_obj["hub_content_arn"] - self.hub_content_dependencies: List[HubContentDependency] = [ - HubContentDependency(dep) for dep in json_obj["hub_content_dependencies"] - ] - self.hub_content_description: str = json_obj["hub_content_description"] - self.hub_content_display_name: str = json_obj["hub_content_display_name"] - self.hub_content_document: str = json_obj["hub_content_document"] - self.hub_content_markdown: str = json_obj["hub_content_markdown"] - self.hub_content_name: str = json_obj["hub_content_name"] - self.hub_content_search_keywords: str = json_obj["hub_content_search_keywords"] - self.hub_content_status: str = json_obj["hub_content_status"] - self.hub_content_type: HubContentType.content_only = json_obj["hub_content_type"] - self.hub_content_version: str = json_obj["hub_content_version"] - self.hub_name: str = json_obj["hub_name"] - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of HubContentDescription object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj - - -class HubS3StorageConfig(JumpStartDataHolderType): - """Data class for any dependencies related to hub content. - - Includes scripts, model artifacts, datasets, or notebooks.""" - - __slots__ = ["s3_output_path"] - - def __init__(self, json_obj: Dict[str, Any]) -> None: - """Instantiates HubS3StorageConfig object - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - self.from_json(json_obj) - - def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - - self.s3_output_path: Optional[str] = json_obj.get("s3_output_path", "") - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of HubS3StorageConfig object.""" - return {"s3_output_path": self.s3_output_path} - - -class HubDescription(JumpStartDataHolderType): - """Data class for the Hub from session.describe_hub()""" - - __slots__ = [ - "creation_time", - "failure_reason", - "hub_arn", - "hub_description", - "hub_display_name", - "hub_name", - "hub_search_keywords", - "hub_status", - "last_modified_time", - "s3_storage_config", - ] - - def __init__(self, json_obj: Dict[str, Any]) -> None: - """Instantiates HubDescription object. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub description. - """ - self.from_json(json_obj) - - def from_json(self, json_obj: Dict[str, Any]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub description. - """ - - self.creation_time: int = int(json_obj["creation_time"]) - self.failure_reason: str = json_obj["failure_reason"] - self.hub_arn: str = json_obj["hub_arn"] - self.hub_description: str = json_obj["hub_description"] - self.hub_display_name: str = json_obj["hub_display_name"] - self.hub_name: str = json_obj["hub_name"] - self.hub_search_keywords: List[str] = json_obj["hub_search_keywords"] - self.hub_status: str = json_obj["hub_status"] - self.last_modified_time: int = int(json_obj["last_modified_time"]) - self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig( - json_obj["s3_storage_config"] - ) - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of HubContentDescription object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj - - class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" From a4c67e8c66e0acd63341a54cc90214a515920094 Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Mon, 26 Feb 2024 23:40:31 +0000 Subject: [PATCH 04/14] Use CuratedHub class instead of using a middleware utilities --- src/sagemaker/jumpstart/cache.py | 12 ++++---- src/sagemaker/jumpstart/curated_hub/utils.py | 29 +------------------- 2 files changed, 6 insertions(+), 35 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 83745b941e..a1f0ed193b 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -51,6 +51,7 @@ ) from sagemaker.jumpstart import utils from sagemaker.jumpstart.curated_hub import utils as hub_utils +from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub from sagemaker.utilities.cache import LRUCache @@ -346,9 +347,8 @@ def _retrieval_function( hub_name, region, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( id_info ) - hub_model_description: DescribeHubContentsResponse = hub_utils.describe_model( - hub_name=hub_name, - region=region, + hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region) + hub_model_description: DescribeHubContentsResponse = hub.describe_model( model_name=model_name, model_version=model_version ) @@ -364,10 +364,8 @@ def _retrieval_function( ) if data_type == HubContentType.HUB: hub_name, region, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) - hub_description: DescribeHubResponse = hub_utils.describe( - hub_name=hub_name, - region=region - ) + hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region) + hub_description: DescribeHubResponse = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_description) raise ValueError( f"Bad value for key '{key}': must be in", diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index 00df687a1a..5aa4cb314d 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -13,43 +13,16 @@ """This module contains utilities related to SageMaker JumpStart CuratedHub.""" from __future__ import absolute_import import re -from typing import Any, Dict, Optional -import boto3 +from typing import Optional from sagemaker.session import Session from sagemaker.jumpstart import constants from sagemaker.utils import aws_partition from sagemaker.jumpstart.curated_hub.types import ( - DescribeHubResponse, HubContentType, - DescribeHubContentsResponse, HubArnExtractedInfo, ) -def describe(hub_name: str, region: str) -> DescribeHubResponse: - """Returns descriptive information about the Hub.""" - - sagemaker_session = Session(boto3.Session(region_name=region)) - hub_description = sagemaker_session.describe_hub(hub_name=hub_name) - return DescribeHubResponse(hub_description) - - -def describe_model( - hub_name: str, region: str, model_name: str, model_version: str = "*" -) -> DescribeHubContentsResponse: - """Returns descriptive information about the Hub model.""" - - sagemaker_session = Session(boto3.Session(region_name=region)) - hub_content_description: Dict[str, Any] = sagemaker_session.describe_hub_content( - hub_name=hub_name, - hub_content_name=model_name, - hub_content_version=model_version, - hub_content_type=HubContentType.MODEL, - ) - - return DescribeHubContentsResponse(hub_content_description) - - def get_info_from_hub_resource_arn( arn: str, ) -> HubArnExtractedInfo: From 4799ccb583400aa60641bd48202a7ac16c6ac197 Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Mon, 26 Feb 2024 23:59:14 +0000 Subject: [PATCH 05/14] Refactor hub related stuff from session utils --- .../jumpstart/curated_hub/curated_hub.py | 8 ++-- src/sagemaker/jumpstart/curated_hub/utils.py | 42 +++++++++++++++++++ src/sagemaker/jumpstart/session_utils.py | 42 ------------------- 3 files changed, 45 insertions(+), 47 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index da77806161..8a9df85d14 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -15,16 +15,14 @@ from typing import Any, Dict, Optional -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION - from sagemaker.session import Session - +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.curated_hub import utils as hub_utils from sagemaker.jumpstart.curated_hub.types import ( DescribeHubResponse, HubContentType, DescribeHubContentsResponse, ) -import sagemaker.jumpstart.session_utils as session_utils class CuratedHub: @@ -53,7 +51,7 @@ def create( ) -> Dict[str, str]: """Creates a hub with the given description""" - bucket_name = session_utils.create_hub_bucket_if_it_does_not_exist( + bucket_name = hub_utils.create_hub_bucket_if_it_does_not_exist( bucket_name, self._sagemaker_session ) diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index 5aa4cb314d..9be424caec 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -110,3 +110,45 @@ def generate_hub_arn_for_estimator_init_kwargs( else: hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) return hub_arn + + +def generate_default_hub_bucket_name( + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. + + Returns: + str: The name of the default bucket. If the name was not explicitly specified through + the Session or sagemaker_config, the bucket will take the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + account_id: str = sagemaker_session.account_id() + + # TODO: Validate and fast fail + + return f"sagemaker-hubs-{region}-{account_id}" + + +def create_hub_bucket_if_it_does_not_exist( + bucket_name: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Creates the default SageMaker Hub bucket if it does not exist. + + Returns: + str: The name of the default bucket. Takes the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + if bucket_name is None: + bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) + + sagemaker_session._create_s3_bucket_if_it_does_not_exist( + bucket_name=bucket_name, + region=region, + ) + + return bucket_name diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index d5f90cff07..e511a052d1 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -208,45 +208,3 @@ def get_model_id_version_from_training_job( ) return model_id, model_version - - -def generate_default_hub_bucket_name( - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. - - Returns: - str: The name of the default bucket. If the name was not explicitly specified through - the Session or sagemaker_config, the bucket will take the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - account_id: str = sagemaker_session.account_id() - - # TODO: Validate and fast fail - - return f"sagemaker-hubs-{region}-{account_id}" - - -def create_hub_bucket_if_it_does_not_exist( - bucket_name: Optional[str] = None, - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, -) -> str: - """Creates the default SageMaker Hub bucket if it does not exist. - - Returns: - str: The name of the default bucket. Takes the form: - ``sagemaker-hubs-{region}-{AWS account ID}``. - """ - - region: str = sagemaker_session.boto_region_name - if bucket_name is None: - bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) - - sagemaker_session._create_s3_bucket_if_it_does_not_exist( - bucket_name=bucket_name, - region=region, - ) - - return bucket_name From 086bf9240699326828867ed514103e38f7a526e3 Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Tue, 27 Feb 2024 00:05:03 +0000 Subject: [PATCH 06/14] to_json to inherit from JumpStartDataHolderType --- src/sagemaker/jumpstart/curated_hub/types.py | 41 +------------------- src/sagemaker/jumpstart/types.py | 38 +++++++++--------- 2 files changed, 21 insertions(+), 58 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py index f68037e1ff..5aed66d68c 100644 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -153,30 +153,12 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hub_content_version: str = json_obj["hub_content_version"] self.hub_name: str = json_obj["hub_name"] - def to_json(self) -> Dict[str, Any]: - """Returns json representation of DescribeHubContentsResponse object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj - class HubS3StorageConfig(JumpStartDataHolderType): """Data class for any dependencies related to hub content. - Includes scripts, model artifacts, datasets, or notebooks.""" + Includes scripts, model artifacts, datasets, or notebooks. + """ __slots__ = ["s3_output_path"] @@ -245,22 +227,3 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig( json_obj["s3_storage_config"] ) - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of DescribeHubContentsResponse object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8798792682..33ddfe1b97 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -99,6 +99,25 @@ def __repr__(self) -> str: } return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" + def to_json(self) -> Dict[str, Any]: + """Returns json representation of object.""" + json_obj = {} + for att in self.__slots__: + if hasattr(self, att): + cur_val = getattr(self, att) + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + else: + json_obj[att] = cur_val + return json_obj + class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" @@ -911,25 +930,6 @@ def from_hub_content_doc(self, hub_content_doc: Dict[str, Any]) -> None: """ # TODO: Implement - def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartModelSpecs object.""" - json_obj = {} - for att in self.__slots__: - if hasattr(self, att): - cur_val = getattr(self, att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - else: - json_obj[att] = cur_val - return json_obj - def supports_prepacked_inference(self) -> bool: """Returns True if the model has a prepacked inference artifact.""" return getattr(self, "hosting_prepacked_artifact_key", None) is not None From b600dd1ad25e2fda786e5781a1c3235bdceb036c Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Tue, 27 Feb 2024 23:23:15 +0000 Subject: [PATCH 07/14] Add curated_hub utils unit tests --- src/sagemaker/jumpstart/cache.py | 6 +-- .../jumpstart/curated_hub/curated_hub.py | 8 ++-- src/sagemaker/jumpstart/curated_hub/types.py | 1 + src/sagemaker/jumpstart/curated_hub/utils.py | 2 +- src/sagemaker/jumpstart/session_utils.py | 4 +- src/sagemaker/jumpstart/types.py | 11 +++--- .../jumpstart/curated_hub/test_utils.py | 37 ++++++++++++++++++- tests/unit/sagemaker/jumpstart/utils.py | 3 +- 8 files changed, 53 insertions(+), 19 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index a1f0ed193b..032dc11adc 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -21,6 +21,7 @@ import botocore from packaging.version import Version from packaging.specifiers import SpecifierSet, InvalidSpecifier +from sagemaker.utilities.cache import LRUCache from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -36,6 +37,7 @@ JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON, JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, ) +from sagemaker.jumpstart import utils from sagemaker.jumpstart.types import ( JumpStartCachedContentKey, JumpStartCachedContentValue, @@ -45,14 +47,12 @@ JumpStartVersionedModelId, ) from sagemaker.jumpstart.curated_hub.types import ( - HubContentType, DescribeHubResponse, DescribeHubContentsResponse, + HubContentType, ) -from sagemaker.jumpstart import utils from sagemaker.jumpstart.curated_hub import utils as hub_utils from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub -from sagemaker.utilities.cache import LRUCache class JumpStartModelsCache: diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 8a9df85d14..1611a88f64 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -17,11 +17,11 @@ from typing import Any, Dict, Optional from sagemaker.session import Session from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.curated_hub import utils as hub_utils +from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist from sagemaker.jumpstart.curated_hub.types import ( DescribeHubResponse, - HubContentType, DescribeHubContentsResponse, + HubContentType, ) @@ -51,9 +51,7 @@ def create( ) -> Dict[str, str]: """Creates a hub with the given description""" - bucket_name = hub_utils.create_hub_bucket_if_it_does_not_exist( - bucket_name, self._sagemaker_session - ) + bucket_name = create_hub_bucket_if_it_does_not_exist(bucket_name, self._sagemaker_session) return self._sagemaker_session.create_hub( hub_name=self.hub_name, diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py index 5aed66d68c..85b98a25f3 100644 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -15,6 +15,7 @@ from enum import Enum from typing import Any, Dict, List, Optional + from sagemaker.jumpstart.types import JumpStartDataHolderType diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index 9be424caec..08a1b7131f 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -15,8 +15,8 @@ import re from typing import Optional from sagemaker.session import Session -from sagemaker.jumpstart import constants from sagemaker.utils import aws_partition +from sagemaker.jumpstart import constants from sagemaker.jumpstart.curated_hub.types import ( HubContentType, HubArnExtractedInfo, diff --git a/src/sagemaker/jumpstart/session_utils.py b/src/sagemaker/jumpstart/session_utils.py index e511a052d1..93e1bc3bd0 100644 --- a/src/sagemaker/jumpstart/session_utils.py +++ b/src/sagemaker/jumpstart/session_utils.py @@ -15,11 +15,11 @@ from __future__ import absolute_import from typing import Optional, Tuple -from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn from sagemaker.session import Session from sagemaker.utils import aws_partition +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.utils import get_jumpstart_model_id_version_from_resource_arn def get_model_id_version_from_endpoint( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 33ddfe1b97..e0762f1519 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -15,16 +15,14 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.session import Session from sagemaker.utils import get_instance_type_family, format_tags, Tags +from sagemaker.enums import EndpointType from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines - -from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements -from sagemaker.enums import EndpointType -from sagemaker.jumpstart.curated_hub import types as hub_types class JumpStartDataHolderType: @@ -119,6 +117,9 @@ def to_json(self) -> Dict[str, Any]: return json_obj +from sagemaker.jumpstart.curated_hub.types import HubContentType # noqa: E402 + + class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" @@ -126,7 +127,7 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" -JumpStartContentDataType = Union[JumpStartS3FileType, hub_types.HubContentType] +JumpStartContentDataType = Union[JumpStartS3FileType, HubContentType] class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 2f0841b4ea..868c1231a2 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -11,10 +11,11 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import Mock -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from unittest.mock import Mock +from botocore.exceptions import ClientError from sagemaker.jumpstart.curated_hub import utils +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo @@ -151,3 +152,35 @@ def test_generate_hub_arn_for_estimator_init_kwargs(): utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn ) + + +def test_generate_default_hub_bucket_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-east-1" + + assert ( + utils.generate_default_hub_bucket_name(sagemaker_session=mock_sagemaker_session) + == "sagemaker-hubs-us-east-1-123456789123" + ) + + +def test_create_hub_bucket_if_it_does_not_exist(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + error = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="foo", + ) + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 81526485f9..c4eb995e5c 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,13 +22,13 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, ) + from sagemaker.jumpstart.utils import get_formatted_manifest from tests.unit.sagemaker.jumpstart.constants import ( PROTOTYPICAL_MODEL_SPECS_DICT, @@ -37,6 +37,7 @@ BASE_HEADER, SPECIAL_MODEL_SPECS_DICT, ) +from sagemaker.jumpstart.curated_hub.types import HubContentType def get_header_from_base_header( From 15061472cf05143afd2cc64a776dc03ddb42793a Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Wed, 28 Feb 2024 02:07:08 +0000 Subject: [PATCH 08/14] add unittests --- src/sagemaker/jumpstart/cache.py | 2 +- .../jumpstart/curated_hub/curated_hub.py | 14 +- .../jumpstart/curated_hub/test_curated_hub.py | 158 ++++++++++++++++++ .../jumpstart/curated_hub/test_utils.py | 5 - 4 files changed, 171 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 032dc11adc..781b8314d3 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -368,7 +368,7 @@ def _retrieval_function( hub_description: DescribeHubResponse = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_description) raise ValueError( - f"Bad value for key '{key}': must be in", + f"Bad value for key '{key}': must be in ", f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}" ) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 1611a88f64..31242a4f53 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -34,10 +34,20 @@ def __init__( region: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): + """Instantiates a SageMaker ``CuratedHub``. + + Args: + hub_name (str): The name of the Hub to create. + region (str): The region in which the CuratedHub is in. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. + """ self.hub_name = hub_name if sagemaker_session.boto_region_name != region: - # TODO: Handle error - pass + raise ValueError( + f"Cannot have conflicting regions for region=[{region}] and ", + f"sagemaker_session region=[{str(sagemaker_session.boto_region_name)}].", + ) self.region = region self._sagemaker_session = sagemaker_session diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index e69de29bb2..1be6496219 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -0,0 +1,158 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from mock import Mock +from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +# @pytest.fixture() +# def sagemaker_session(): +# boto_mock = Mock(name="boto_session", region_name=REGION) +# session_mock = Mock( +# name="sagemaker_session", +# boto_session=boto_mock, +# boto_region_name=REGION, +# config=None, +# local_mode=False, +# default_bucket_prefix=None, +# ) +# session_mock.return_value.sagemkaer_client = Mock(name="sagemaker_client") +# session_mock.sts_client.get_caller_identity = Mock(return_value={"Account": ACCOUNT_ID}) +# create_hub = {"HubArn": "arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name"} +# session_mock.sagemaker_client.create_hub = Mock(return_value=create_hub) +# print(session_mock.sagemaker_client) +# return session_mock + + +def test_instantiates(sagemaker_session): + hub = CuratedHub(hub_name=HUB_NAME, region=REGION, sagemaker_session=sagemaker_session) + assert hub.hub_name == HUB_NAME + assert hub.region == "us-east-1" + assert hub._sagemaker_session == sagemaker_session + + +def test_instantiates_handles_conflicting_regions(sagemaker_session): + conflicting_region = "us-east-2" + + with pytest.raises(ValueError): + CuratedHub( + hub_name=HUB_NAME, region=conflicting_region, sagemaker_session=sagemaker_session + ) + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + None, + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +def test_create_with_no_bucket_name( + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = CuratedHub(hub_name=hub_name, region=REGION, sagemaker_session=sagemaker_session) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_bucket_name": "sagemaker-hubs-us-east-1-123456789123", + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + bucket_name=hub_bucket_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", "mock-bucket-123", None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + "mock-bucket-123", + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +def test_create_with_bucket_name( + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = CuratedHub(hub_name=hub_name, region=REGION, sagemaker_session=sagemaker_session) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_bucket_name": hub_bucket_name, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + bucket_name=hub_bucket_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 868c1231a2..0787a50228 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -13,7 +13,6 @@ from __future__ import absolute_import from unittest.mock import Mock -from botocore.exceptions import ClientError from sagemaker.jumpstart.curated_hub import utils from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo @@ -173,10 +172,6 @@ def test_create_hub_bucket_if_it_does_not_exist(): } mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None mock_sagemaker_session.boto_region_name = "us-east-1" - error = ClientError( - error_response={"Error": {"Code": "404", "Message": "Not Found"}}, - operation_name="foo", - ) bucket_name = "sagemaker-hubs-us-east-1-123456789123" created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( sagemaker_session=mock_sagemaker_session From a8d6664ed4807e2d2d0b982ed96461d07e5c9685 Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Wed, 28 Feb 2024 02:09:58 +0000 Subject: [PATCH 09/14] add noqa --- src/sagemaker/jumpstart/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e0762f1519..16cbef4acb 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -117,7 +117,7 @@ def to_json(self) -> Dict[str, Any]: return json_obj -from sagemaker.jumpstart.curated_hub.types import HubContentType # noqa: E402 +from sagemaker.jumpstart.curated_hub.types import HubContentType # noqa: E402, C0413 class JumpStartS3FileType(str, Enum): From edb887e6a0d10e37ce79c417a550899f48abf04c Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Wed, 28 Feb 2024 02:11:47 +0000 Subject: [PATCH 10/14] lint --- .../jumpstart/curated_hub/test_curated_hub.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index 1be6496219..f5db923919 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -33,25 +33,6 @@ def sagemaker_session(): return sagemaker_session_mock -# @pytest.fixture() -# def sagemaker_session(): -# boto_mock = Mock(name="boto_session", region_name=REGION) -# session_mock = Mock( -# name="sagemaker_session", -# boto_session=boto_mock, -# boto_region_name=REGION, -# config=None, -# local_mode=False, -# default_bucket_prefix=None, -# ) -# session_mock.return_value.sagemkaer_client = Mock(name="sagemaker_client") -# session_mock.sts_client.get_caller_identity = Mock(return_value={"Account": ACCOUNT_ID}) -# create_hub = {"HubArn": "arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name"} -# session_mock.sagemaker_client.create_hub = Mock(return_value=create_hub) -# print(session_mock.sagemaker_client) -# return session_mock - - def test_instantiates(sagemaker_session): hub = CuratedHub(hub_name=HUB_NAME, region=REGION, sagemaker_session=sagemaker_session) assert hub.hub_name == HUB_NAME From 5f2403627edefe65bfb57db2e6045b92fa0aad0d Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Wed, 28 Feb 2024 18:51:43 +0000 Subject: [PATCH 11/14] curated_hub/types.py to types.py --- src/sagemaker/jumpstart/cache.py | 5 +- .../jumpstart/curated_hub/curated_hub.py | 4 +- src/sagemaker/jumpstart/curated_hub/types.py | 230 ------------------ src/sagemaker/jumpstart/curated_hub/utils.py | 4 +- src/sagemaker/jumpstart/types.py | 215 +++++++++++++++- .../jumpstart/curated_hub/test_utils.py | 4 +- tests/unit/sagemaker/jumpstart/utils.py | 2 +- 7 files changed, 220 insertions(+), 244 deletions(-) delete mode 100644 src/sagemaker/jumpstart/curated_hub/types.py diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 781b8314d3..7765b4752a 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -45,10 +45,9 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, -) -from sagemaker.jumpstart.curated_hub.types import ( DescribeHubResponse, DescribeHubContentsResponse, + HubType, HubContentType, ) from sagemaker.jumpstart.curated_hub import utils as hub_utils @@ -362,7 +361,7 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubContentType.HUB: + if data_type == HubType.HUB: hub_name, region, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region) hub_description: DescribeHubResponse = hub.describe() diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 31242a4f53..9033958921 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -17,12 +17,12 @@ from typing import Any, Dict, Optional from sagemaker.session import Session from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist -from sagemaker.jumpstart.curated_hub.types import ( +from sagemaker.jumpstart.types import ( DescribeHubResponse, DescribeHubContentsResponse, HubContentType, ) +from sagemaker.jumpstart.curated_hub.utils import create_hub_bucket_if_it_does_not_exist class CuratedHub: diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py deleted file mode 100644 index 85b98a25f3..0000000000 --- a/src/sagemaker/jumpstart/curated_hub/types.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module stores types related to SageMaker JumpStart CuratedHub.""" -from __future__ import absolute_import -from enum import Enum -from typing import Any, Dict, List, Optional - - -from sagemaker.jumpstart.types import JumpStartDataHolderType - - -class HubContentType(str, Enum): - """Enum for Hub data storage objects.""" - - HUB = "Hub" - MODEL = "Model" - NOTEBOOK = "Notebook" - - @classmethod - @property - def content_only(cls): - """Subset of HubContentType defined by SageMaker API Documentation""" - return cls.MODEL, cls.NOTEBOOK - - -class HubArnExtractedInfo(JumpStartDataHolderType): - """Data class for info extracted from Hub arn.""" - - __slots__ = [ - "partition", - "region", - "account_id", - "hub_name", - "hub_content_type", - "hub_content_name", - "hub_content_version", - ] - - def __init__( - self, - partition: str, - region: str, - account_id: str, - hub_name: str, - hub_content_type: Optional[str] = None, - hub_content_name: Optional[str] = None, - hub_content_version: Optional[str] = None, - ) -> None: - """Instantiates HubArnExtractedInfo object.""" - - self.partition = partition - self.region = region - self.account_id = account_id - self.hub_name = hub_name - self.hub_content_type = hub_content_type - self.hub_content_name = hub_content_name - self.hub_content_version = hub_content_version - - -class HubContentDependency(JumpStartDataHolderType): - """Data class for any dependencies related to hub content. - - Content can be scripts, model artifacts, datasets, or notebooks. - """ - - __slots__ = ["dependency_copy_path", "dependency_origin_path"] - - def __init__(self, json_obj: Dict[str, Any]) -> None: - """Instantiates HubContentDependency object - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - self.from_json(json_obj) - - def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - - self.dependency_copy_path: Optional[str] = json_obj.get("dependency_copy_path", "") - self.dependency_origin_path: Optional[str] = json_obj.get("dependency_origin_path", "") - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of HubContentDependency object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} - return json_obj - - -class DescribeHubContentsResponse(JumpStartDataHolderType): - """Data class for the Hub Content from session.describe_hub_contents()""" - - __slots__ = [ - "creation_time", - "document_schema_version", - "failure_reason", - "hub_arn", - "hub_content_arn", - "hub_content_dependencies", - "hub_content_description", - "hub_content_display_name", - "hub_content_document", - "hub_content_markdown", - "hub_content_name", - "hub_content_search_keywords", - "hub_content_status", - "hub_content_type", - "hub_content_version", - "hub_name", - ] - - def __init__(self, json_obj: Dict[str, Any]) -> None: - """Instantiates DescribeHubContentsResponse object. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - self.from_json(json_obj) - - def from_json(self, json_obj: Dict[str, Any]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - self.creation_time: int = int(json_obj["creation_time"]) - self.document_schema_version: str = json_obj["document_schema_version"] - self.failure_reason: str = json_obj["failure_reason"] - self.hub_arn: str = json_obj["hub_arn"] - self.hub_content_arn: str = json_obj["hub_content_arn"] - self.hub_content_dependencies: List[HubContentDependency] = [ - HubContentDependency(dep) for dep in json_obj["hub_content_dependencies"] - ] - self.hub_content_description: str = json_obj["hub_content_description"] - self.hub_content_display_name: str = json_obj["hub_content_display_name"] - self.hub_content_document: str = json_obj["hub_content_document"] - self.hub_content_markdown: str = json_obj["hub_content_markdown"] - self.hub_content_name: str = json_obj["hub_content_name"] - self.hub_content_search_keywords: str = json_obj["hub_content_search_keywords"] - self.hub_content_status: str = json_obj["hub_content_status"] - self.hub_content_type: HubContentType.content_only = json_obj["hub_content_type"] - self.hub_content_version: str = json_obj["hub_content_version"] - self.hub_name: str = json_obj["hub_name"] - - -class HubS3StorageConfig(JumpStartDataHolderType): - """Data class for any dependencies related to hub content. - - Includes scripts, model artifacts, datasets, or notebooks. - """ - - __slots__ = ["s3_output_path"] - - def __init__(self, json_obj: Dict[str, Any]) -> None: - """Instantiates HubS3StorageConfig object - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - self.from_json(json_obj) - - def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub content description. - """ - - self.s3_output_path: Optional[str] = json_obj.get("s3_output_path", "") - - def to_json(self) -> Dict[str, Any]: - """Returns json representation of HubS3StorageConfig object.""" - return {"s3_output_path": self.s3_output_path} - - -class DescribeHubResponse(JumpStartDataHolderType): - """Data class for the Hub from session.describe_hub()""" - - __slots__ = [ - "creation_time", - "failure_reason", - "hub_arn", - "hub_description", - "hub_display_name", - "hub_name", - "hub_search_keywords", - "hub_status", - "last_modified_time", - "s3_storage_config", - ] - - def __init__(self, json_obj: Dict[str, Any]) -> None: - """Instantiates DescribeHubResponse object. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub description. - """ - self.from_json(json_obj) - - def from_json(self, json_obj: Dict[str, Any]) -> None: - """Sets fields in object based on json. - - Args: - json_obj (Dict[str, Any]): Dictionary representation of hub description. - """ - - self.creation_time: int = int(json_obj["creation_time"]) - self.failure_reason: str = json_obj["failure_reason"] - self.hub_arn: str = json_obj["hub_arn"] - self.hub_description: str = json_obj["hub_description"] - self.hub_display_name: str = json_obj["hub_display_name"] - self.hub_name: str = json_obj["hub_name"] - self.hub_search_keywords: List[str] = json_obj["hub_search_keywords"] - self.hub_status: str = json_obj["hub_status"] - self.last_modified_time: int = int(json_obj["last_modified_time"]) - self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig( - json_obj["s3_storage_config"] - ) diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index 08a1b7131f..027452d7ed 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -16,11 +16,11 @@ from typing import Optional from sagemaker.session import Session from sagemaker.utils import aws_partition -from sagemaker.jumpstart import constants -from sagemaker.jumpstart.curated_hub.types import ( +from sagemaker.jumpstart.types import ( HubContentType, HubArnExtractedInfo, ) +from sagemaker.jumpstart import constants def get_info_from_hub_resource_arn( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 16cbef4acb..ca9d48eccf 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -117,9 +117,6 @@ def to_json(self) -> Dict[str, Any]: return json_obj -from sagemaker.jumpstart.curated_hub.types import HubContentType # noqa: E402, C0413 - - class JumpStartS3FileType(str, Enum): """Type of files published in JumpStart S3 distribution buckets.""" @@ -127,7 +124,20 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" -JumpStartContentDataType = Union[JumpStartS3FileType, HubContentType] +class HubType(str, Enum): + """Enum for Hub objects.""" + + HUB = "Hub" + + +class HubContentType(str, Enum): + """Enum for Hub content objects.""" + + MODEL = "Model" + NOTEBOOK = "Notebook" + + +JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType] class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -1021,6 +1031,203 @@ def __init__( self.md5_hash = md5_hash +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_type = hub_content_type + self.hub_content_name = hub_content_name + self.hub_content_version = hub_content_version + + +class HubContentDependency(JumpStartDataHolderType): + """Data class for any dependencies related to hub content. + + Content can be scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["dependency_copy_path", "dependency_origin_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDependency object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.dependency_copy_path: Optional[str] = json_obj.get("dependency_copy_path", "") + self.dependency_origin_path: Optional[str] = json_obj.get("dependency_origin_path", "") + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubContentDependency object.""" + json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + return json_obj + + +class DescribeHubContentsResponse(JumpStartDataHolderType): + """Data class for the Hub Content from session.describe_hub_contents()""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "failure_reason", + "hub_arn", + "hub_content_arn", + "hub_content_dependencies", + "hub_content_description", + "hub_content_display_name", + "hub_content_document", + "hub_content_markdown", + "hub_content_name", + "hub_content_search_keywords", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_name", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubContentsResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: int = int(json_obj["creation_time"]) + self.document_schema_version: str = json_obj["document_schema_version"] + self.failure_reason: str = json_obj["failure_reason"] + self.hub_arn: str = json_obj["hub_arn"] + self.hub_content_arn: str = json_obj["hub_content_arn"] + self.hub_content_dependencies: List[HubContentDependency] = [ + HubContentDependency(dep) for dep in json_obj["hub_content_dependencies"] + ] + self.hub_content_description: str = json_obj["hub_content_description"] + self.hub_content_display_name: str = json_obj["hub_content_display_name"] + self.hub_content_document: str = json_obj["hub_content_document"] + self.hub_content_markdown: str = json_obj["hub_content_markdown"] + self.hub_content_name: str = json_obj["hub_content_name"] + self.hub_content_search_keywords: str = json_obj["hub_content_search_keywords"] + self.hub_content_status: str = json_obj["hub_content_status"] + self.hub_content_type: HubContentType = json_obj["hub_content_type"] + self.hub_content_version: str = json_obj["hub_content_version"] + self.hub_name: str = json_obj["hub_name"] + + +class HubS3StorageConfig(JumpStartDataHolderType): + """Data class for any dependencies related to hub content. + + Includes scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["s3_output_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubS3StorageConfig object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.s3_output_path: Optional[str] = json_obj.get("s3_output_path", "") + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of HubS3StorageConfig object.""" + return {"s3_output_path": self.s3_output_path} + + +class DescribeHubResponse(JumpStartDataHolderType): + """Data class for the Hub from session.describe_hub()""" + + __slots__ = [ + "creation_time", + "failure_reason", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + "s3_storage_config", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + + self.creation_time: int = int(json_obj["creation_time"]) + self.failure_reason: str = json_obj["failure_reason"] + self.hub_arn: str = json_obj["hub_arn"] + self.hub_description: str = json_obj["hub_description"] + self.hub_display_name: str = json_obj["hub_display_name"] + self.hub_name: str = json_obj["hub_name"] + self.hub_search_keywords: List[str] = json_obj["hub_search_keywords"] + self.hub_status: str = json_obj["hub_status"] + self.last_modified_time: int = int(json_obj["last_modified_time"]) + self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig( + json_obj["s3_storage_config"] + ) + + class JumpStartKwargs(JumpStartDataHolderType): """Data class for JumpStart object kwargs.""" diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 0787a50228..af2c09417e 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -13,9 +13,9 @@ from __future__ import absolute_import from unittest.mock import Mock -from sagemaker.jumpstart.curated_hub import utils +from sagemaker.jumpstart.types import HubArnExtractedInfo from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME -from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo +from sagemaker.jumpstart.curated_hub import utils def test_get_info_from_hub_resource_arn(): diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index c4eb995e5c..f38a9d6ed4 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -27,6 +27,7 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, + HubContentType, ) from sagemaker.jumpstart.utils import get_formatted_manifest @@ -37,7 +38,6 @@ BASE_HEADER, SPECIAL_MODEL_SPECS_DICT, ) -from sagemaker.jumpstart.curated_hub.types import HubContentType def get_header_from_base_header( From b0ce6247ca12601b03e6b030d22ec2ef8b68201a Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Wed, 28 Feb 2024 21:35:55 +0000 Subject: [PATCH 12/14] remove region dependency for curatedhub --- .../jumpstart/curated_hub/curated_hub.py | 15 +++++---------- .../jumpstart/curated_hub/test_curated_hub.py | 15 +++------------ tests/unit/sagemaker/jumpstart/test_cache.py | 6 +++--- tests/unit/sagemaker/jumpstart/utils.py | 3 ++- 4 files changed, 13 insertions(+), 26 deletions(-) diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 9033958921..59a11df577 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -31,24 +31,17 @@ class CuratedHub: def __init__( self, hub_name: str, - region: str, sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Instantiates a SageMaker ``CuratedHub``. Args: hub_name (str): The name of the Hub to create. - region (str): The region in which the CuratedHub is in. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. """ self.hub_name = hub_name - if sagemaker_session.boto_region_name != region: - raise ValueError( - f"Cannot have conflicting regions for region=[{region}] and ", - f"sagemaker_session region=[{str(sagemaker_session.boto_region_name)}].", - ) - self.region = region + self.region = sagemaker_session.boto_region_name self._sagemaker_session = sagemaker_session def create( @@ -75,9 +68,11 @@ def create( def describe(self) -> DescribeHubResponse: """Returns descriptive information about the Hub""" - hub_description = self._sagemaker_session.describe_hub(hub_name=self.hub_name) + hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub( + hub_name=self.hub_name + ) - return DescribeHubResponse(hub_description) + return hub_description def list_models(self, **kwargs) -> Dict[str, Any]: """Lists the models in this Curated Hub diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py index f5db923919..2448721520 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_curated_hub.py @@ -34,21 +34,12 @@ def sagemaker_session(): def test_instantiates(sagemaker_session): - hub = CuratedHub(hub_name=HUB_NAME, region=REGION, sagemaker_session=sagemaker_session) + hub = CuratedHub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) assert hub.hub_name == HUB_NAME assert hub.region == "us-east-1" assert hub._sagemaker_session == sagemaker_session -def test_instantiates_handles_conflicting_regions(sagemaker_session): - conflicting_region = "us-east-2" - - with pytest.raises(ValueError): - CuratedHub( - hub_name=HUB_NAME, region=conflicting_region, sagemaker_session=sagemaker_session - ) - - @pytest.mark.parametrize( ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), [ @@ -74,7 +65,7 @@ def test_create_with_no_bucket_name( ): create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - hub = CuratedHub(hub_name=hub_name, region=REGION, sagemaker_session=sagemaker_session) + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { "hub_name": hub_name, "hub_description": hub_description, @@ -119,7 +110,7 @@ def test_create_with_bucket_name( ): create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} sagemaker_session.create_hub = Mock(return_value=create_hub) - hub = CuratedHub(hub_name=hub_name, region=REGION, sagemaker_session=sagemaker_session) + hub = CuratedHub(hub_name=hub_name, sagemaker_session=sagemaker_session) request = { "hub_name": hub_name, "hub_description": hub_description, diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 69c8659148..27f5518834 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -252,14 +252,14 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): @patch("boto3.client") def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache = JumpStartModelsCache( - s3_bucket_name="some_bucket", region="some_region", manifest_file_s3_key="some_key" + s3_bucket_name="some_bucket", region="us-west-2", manifest_file_s3_key="some_key" ) cache.clear = MagicMock() cache.set_s3_bucket_name("some_bucket") cache.clear.assert_not_called() cache.clear.reset_mock() - cache.set_region("some_region") + cache.set_region("us-west-2") cache.clear.assert_not_called() cache.clear.reset_mock() cache.set_manifest_file_s3_key("some_key") @@ -270,7 +270,7 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache.set_s3_bucket_name("some_bucket1") cache.clear.assert_called_once() cache.clear.reset_mock() - cache.set_region("some_region1") + cache.set_region("us-east-1") cache.clear.assert_called_once() cache.clear.reset_mock() cache.set_manifest_file_s3_key("some_key1") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index f38a9d6ed4..091821a421 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -27,6 +27,7 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, + HubType, HubContentType, ) @@ -211,7 +212,7 @@ def patched_retrieval_function( ) # TODO: Implement - if datatype == HubContentType.HUB: + if datatype == HubType.HUB: return None raise ValueError(f"Bad value for filetype: {datatype}") From 0937c7490a2728221a32cddc0690dbdc8666e8e0 Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Wed, 28 Feb 2024 22:51:24 +0000 Subject: [PATCH 13/14] Use sagemaker.session.Session to call HubAPIs in cache --- src/sagemaker/jumpstart/cache.py | 33 +++++++++------ tests/unit/sagemaker/jumpstart/test_cache.py | 43 +++++++++++++++----- tests/unit/sagemaker/jumpstart/utils.py | 1 - 3 files changed, 52 insertions(+), 25 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 7765b4752a..6cbc5b30cb 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -21,6 +21,7 @@ import botocore from packaging.version import Version from packaging.specifiers import SpecifierSet, InvalidSpecifier +from sagemaker.session import Session from sagemaker.utilities.cache import LRUCache from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, @@ -29,6 +30,7 @@ JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, MODEL_ID_LIST_WEB_URL, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( @@ -51,7 +53,6 @@ HubContentType, ) from sagemaker.jumpstart.curated_hub import utils as hub_utils -from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub class JumpStartModelsCache: @@ -77,6 +78,7 @@ def __init__( s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> None: # fmt: on """Initialize a ``JumpStartModelsCache`` instance. @@ -98,6 +100,8 @@ def __init__( s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. + sagemaker_session (Optional[sagemaker.session.Session]): A SageMaker Session object, + used for SageMaker interactions. Default: Session in region associated with boto3 session. """ self._region = region @@ -124,6 +128,7 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) + self._sagemaker_session = sagemaker_session def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" @@ -343,15 +348,17 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubContentType.MODEL: - hub_name, region, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( + hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( id_info ) - hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region) - hub_model_description: DescribeHubContentsResponse = hub.describe_model( - model_name=model_name, - model_version=model_version + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=data_type ) - model_specs = JumpStartModelSpecs(hub_model_description, is_hub_content=True) + + model_specs = JumpStartModelSpecs(DescribeHubContentsResponse(hub_model_description), is_hub_content=True) utils.emit_logs_based_on_model_specs( model_specs, @@ -362,13 +369,13 @@ def _retrieval_function( formatted_content=model_specs ) if data_type == HubType.HUB: - hub_name, region, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) - hub: CuratedHub = CuratedHub(hub_name=hub_name, region=region) - hub_description: DescribeHubResponse = hub.describe() - return JumpStartCachedContentValue(formatted_content=hub_description) + hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info) + response: Dict[str, Any] = self._sagemaker_session.describe_hub(hub_name=hub_name) + hub_description = DescribeHubResponse(response) + return JumpStartCachedContentValue(formatted_content=DescribeHubResponse(hub_description)) raise ValueError( f"Bad value for key '{key}': must be in ", - f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}" + f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubType.HUB, HubContentType.MODEL]}" ) def get_manifest(self) -> List[JumpStartModelHeader]: @@ -493,7 +500,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn)) + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubType.HUB, hub_arn)) return details.formatted_content def clear(self) -> None: diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 27f5518834..b9eebbfc6d 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -22,7 +22,7 @@ from mock.mock import MagicMock import pytest from mock import patch - +from sagemaker.session_settings import SessionSettings from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, @@ -45,6 +45,27 @@ from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +REGION = "us-east-1" +REGION2 = "us-east-2" +ACCOUNT_ID = "123456789123" + + +@pytest.fixture() +def sagemaker_session(): + mocked_boto_session = Mock(name="boto_session") + mocked_s3_client= Mock(name="s3_client") + mocked_sagemaker_session = Mock( + name="sagemaker_session", boto_session=mocked_boto_session, s3_client= mocked_s3_client, boto_region_name=REGION, config=None, + ) + mocked_sagemaker_session.sagemaker_config = {} + mocked_sagemaker_session._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + mocked_sagemaker_session.account_id.return_value = ACCOUNT_ID + return mocked_sagemaker_session + + + @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): @@ -252,14 +273,14 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client): @patch("boto3.client") def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache = JumpStartModelsCache( - s3_bucket_name="some_bucket", region="us-west-2", manifest_file_s3_key="some_key" + s3_bucket_name="some_bucket", region=REGION, manifest_file_s3_key="some_key" ) cache.clear = MagicMock() cache.set_s3_bucket_name("some_bucket") cache.clear.assert_not_called() cache.clear.reset_mock() - cache.set_region("us-west-2") + cache.set_region(REGION) cache.clear.assert_not_called() cache.clear.reset_mock() cache.set_manifest_file_s3_key("some_key") @@ -270,7 +291,7 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client): cache.set_s3_bucket_name("some_bucket1") cache.clear.assert_called_once() cache.clear.reset_mock() - cache.set_region("us-east-1") + cache.set_region(REGION2) cache.clear.assert_called_once() cache.clear.reset_mock() cache.set_manifest_file_s3_key("some_key1") @@ -399,7 +420,6 @@ def test_jumpstart_cache_handles_boto3_client_errors(): def test_jumpstart_cache_accepts_input_parameters(): - region = "us-east-1" max_s3_cache_items = 1 s3_cache_expiration_horizon = datetime.timedelta(weeks=2) max_semantic_version_cache_items = 3 @@ -408,7 +428,7 @@ def test_jumpstart_cache_accepts_input_parameters(): manifest_file_key = "some_s3_key" cache = JumpStartModelsCache( - region=region, + region=REGION, max_s3_cache_items=max_s3_cache_items, s3_cache_expiration_horizon=s3_cache_expiration_horizon, max_semantic_version_cache_items=max_semantic_version_cache_items, @@ -418,7 +438,7 @@ def test_jumpstart_cache_accepts_input_parameters(): ) assert cache.get_manifest_file_s3_key() == manifest_file_key - assert cache.get_region() == region + assert cache.get_region() == REGION assert cache.get_bucket() == bucket assert cache._content_cache._max_cache_items == max_s3_cache_items assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon @@ -741,7 +761,7 @@ def test_jumpstart_cache_get_specs(): @patch("sagemaker.jumpstart.cache.os.path.isdir") @patch("builtins.open") def test_jumpstart_local_metadata_override_header( - mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock + mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock, sagemaker_session: Mock ): mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST)) mocked_is_dir.return_value = True @@ -760,7 +780,7 @@ def test_jumpstart_local_metadata_override_header( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") assert mocked_is_dir.call_count == 2 - mocked_open.assert_called_once_with( + mocked_open.assert_called_with( "/some/directory/metadata/manifest/root/models_manifest.json", "r" ) mocked_get_json_file_and_etag_from_s3.assert_not_called() @@ -783,6 +803,7 @@ def test_jumpstart_local_metadata_override_specs( mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock, mock_emit_logs_based_on_model_specs, + sagemaker_session, ): mocked_open.side_effect = [ @@ -791,7 +812,7 @@ def test_jumpstart_local_metadata_override_specs( ] mocked_is_dir.return_value = True - cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + cache = JumpStartModelsCache(s3_bucket_name="some_bucket", s3_client=Mock(), sagemaker_session=sagemaker_session) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( @@ -845,7 +866,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") assert mocked_is_dir.call_count == 2 - mocked_open.assert_not_called() + assert mocked_open.call_count == 2 mocked_get_json_file_and_etag_from_s3.assert_has_calls( calls=[ call("models_manifest.json"), diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 091821a421..a809b32a24 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -195,7 +195,6 @@ def patched_retrieval_function( datatype, id_info = key.data_type, key.id_info if datatype == JumpStartS3FileType.MANIFEST: - return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) if datatype == JumpStartS3FileType.SPECS: From dddd0a69c0cd74dc72c97259eb6a48780ce25f0e Mon Sep 17 00:00:00 2001 From: JJ Lim Date: Wed, 28 Feb 2024 22:56:14 +0000 Subject: [PATCH 14/14] lint --- tests/unit/sagemaker/jumpstart/test_cache.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index b9eebbfc6d..423dbf5e02 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -53,9 +53,13 @@ @pytest.fixture() def sagemaker_session(): mocked_boto_session = Mock(name="boto_session") - mocked_s3_client= Mock(name="s3_client") + mocked_s3_client = Mock(name="s3_client") mocked_sagemaker_session = Mock( - name="sagemaker_session", boto_session=mocked_boto_session, s3_client= mocked_s3_client, boto_region_name=REGION, config=None, + name="sagemaker_session", + boto_session=mocked_boto_session, + s3_client=mocked_s3_client, + boto_region_name=REGION, + config=None, ) mocked_sagemaker_session.sagemaker_config = {} mocked_sagemaker_session._client_config.user_agent = ( @@ -65,7 +69,6 @@ def sagemaker_session(): return mocked_sagemaker_session - @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): @@ -761,7 +764,10 @@ def test_jumpstart_cache_get_specs(): @patch("sagemaker.jumpstart.cache.os.path.isdir") @patch("builtins.open") def test_jumpstart_local_metadata_override_header( - mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock, sagemaker_session: Mock + mocked_open: Mock, + mocked_is_dir: Mock, + mocked_get_json_file_and_etag_from_s3: Mock, + sagemaker_session: Mock, ): mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST)) mocked_is_dir.return_value = True @@ -812,7 +818,9 @@ def test_jumpstart_local_metadata_override_specs( ] mocked_is_dir.return_value = True - cache = JumpStartModelsCache(s3_bucket_name="some_bucket", s3_client=Mock(), sagemaker_session=sagemaker_session) + cache = JumpStartModelsCache( + s3_bucket_name="some_bucket", s3_client=Mock(), sagemaker_session=sagemaker_session + ) model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs(