From b1f5cd833e07592d53eb61a3ec62f325e3855e7e Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Mon, 24 Jun 2024 17:04:24 -0400 Subject: [PATCH 01/25] fix: list_models() for python3.8 --- src/sagemaker/jumpstart/hub/hub.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 1545fe3a36..aa6904d107 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -184,13 +184,16 @@ def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: **{ "hub_name": self.hub_name, "hub_content_type": HubContentType.MODEL_REFERENCE.value, + **kwargs } - | kwargs ) hub_model_summaries = self._list_and_paginate_models( - **{"hub_name": self.hub_name, "hub_content_type": HubContentType.MODEL.value} - | kwargs + **{ + "hub_name": self.hub_name, + "hub_content_type": HubContentType.MODEL.value, + **kwargs + } ) response["hub_content_summaries"] = hub_model_reference_summaries + hub_model_summaries response["next_token"] = None # Temporary until pagination is implemented From 6b9f390d9dcb44f5c47bb8d41a07af6cef85c2b7 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Mon, 24 Jun 2024 23:05:41 -0400 Subject: [PATCH 02/25] fix linting --- src/sagemaker/jumpstart/hub/hub.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index aa6904d107..d208220965 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -184,15 +184,15 @@ def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: **{ "hub_name": self.hub_name, "hub_content_type": HubContentType.MODEL_REFERENCE.value, - **kwargs + **kwargs, } ) hub_model_summaries = self._list_and_paginate_models( **{ - "hub_name": self.hub_name, + "hub_name": self.hub_name, "hub_content_type": HubContentType.MODEL.value, - **kwargs + **kwargs, } ) response["hub_content_summaries"] = hub_model_reference_summaries + hub_model_summaries From cb6660841ca4607a6beaa1fcc802677db3e0f961 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 27 Jun 2024 10:45:58 -0400 Subject: [PATCH 03/25] fix: Address nits and improvements --- src/sagemaker/jumpstart/accessors.py | 21 ++++++++++++--------- src/sagemaker/jumpstart/factory/model.py | 16 ++++++++++++++-- src/sagemaker/jumpstart/utils.py | 4 ++-- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 66003c9f03..e3400defbc 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -14,6 +14,7 @@ """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import import functools +import logging from typing import Any, Dict, List, Optional import boto3 @@ -289,15 +290,6 @@ def get_model_specs( if hub_arn: try: - hub_model_arn = construct_hub_model_arn_from_inputs( - hub_arn=hub_arn, model_name=model_id, version=version - ) - model_specs = JumpStartModelsAccessor._cache.get_hub_model( - hub_model_arn=hub_model_arn - ) - model_specs.set_hub_content_type(HubContentType.MODEL) - return model_specs - except: # noqa: E722 hub_model_arn = construct_hub_model_reference_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) @@ -307,6 +299,17 @@ def get_model_specs( model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) return model_specs + except Exception as ex: + logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) + hub_model_arn = construct_hub_model_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL) + return model_specs + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type ) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 55dfa1394a..a4b0fcac1f 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -34,7 +34,7 @@ JUMPSTART_LOGGER, ) from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard -from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs +from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -550,7 +550,19 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: ) if kwargs.hub_arn: - kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) + if kwargs.model_reference_arn: + hub_content_arn = construct_hub_model_reference_arn_from_inputs( + kwargs.hub_arn, + kwargs.model_id, + kwargs.model_version + ) + else: + hub_content_arn = construct_hub_model_arn_from_inputs( + kwargs.hub_arn, + kwargs.model_id, + kwargs.model_version + ) + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) return kwargs diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 989ca426b5..4684e6c5c1 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -384,12 +384,12 @@ def add_jumpstart_model_id_version_tags( def add_hub_content_arn_tags( tags: Optional[List[TagsDict]], - hub_arn: str, + hub_content_arn: str, ) -> Optional[List[TagsDict]]: """Adds custom Hub arn tag to JumpStart related resources.""" tags = add_single_jumpstart_tag( - hub_arn, + hub_content_arn, enums.JumpStartTag.HUB_CONTENT_ARN, tags, is_uri=False, From 5392504b84f766e00c9eb5e4c041fed3459ba801 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 27 Jun 2024 11:25:11 -0400 Subject: [PATCH 04/25] fix codestyle issues --- src/sagemaker/jumpstart/accessors.py | 4 +++- src/sagemaker/jumpstart/factory/model.py | 13 ++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index e3400defbc..4049af2ed8 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -300,7 +300,9 @@ def get_model_specs( return model_specs except Exception as ex: - logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) + logging.info( + "Recieved expection while calling APIs for ContentType Model: " + str(ex) + ) hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version ) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 99c6836e14..638f48a96d 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -34,7 +34,10 @@ JUMPSTART_LOGGER, ) from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard -from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines @@ -552,15 +555,11 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: if kwargs.hub_arn: if kwargs.model_reference_arn: hub_content_arn = construct_hub_model_reference_arn_from_inputs( - kwargs.hub_arn, - kwargs.model_id, - kwargs.model_version + kwargs.hub_arn, kwargs.model_id, kwargs.model_version ) else: hub_content_arn = construct_hub_model_arn_from_inputs( - kwargs.hub_arn, - kwargs.model_id, - kwargs.model_version + kwargs.hub_arn, kwargs.model_id, kwargs.model_version ) kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) From 269dc084dd3e74edf8238edc636abfb80111d5b6 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 27 Jun 2024 13:05:56 -0400 Subject: [PATCH 05/25] fix: don't force automatic bucket creation if user don't specify it --- src/sagemaker/jumpstart/hub/hub.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index d208220965..49025d7ac4 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -90,25 +90,22 @@ def _fetch_hub_bucket_name(self) -> str: if hub_output_location: location = create_s3_object_reference_from_uri(hub_output_location) return location.bucket - default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s. Using %s", + "There is not a Hub bucket associated with %s.", self.hub_name, - default_bucket_name, ) - return default_bucket_name + return None except exceptions.ClientError: - hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s. Using %s", + "There is not a Hub bucket associated with %s.", self.hub_name, - hub_bucket_name, ) - return hub_bucket_name - + return None def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: """Generates an ``S3ObjectLocation`` given a Hub name.""" hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() + if hub_bucket_name is None: + return curr_timestamp = datetime.now().timestamp() return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") @@ -131,16 +128,16 @@ def create( ) -> Dict[str, str]: """Creates a hub with the given description""" - create_hub_bucket_if_it_does_not_exist( - self.hub_storage_location.bucket, self._sagemaker_session - ) + s3_storage_config = { + "S3OutputPath": self.hub_storage_location.get_uri() + } if self.hub_storage_location else None return self._sagemaker_session.create_hub( hub_name=self.hub_name, hub_description=description, hub_display_name=display_name, hub_search_keywords=search_keywords, - s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, + s3_storage_config=s3_storage_config, tags=tags, ) From 502063f6bd9b6effd44b1570ed93d55565f9dd88 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 27 Jun 2024 13:06:40 -0400 Subject: [PATCH 06/25] fix formatting --- src/sagemaker/jumpstart/hub/hub.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 49025d7ac4..324bb99f9e 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -101,6 +101,7 @@ def _fetch_hub_bucket_name(self) -> str: self.hub_name, ) return None + def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: """Generates an ``S3ObjectLocation`` given a Hub name.""" hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() @@ -128,9 +129,11 @@ def create( ) -> Dict[str, str]: """Creates a hub with the given description""" - s3_storage_config = { - "S3OutputPath": self.hub_storage_location.get_uri() - } if self.hub_storage_location else None + s3_storage_config = ( + {"S3OutputPath": self.hub_storage_location.get_uri()} + if self.hub_storage_location + else None + ) return self._sagemaker_session.create_hub( hub_name=self.hub_name, From f553357e642237fc7809b57fdfdcc11f4c4806f6 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Thu, 27 Jun 2024 13:33:20 -0400 Subject: [PATCH 07/25] fix flake8 --- src/sagemaker/jumpstart/hub/hub.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 324bb99f9e..7d98946c87 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -33,8 +33,6 @@ from sagemaker.jumpstart.hub.utils import ( get_hub_model_version, get_info_from_hub_resource_arn, - create_hub_bucket_if_it_does_not_exist, - generate_default_hub_bucket_name, create_s3_object_reference_from_uri, construct_hub_arn_from_name, ) From 5ab02e4fcbacae0bb16ae97d1d37283e89175942 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Wed, 3 Jul 2024 15:28:32 -0400 Subject: [PATCH 08/25] address nits --- src/sagemaker/jumpstart/accessors.py | 2 +- src/sagemaker/jumpstart/hub/hub.py | 2 +- src/sagemaker/jumpstart/hub/utils.py | 6 +++++- src/sagemaker/jumpstart/types.py | 6 ++---- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 4049af2ed8..87e5415181 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -301,7 +301,7 @@ def get_model_specs( except Exception as ex: logging.info( - "Recieved expection while calling APIs for ContentType Model: " + str(ex) + "Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " + str(ex) ) hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 7d98946c87..6a4e425123 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -286,7 +286,7 @@ def describe_model( ) except Exception as ex: - logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) + logging.info("Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " + str(ex)) model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL_REFERENCE.value, diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 3dfe99a8c4..fc68ccb0fd 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -193,7 +193,11 @@ def get_hub_model_version( hub_model_version: Optional[str] = None, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: - """Returns available Jumpstart hub model version""" + """Returns available Jumpstart hub model version + + Raises: + ResourceNotFound: If the specified model is not found in the hub. + """ try: hub_content_summaries = sagemaker_session.list_hub_content_versions( diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 171d9ce8a1..e34c019293 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -17,6 +17,7 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.jumpstart.constants import HUB_ARN_REGEX, HUB_CONTENT_ARN_REGEX from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict from sagemaker.model_metrics import ModelMetrics @@ -1412,6 +1413,7 @@ def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False) Args: spec (Dict[str, Any]): Dictionary representation of spec. + is_hub_content (Optional[bool]): Whether the model is from a private hub. """ super().__init__(spec, is_hub_content) self.from_json(spec) @@ -1666,10 +1668,6 @@ def __init__( def extract_region_from_arn(arn: str) -> Optional[str]: """Extracts hub_name, content_name, and content_version from a HubContentArn""" - HUB_CONTENT_ARN_REGEX = ( - r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" - ) - HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" match = re.match(HUB_CONTENT_ARN_REGEX, arn) hub_region = None From 37a36c8f6df7d324ba12d86d694d0578f00f0fcc Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Sun, 7 Jul 2024 18:28:46 -0400 Subject: [PATCH 09/25] revert HUB_ARN_REGEX and HUB_CONTENT_ARN_REGEX constants from types.py due to the circular dependancy issue --- src/sagemaker/jumpstart/types.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index e34c019293..8017004ad9 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -17,7 +17,6 @@ from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union -from sagemaker.jumpstart.constants import HUB_ARN_REGEX, HUB_CONTENT_ARN_REGEX from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard from sagemaker.utils import get_instance_type_family, format_tags, Tags, deep_override_dict from sagemaker.model_metrics import ModelMetrics @@ -1668,6 +1667,10 @@ def __init__( def extract_region_from_arn(arn: str) -> Optional[str]: """Extracts hub_name, content_name, and content_version from a HubContentArn""" + HUB_CONTENT_ARN_REGEX = ( + r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" + ) + HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" match = re.match(HUB_CONTENT_ARN_REGEX, arn) hub_region = None From 3fe2774c1000fa1d86ed6015de0f7efdf0e30804 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Mon, 8 Jul 2024 18:05:47 -0400 Subject: [PATCH 10/25] revert: don't force automatic bucket creation if user don't specify it --- src/sagemaker/jumpstart/hub/hub.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 6a4e425123..6bac175cd4 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -33,6 +33,8 @@ from sagemaker.jumpstart.hub.utils import ( get_hub_model_version, get_info_from_hub_resource_arn, + create_hub_bucket_if_it_does_not_exist, + generate_default_hub_bucket_name, create_s3_object_reference_from_uri, construct_hub_arn_from_name, ) @@ -88,23 +90,25 @@ def _fetch_hub_bucket_name(self) -> str: if hub_output_location: location = create_s3_object_reference_from_uri(hub_output_location) return location.bucket + default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s.", + "There is not a Hub bucket associated with %s. Using %s", self.hub_name, + default_bucket_name, ) - return None + return default_bucket_name except exceptions.ClientError: + hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) JUMPSTART_LOGGER.warning( - "There is not a Hub bucket associated with %s.", + "There is not a Hub bucket associated with %s. Using %s", self.hub_name, + hub_bucket_name, ) - return None + return hub_bucket_name def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: """Generates an ``S3ObjectLocation`` given a Hub name.""" hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() - if hub_bucket_name is None: - return curr_timestamp = datetime.now().timestamp() return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") @@ -127,10 +131,8 @@ def create( ) -> Dict[str, str]: """Creates a hub with the given description""" - s3_storage_config = ( - {"S3OutputPath": self.hub_storage_location.get_uri()} - if self.hub_storage_location - else None + create_hub_bucket_if_it_does_not_exist( + self.hub_storage_location.bucket, self._sagemaker_session ) return self._sagemaker_session.create_hub( @@ -138,7 +140,7 @@ def create( hub_description=description, hub_display_name=display_name, hub_search_keywords=search_keywords, - s3_storage_config=s3_storage_config, + s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, tags=tags, ) From 10dba2ceec16d1df115d701221279e3c3f9a1012 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Mon, 8 Jul 2024 22:38:16 -0400 Subject: [PATCH 11/25] fix: fix _add_tags_to_kwargs to use hub_content_arn instead of hub_arn --- src/sagemaker/jumpstart/factory/estimator.py | 43 +++++++++++++++++++- src/sagemaker/jumpstart/factory/model.py | 1 + src/sagemaker/jumpstart/types.py | 3 ++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index d3e597c395..a3c88b04a7 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -29,6 +29,7 @@ _retrieve_model_package_model_artifact_s3_uri, ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base +from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs from sagemaker.session import Session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -52,6 +53,7 @@ from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.factory import model from sagemaker.jumpstart.types import ( + HubContentType, JumpStartEstimatorDeployKwargs, JumpStartEstimatorFitKwargs, JumpStartEstimatorInitKwargs, @@ -201,6 +203,11 @@ def get_init_kwargs( estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs) + if hub_arn: + estimator_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=estimator_init_kwargs) + else: + estimator_init_kwargs.model_reference_arn = None + estimator_init_kwargs.hub_content_type = None estimator_init_kwargs = _add_model_uri_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_source_dir_to_kwargs(estimator_init_kwargs) estimator_init_kwargs = _add_entry_point_to_kwargs(estimator_init_kwargs) @@ -511,7 +518,15 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima ) if kwargs.hub_arn: - kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) + if kwargs.model_reference_arn: + hub_content_arn = construct_hub_model_reference_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + else: + hub_content_arn = construct_hub_model_arn_from_inputs( + kwargs.hub_arn, kwargs.model_id, kwargs.model_version + ) + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, hub_content_arn=hub_content_arn) return kwargs @@ -534,6 +549,32 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE return kwargs +def _add_model_reference_arn_to_kwargs( + kwargs: JumpStartEstimatorInitKwargs, +) -> JumpStartEstimatorInitKwargs: + """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + + hub_content_type = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).hub_content_type + kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None + + if hub_content_type == HubContentType.MODEL_REFERENCE: + kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version + ) + else: + kwargs.model_reference_arn = None + return kwargs + def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs: """Sets model uri in kwargs based on default or override, returns full kwargs.""" diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 638f48a96d..f4eb7a3331 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -268,6 +268,7 @@ def _add_model_reference_arn_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + hub_content_type = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 8017004ad9..56f8df1e74 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2026,6 +2026,8 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "enable_infra_check", "enable_remote_debug", "enable_session_tag_chaining", + "hub_content_type", + "model_reference_arn", ] SERIALIZATION_EXCLUSION_SET = { @@ -2036,6 +2038,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_content_type", } def __init__( From 5449eb54c72a593813c59391953b41f8a2568b18 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Mon, 8 Jul 2024 23:04:47 -0400 Subject: [PATCH 12/25] fix codestyle issues --- src/sagemaker/jumpstart/accessors.py | 3 ++- src/sagemaker/jumpstart/factory/estimator.py | 8 ++++++-- src/sagemaker/jumpstart/factory/model.py | 2 +- src/sagemaker/jumpstart/hub/hub.py | 5 ++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 87e5415181..0ff072a9c0 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -301,7 +301,8 @@ def get_model_specs( except Exception as ex: logging.info( - "Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " + str(ex) + "Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " + + str(ex) ) hub_model_arn = construct_hub_model_arn_from_inputs( hub_arn=hub_arn, model_name=model_id, version=version diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index a3c88b04a7..b07e643796 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -29,7 +29,10 @@ _retrieve_model_package_model_artifact_s3_uri, ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base -from sagemaker.jumpstart.hub.utils import construct_hub_model_arn_from_inputs, construct_hub_model_reference_arn_from_inputs +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.session import Session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig from sagemaker.base_deserializers import BaseDeserializer @@ -549,11 +552,12 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE return kwargs + def _add_model_reference_arn_to_kwargs( kwargs: JumpStartEstimatorInitKwargs, ) -> JumpStartEstimatorInitKwargs: """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" - + hub_content_type = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index f4eb7a3331..29148eea75 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -268,7 +268,7 @@ def _add_model_reference_arn_to_kwargs( kwargs: JumpStartModelInitKwargs, ) -> JumpStartModelInitKwargs: """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" - + hub_content_type = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 6bac175cd4..4ac27b1003 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -288,7 +288,10 @@ def describe_model( ) except Exception as ex: - logging.info("Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " + str(ex)) + logging.info( + "Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " + + str(ex) + ) model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL_REFERENCE.value, From 559ef2e3b5c757f66e0402cc58e05993aa902851 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Mon, 8 Jul 2024 23:47:01 -0400 Subject: [PATCH 13/25] feat: Add support for Hub in model attach functionality --- src/sagemaker/jumpstart/model.py | 2 ++ tests/unit/sagemaker/jumpstart/model/test_model.py | 1 + 2 files changed, 3 insertions(+) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index b482d4fefd..45544e5202 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -429,6 +429,7 @@ def attach( cls, endpoint_name: str, inference_component_name: Optional[str] = None, + hub_name: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -457,6 +458,7 @@ def attach( model_id=model_id, model_version=model_version, sagemaker_session=sagemaker_session, + hub_name=hub_name, ) model.endpoint_name = endpoint_name model.inference_component_name = inference_component_name diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 15c2c43bf0..29cc1c9773 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1360,6 +1360,7 @@ def test_attach( model_id="model-id", model_version="model-version", sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_name=None, ) assert isinstance(val, JumpStartModel) From 7e307bf6b2ca0b309af188e73b03790cc51b12d7 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 02:15:16 -0400 Subject: [PATCH 14/25] feat: Add curatedHub telemetry support --- src/sagemaker/jumpstart/factory/estimator.py | 2 +- src/sagemaker/jumpstart/factory/model.py | 4 +++- src/sagemaker/jumpstart/utils.py | 9 +++++---- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index b07e643796..9c063a1f4a 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -432,7 +432,7 @@ def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version + kwargs.model_id, kwargs.model_version, kwargs.hub_arn ) ) return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 29148eea75..5f9f6ab038 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -152,12 +152,14 @@ def _add_sagemaker_session_to_kwargs( kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] ) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" + kwargs.sagemaker_session = ( kwargs.sagemaker_session or get_default_jumpstart_session_with_user_agent_suffix( - kwargs.model_id, kwargs.model_version + kwargs.model_id, kwargs.model_version, kwargs.hub_arn ) ) + return kwargs diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 4684e6c5c1..ac8a8165e8 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1012,24 +1012,25 @@ def get_jumpstart_configs( ) -def get_jumpstart_user_agent_extra_suffix(model_id: str, model_version: str) -> str: +def get_jumpstart_user_agent_extra_suffix(model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool]) -> str: """Returns the model-specific user agent string to be added to requests.""" sagemaker_python_sdk_headers = get_user_agent_extra_suffix() jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}" + hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}" return ( sagemaker_python_sdk_headers if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None) - else f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}" + else f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}" ) def get_default_jumpstart_session_with_user_agent_suffix( - model_id: str, model_version: str + model_id: Optional[str] = None, model_version: Optional[str] = None, is_hub_content: Optional[bool] = False ) -> Session: """Returns default JumpStart SageMaker Session with model-specific user agent suffix.""" botocore_session = botocore.session.get_session() botocore_config = botocore.config.Config( - user_agent_extra=get_jumpstart_user_agent_extra_suffix(model_id, model_version), + user_agent_extra=get_jumpstart_user_agent_extra_suffix(model_id, model_version, is_hub_content), ) botocore_session.set_default_client_config(botocore_config) # shallow copy to not affect default session constant From 5f7e955b2693875b3f361c2497fec67a81afcc27 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 02:18:16 -0400 Subject: [PATCH 15/25] Address codestyledoc issues --- src/sagemaker/jumpstart/accessors.py | 3 ++- src/sagemaker/jumpstart/utils.py | 28 ++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 0ff072a9c0..1b3da337b2 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -301,7 +301,8 @@ def get_model_specs( except Exception as ex: logging.info( - "Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " + "Recieved exeption while calling APIs for ContentType Model, \ + retrying with ContentType ModelReference: " + str(ex) ) hub_model_arn = construct_hub_model_arn_from_inputs( diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index ac8a8165e8..ff7b71f0fc 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1012,25 +1012,37 @@ def get_jumpstart_configs( ) -def get_jumpstart_user_agent_extra_suffix(model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool]) -> str: +def get_jumpstart_user_agent_extra_suffix( + model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool] +) -> str: """Returns the model-specific user agent string to be added to requests.""" sagemaker_python_sdk_headers = get_user_agent_extra_suffix() jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}" hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}" - return ( - sagemaker_python_sdk_headers - if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None) - else f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}" - ) + + if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None): + headers = sagemaker_python_sdk_headers + elif model_id is None and model_version is None: + headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}" + else: + headers = ( + f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}" + ) + + return headers def get_default_jumpstart_session_with_user_agent_suffix( - model_id: Optional[str] = None, model_version: Optional[str] = None, is_hub_content: Optional[bool] = False + model_id: Optional[str] = None, + model_version: Optional[str] = None, + is_hub_content: Optional[bool] = False, ) -> Session: """Returns default JumpStart SageMaker Session with model-specific user agent suffix.""" botocore_session = botocore.session.get_session() botocore_config = botocore.config.Config( - user_agent_extra=get_jumpstart_user_agent_extra_suffix(model_id, model_version, is_hub_content), + user_agent_extra=get_jumpstart_user_agent_extra_suffix( + model_id, model_version, is_hub_content + ), ) botocore_session.set_default_client_config(botocore_config) # shallow copy to not affect default session constant From 38495dc1ecfdf6185bc9ed4d3d26e3a0494e52f8 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 02:30:38 -0400 Subject: [PATCH 16/25] fix failing unit tests --- tests/unit/sagemaker/jumpstart/model/test_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 29cc1c9773..15c2c43bf0 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1360,7 +1360,6 @@ def test_attach( model_id="model-id", model_version="model-version", sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - hub_name=None, ) assert isinstance(val, JumpStartModel) From 90006f6466b0491744d4763a31914d30afb5fb2d Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 09:52:41 -0400 Subject: [PATCH 17/25] fix failing tests --- .../sagemaker/jumpstart/model/test_model.py | 1 + tests/unit/sagemaker/jumpstart/test_utils.py | 26 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 15c2c43bf0..29cc1c9773 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -1360,6 +1360,7 @@ def test_attach( model_id="model-id", model_version="model-version", sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_name=None, ) assert isinstance(val, JumpStartModel) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 941e2797ea..4c86dd6cf8 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1714,21 +1714,21 @@ class TestUserAgent: @patch("sagemaker.jumpstart.utils.os.getenv") def test_get_jumpstart_user_agent_extra_suffix(self, mock_getenv): mock_getenv.return_value = False - assert utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#False") mock_getenv.return_value = None - assert utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) + assert utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "False" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#False") mock_getenv.return_value = "True" - assert not utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) + assert not utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "True" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") mock_getenv.return_value = True - assert not utils.get_jumpstart_user_agent_extra_suffix("some-id", "some-version").endswith( - "md/js_model_id#some-id md/js_model_ver#some-version" - ) + assert not utils.get_jumpstart_user_agent_extra_suffix( + "some-id", "some-version", "True" + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#True") @patch("sagemaker.jumpstart.utils.botocore.session") @patch("sagemaker.jumpstart.utils.botocore.config.Config") @@ -1748,7 +1748,7 @@ def test_get_default_jumpstart_session_with_user_agent_suffix( utils.get_default_jumpstart_session_with_user_agent_suffix("model_id", "model_version") mock_boto3_session.get_session.assert_called_once_with() mock_get_jumpstart_user_agent_extra_suffix.assert_called_once_with( - "model_id", "model_version" + "model_id", "model_version", False ) mock_botocore_config.assert_called_once_with( user_agent_extra=mock_get_jumpstart_user_agent_extra_suffix.return_value From c331b0c053a8b02d75bde096d0db5375327b4792 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 09:54:54 -0400 Subject: [PATCH 18/25] change default session object in hub class to one with user agent string --- src/sagemaker/jumpstart/hub/hub.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 4ac27b1003..5cb4506a84 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -68,7 +68,9 @@ def __init__( self, hub_name: str, bucket_name: Optional[str] = None, - sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + sagemaker_session: Optional[ + Session + ] = utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True), ) -> None: """Instantiates a SageMaker ``Hub``. From ac45eeadafce4c6fb74c4187be62698a44754687 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 10:03:23 -0400 Subject: [PATCH 19/25] fix flake8 --- src/sagemaker/jumpstart/hub/hub.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 5cb4506a84..9244400fbf 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -23,7 +23,6 @@ from sagemaker.session import Session from sagemaker.jumpstart.constants import ( - DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER, ) from sagemaker.jumpstart.types import ( From 2f071301d0faefabf9004ed33484f8bdc8fb25d2 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 14:26:09 -0400 Subject: [PATCH 20/25] address comments: moving get default JS session to constructor body --- src/sagemaker/jumpstart/hub/hub.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 9244400fbf..21fdfa30b9 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -67,9 +67,7 @@ def __init__( self, hub_name: str, bucket_name: Optional[str] = None, - sagemaker_session: Optional[ - Session - ] = utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True), + sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates a SageMaker ``Hub``. @@ -80,7 +78,7 @@ def __init__( """ self.hub_name = hub_name self.region = sagemaker_session.boto_region_name - self._sagemaker_session = sagemaker_session + self._sagemaker_session = sagemaker_session or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) self.hub_storage_location = self._generate_hub_storage_location(bucket_name) def _fetch_hub_bucket_name(self) -> str: From 6fb32234d0bd5c566ce431592a0243623d7efa14 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 14:39:14 -0400 Subject: [PATCH 21/25] Address comments: only add is_hub_content to user aggent suffix if its available --- src/sagemaker/jumpstart/hub/hub.py | 5 ++++- src/sagemaker/jumpstart/utils.py | 13 ++++++++----- tests/unit/sagemaker/jumpstart/test_utils.py | 4 ++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 21fdfa30b9..856fe72c94 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -78,7 +78,10 @@ def __init__( """ self.hub_name = hub_name self.region = sagemaker_session.boto_region_name - self._sagemaker_session = sagemaker_session or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) + self._sagemaker_session = ( + sagemaker_session + or utils.get_default_jumpstart_session_with_user_agent_suffix(is_hub_content=True) + ) self.hub_storage_location = self._generate_hub_storage_location(bucket_name) def _fetch_hub_bucket_name(self) -> str: diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index ff7b71f0fc..1cb6006da3 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -1022,12 +1022,15 @@ def get_jumpstart_user_agent_extra_suffix( if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None): headers = sagemaker_python_sdk_headers - elif model_id is None and model_version is None: - headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}" + elif is_hub_content is True: + if model_id is None and model_version is None: + headers = f"{sagemaker_python_sdk_headers} {hub_specific_suffix}" + else: + headers = ( + f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}" + ) else: - headers = ( - f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix} {hub_specific_suffix}" - ) + headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}" return headers diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 4c86dd6cf8..2363825141 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1716,11 +1716,11 @@ def test_get_jumpstart_user_agent_extra_suffix(self, mock_getenv): mock_getenv.return_value = False assert utils.get_jumpstart_user_agent_extra_suffix( "some-id", "some-version", "False" - ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#False") + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") mock_getenv.return_value = None assert utils.get_jumpstart_user_agent_extra_suffix( "some-id", "some-version", "False" - ).endswith("md/js_model_id#some-id md/js_model_ver#some-version md/js_is_hub_content#False") + ).endswith("md/js_model_id#some-id md/js_model_ver#some-version") mock_getenv.return_value = "True" assert not utils.get_jumpstart_user_agent_extra_suffix( "some-id", "some-version", "True" From 0f3f4346008cd3d9bbae3a3a8b51154d86d34fd9 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 14:54:43 -0400 Subject: [PATCH 22/25] try with ModelReference first then with Model type --- src/sagemaker/jumpstart/accessors.py | 4 ++-- src/sagemaker/jumpstart/hub/hub.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 1b3da337b2..4cf544b93e 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -301,8 +301,8 @@ def get_model_specs( except Exception as ex: logging.info( - "Recieved exeption while calling APIs for ContentType Model, \ - retrying with ContentType ModelReference: " + "Recieved exeption while calling APIs for ContentType ModelReference, \ + retrying with ContentType Model: " + str(ex) ) hub_model_arn = construct_hub_model_arn_from_inputs( diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index 856fe72c94..c61866e151 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -276,37 +276,37 @@ def describe_model( try: model_version = get_hub_model_version( hub_model_name=model_name, - hub_model_type=HubContentType.MODEL.value, + hub_model_type=HubContentType.MODEL_REFERENCE.value, hub_name=self.hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name if not hub_name else hub_name, + hub_name=self.hub_name, hub_content_name=model_name, hub_content_version=model_version, - hub_content_type=HubContentType.MODEL.value, + hub_content_type=HubContentType.MODEL_REFERENCE.value, ) except Exception as ex: logging.info( - "Recieved exeption while calling APIs for ContentType Model, retrying with ContentType ModelReference: " + "Recieved exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + str(ex) ) model_version = get_hub_model_version( hub_model_name=model_name, - hub_model_type=HubContentType.MODEL_REFERENCE.value, + hub_model_type=HubContentType.MODEL.value, hub_name=self.hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name, + hub_name=self.hub_name if not hub_name else hub_name, hub_content_name=model_name, hub_content_version=model_version, - hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_content_type=HubContentType.MODEL.value, ) return DescribeHubContentResponse(hub_content_description) From 65b61a69becb46df25eb6e2660d0b019e5ae7466 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 15:04:37 -0400 Subject: [PATCH 23/25] fix: describe_model if hub_name has been explicitly provided --- src/sagemaker/jumpstart/hub/hub.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index c61866e151..c764fa20b1 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -277,13 +277,13 @@ def describe_model( model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL_REFERENCE.value, - hub_name=self.hub_name, + hub_name=self.hub_name if not hub_name else hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( - hub_name=self.hub_name, + hub_name=self.hub_name if not hub_name else hub_name, hub_content_name=model_name, hub_content_version=model_version, hub_content_type=HubContentType.MODEL_REFERENCE.value, @@ -297,7 +297,7 @@ def describe_model( model_version = get_hub_model_version( hub_model_name=model_name, hub_model_type=HubContentType.MODEL.value, - hub_name=self.hub_name, + hub_name=self.hub_name if not hub_name else hub_name, sagemaker_session=self._sagemaker_session, hub_model_version=model_version, ) From d8b173d813dbecfde00b9b8bdfb9d8733c266338 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Tue, 9 Jul 2024 17:17:15 -0400 Subject: [PATCH 24/25] Address comments --- src/sagemaker/jumpstart/accessors.py | 2 +- src/sagemaker/jumpstart/hub/hub.py | 2 +- src/sagemaker/jumpstart/hub/utils.py | 2 +- src/sagemaker/jumpstart/model.py | 2 +- tests/unit/sagemaker/jumpstart/hub/test_hub.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 4cf544b93e..20a2d16c15 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -301,7 +301,7 @@ def get_model_specs( except Exception as ex: logging.info( - "Recieved exeption while calling APIs for ContentType ModelReference, \ + "Received exeption while calling APIs for ContentType ModelReference, \ retrying with ContentType Model: " + str(ex) ) diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py index c764fa20b1..69d1dbb5c1 100644 --- a/src/sagemaker/jumpstart/hub/hub.py +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -291,7 +291,7 @@ def describe_model( except Exception as ex: logging.info( - "Recieved exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + "Received exeption while calling APIs for ContentType ModelReference, retrying with ContentType Model: " + str(ex) ) model_version = get_hub_model_version( diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index fc68ccb0fd..2624796b3f 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -196,7 +196,7 @@ def get_hub_model_version( """Returns available Jumpstart hub model version Raises: - ResourceNotFound: If the specified model is not found in the hub. + ClientError: If the specified model is not found in the hub. """ try: diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 45544e5202..6ab840f2bf 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -429,10 +429,10 @@ def attach( cls, endpoint_name: str, inference_component_name: Optional[str] = None, - hub_name: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_name: Optional[str] = None, ) -> "JumpStartModel": """Attaches a JumpStartModel object to an existing SageMaker Endpoint. diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py index e2085e5ab9..8522b33bc3 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_hub.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -182,13 +182,13 @@ def test_describe_model_success(mock_describe_hub_content_response, sagemaker_se hub.describe_model("test-model") mock_list_hub_content_versions.assert_called_with( - hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="Model" + hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="ModelReference" ) sagemaker_session.describe_hub_content.assert_called_with( hub_name=HUB_NAME, hub_content_name="test-model", hub_content_version="3.0", - hub_content_type="Model", + hub_content_type="ModelReference", ) From 2f5f29be7e4e85bd60f8aec514a7e6abe47d0739 Mon Sep 17 00:00:00 2001 From: Malav Shastri Date: Wed, 10 Jul 2024 13:30:30 -0400 Subject: [PATCH 25/25] Address merge conflicts --- tests/unit/sagemaker/jumpstart/test_utils.py | 36 +++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 79654dacd4..07c49a308c 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import os from unittest import TestCase +from unittest.mock import call from botocore.exceptions import ClientError from mock.mock import Mock, patch @@ -1932,7 +1933,40 @@ def test_get_default_jumpstart_session_with_user_agent_suffix( botocore_session=mock_boto3_session.get_session.return_value, ) mock_boto3_client.assert_has_calls( -======= + [ + call( + "sagemaker", + region_name=JUMPSTART_DEFAULT_REGION_NAME, + config=mock_botocore_config.return_value, + ), + call( + "sagemaker-runtime", + region_name=JUMPSTART_DEFAULT_REGION_NAME, + config=mock_botocore_config.return_value, + ), + ], + any_order=True, + ) + + @patch("botocore.client.BaseClient._make_request") + def test_get_default_jumpstart_session_with_user_agent_suffix_http_header( + self, + mock_make_request, + ): + session = utils.get_default_jumpstart_session_with_user_agent_suffix( + "model_id", "model_version" + ) + try: + session.sagemaker_client.list_endpoints() + except Exception: + pass + + assert ( + "md/js_model_id#model_id md/js_model_ver#model_version" + in mock_make_request.call_args[0][1]["headers"]["User-Agent"] + ) + + def test_extract_metrics_from_deployment_configs(): configs = get_base_deployment_configs_metadata() configs[0].benchmark_metrics = None