From 082f727cab85693f53b854f3daa520dfdc9c1b6b Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 28 Feb 2024 15:56:45 +0000 Subject: [PATCH 1/2] add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor --- src/sagemaker/accept_types.py | 4 +++ src/sagemaker/content_types.py | 4 +++ src/sagemaker/deserializers.py | 4 +++ src/sagemaker/jumpstart/artifacts/kwargs.py | 8 +++++ .../jumpstart/artifacts/predictors.py | 16 ++++++++++ .../jumpstart/artifacts/script_uris.py | 4 +++ src/sagemaker/jumpstart/curated_hub/utils.py | 10 +++---- src/sagemaker/jumpstart/estimator.py | 5 ++-- src/sagemaker/jumpstart/factory/model.py | 19 ++++++++++++ src/sagemaker/jumpstart/model.py | 12 ++++++++ src/sagemaker/jumpstart/types.py | 4 +++ src/sagemaker/predictor.py | 4 +++ src/sagemaker/serializers.py | 4 +++ .../jumpstart/curated_hub/test_utils.py | 29 ++++++------------- .../jumpstart/estimator/test_estimator.py | 1 + .../sagemaker/jumpstart/model/test_model.py | 2 ++ .../sagemaker/jumpstart/test_predictor.py | 1 + 17 files changed, 104 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index bf081365ab..3f8d3171f3 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -72,6 +72,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -85,6 +86,8 @@ def retrieve_default( retrieve the default accept type. (Default: None). model_version (str): The version of the model for which to retrieve the default accept type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -110,6 +113,7 @@ def retrieve_default( return artifacts._retrieve_default_accept_type( model_id, model_version, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index e43e96be17..9b70ecce96 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -72,6 +72,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -85,6 +86,8 @@ def retrieve_default( retrieve the default content type. (Default: None). model_version (str): The version of the model for which to retrieve the default content type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -110,6 +113,7 @@ def retrieve_default( return artifacts._retrieve_default_content_type( model_id, model_version, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 706ae56bda..51ad335ffe 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -92,6 +92,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -105,6 +106,8 @@ def retrieve_default( retrieve the default deserializer. (Default: None). model_version (str): The version of the model for which to retrieve the default deserializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -131,6 +134,7 @@ def retrieve_default( return artifacts._retrieve_default_deserializer( model_id, model_version, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 69b8bfb51a..c9edeb2e76 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -31,6 +31,7 @@ def _retrieve_model_init_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -43,6 +44,8 @@ def _retrieve_model_init_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -66,6 +69,7 @@ def _retrieve_model_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -85,6 +89,7 @@ def _retrieve_model_deploy_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -99,6 +104,8 @@ def _retrieve_model_deploy_kwargs( kwargs. instance_type (str): Instance type of the hosting endpoint, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -123,6 +130,7 @@ def _retrieve_model_deploy_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 8d599c89cc..2b24a8ac8c 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -72,6 +72,7 @@ def _retrieve_deserializer_from_accept_type( def _retrieve_default_deserializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -84,6 +85,8 @@ def _retrieve_default_deserializer( retrieve the default deserializer. model_version (str): Version of the JumpStart model for which to retrieve the default deserializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default deserializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -104,6 +107,7 @@ def _retrieve_default_deserializer( default_accept_type = _retrieve_default_accept_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -116,6 +120,7 @@ def _retrieve_default_deserializer( def _retrieve_default_serializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -128,6 +133,8 @@ def _retrieve_default_serializer( retrieve the default serializer. model_version (str): Version of the JumpStart model for which to retrieve the default serializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default serializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -147,6 +154,7 @@ def _retrieve_default_serializer( default_content_type = _retrieve_default_content_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -273,6 +281,7 @@ def _retrieve_serializer_options( def _retrieve_default_content_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -285,6 +294,8 @@ def _retrieve_default_content_type( retrieve the default content type. model_version (str): Version of the JumpStart model for which to retrieve the default content type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve default content type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -307,6 +318,7 @@ def _retrieve_default_content_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -321,6 +333,7 @@ def _retrieve_default_content_type( def _retrieve_default_accept_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -333,6 +346,8 @@ def _retrieve_default_accept_type( retrieve the default accept type. model_version (str): Version of the JumpStart model for which to retrieve the default accept type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default accept type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -355,6 +370,7 @@ def _retrieve_default_accept_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index c04ae88ca3..f36cf1272f 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -107,6 +107,7 @@ def _retrieve_script_uri( def _model_supports_inference_script_uri( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -119,6 +120,8 @@ def _model_supports_inference_script_uri( retrieve the support status for script uri with inference. model_version (str): Version of the JumpStart model for which to retrieve the support status for script uri with inference. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the support status for script uri with inference. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -142,6 +145,7 @@ def _model_supports_inference_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py index 5c7a91382b..7758277ee1 100644 --- a/src/sagemaker/jumpstart/curated_hub/utils.py +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -90,15 +90,15 @@ def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: # TODO: Update to recognize JumpStartHub hub_name -def generate_hub_arn_for_estimator_init_kwargs( +def generate_hub_arn_for_init_kwargs( hub_name: str, region: Optional[str] = None, session: Optional[Session] = None ): - """Generates the Hub Arn for JumpStartEstimator from a HubName or Arn. + """Generates the Hub Arn for JumpStart class args from a HubName or Arn. Args: - hub_name (str): HubName or HubArn from JumpStartEstimator args - region (str): Region from JumpStartEstimator args - session (Session): Custom SageMaker Session from JumpStartEstimator args + hub_name (str): HubName or HubArn from JumpStart class args + region (str): Region from JumpStart class args + session (Session): Custom SageMaker Session from JumpStart class args """ hub_arn = None diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 53d12b46a6..d0706a9aab 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -28,7 +28,7 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION -from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_estimator_init_kwargs +from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -523,7 +523,7 @@ def _is_valid_model_id_hook(): hub_arn = None if hub_name: - hub_arn = generate_hub_arn_for_estimator_init_kwargs( + hub_arn = generate_hub_arn_for_init_kwargs( hub_name=hub_name, region=region, session=sagemaker_session ) @@ -1081,6 +1081,7 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 1dfe9ef5e2..a586081981 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -68,6 +68,7 @@ def get_default_predictor( predictor: Predictor, model_id: str, model_version: str, + hub_arn: Optional[str], region: str, tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, @@ -89,6 +90,7 @@ def get_default_predictor( predictor.serializer = serializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -97,6 +99,7 @@ def get_default_predictor( predictor.deserializer = deserializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -105,6 +108,7 @@ def get_default_predictor( predictor.accept = accept_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -113,6 +117,7 @@ def get_default_predictor( predictor.content_type = content_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -183,6 +188,7 @@ def _add_instance_type_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -208,6 +214,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel image_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -224,6 +231,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -261,6 +269,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode if _model_supports_inference_script_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -270,6 +279,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode script_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -289,6 +299,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod if _model_supports_inference_script_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -313,6 +324,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -361,6 +373,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI model_kwargs_to_add = _retrieve_model_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -396,6 +409,7 @@ def _add_endpoint_name_to_kwargs( default_endpoint_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -417,6 +431,7 @@ def _add_model_name_to_kwargs( default_model_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -436,6 +451,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -460,6 +476,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] deploy_kwargs_to_add = _retrieve_model_deploy_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -660,6 +677,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, instance_type: Optional[str] = None, @@ -691,6 +709,7 @@ def get_init_kwargs( model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, image_uri=image_uri, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 1742f860e4..68a5174b41 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -21,6 +21,7 @@ from sagemaker.base_serializers import BaseSerializer from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor +from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG from sagemaker.jumpstart.factory.model import ( @@ -58,6 +59,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -95,6 +97,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with @@ -284,12 +287,19 @@ def _is_valid_model_id_hook(): if not _is_valid_model_id_hook(): raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + self._model_data_is_set = model_data is not None model_init_kwargs = get_init_kwargs( model_id=model_id, model_from_estimator=False, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -319,6 +329,7 @@ def _is_valid_model_id_hook(): self.model_id = model_init_kwargs.model_id self.model_version = model_init_kwargs.model_version + self.hub_arn = model_init_kwargs.hub_arn self.instance_type = model_init_kwargs.instance_type self.resources = model_init_kwargs.resources self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model @@ -599,6 +610,7 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 22cc70fcab..68ec952ec7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1049,6 +1049,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -1079,6 +1080,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "instance_type", "model_id", "model_version", + "hub_arn", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", @@ -1090,6 +1092,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, instance_type: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1119,6 +1122,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.region = region self.image_uri = image_uri diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 42c2af0917..aaa1c1d797 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -39,6 +39,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, ) -> Predictor: @@ -56,6 +57,8 @@ def retrieve_default( retrieve the default predictor. (Default: None). model_version (str): The version of the model for which to retrieve the default predictor. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -103,6 +106,7 @@ def retrieve_default( predictor=predictor, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index fc76c0fa76..7f1d9413b0 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -91,6 +91,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -104,6 +105,8 @@ def retrieve_default( retrieve the default serializer. (Default: None). model_version (str): The version of the model for which to retrieve the default serializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -130,6 +133,7 @@ def retrieve_default( return artifacts._retrieve_default_serializer( model_id, model_version, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 2f0841b4ea..892a2ed980 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -96,7 +96,7 @@ def test_construct_hub_model_arn_from_inputs(): ) -def test_generate_hub_arn_for_estimator_init_kwargs(): +def test_generate_hub_arn_for_init_kwargs(): hub_name = "my-hub-name" hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" # Mock default session with default values @@ -109,45 +109,34 @@ def test_generate_hub_arn_for_estimator_init_kwargs(): mock_custom_session.boto_region_name = "us-east-2" assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, session=mock_default_session) + utils.generate_hub_arn_for_init_kwargs(hub_name, session=mock_default_session) == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" ) assert ( - utils.generate_hub_arn_for_estimator_init_kwargs( - hub_name, "us-east-1", session=mock_default_session - ) + utils.generate_hub_arn_for_init_kwargs(hub_name, "us-east-1", session=mock_default_session) == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" ) assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, "eu-west-1", mock_custom_session) + utils.generate_hub_arn_for_init_kwargs(hub_name, "eu-west-1", mock_custom_session) == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" ) assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, None, mock_custom_session) + utils.generate_hub_arn_for_init_kwargs(hub_name, None, mock_custom_session) == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" ) - assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, session=mock_default_session) - == hub_arn - ) + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, session=mock_default_session) == hub_arn assert ( - utils.generate_hub_arn_for_estimator_init_kwargs( - hub_arn, "us-east-1", session=mock_default_session - ) + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", session=mock_default_session) == hub_arn ) assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, "us-east-1", mock_custom_session) - == hub_arn + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn ) - assert ( - utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) - == hub_arn - ) + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index c8fc541816..d88961ebb7 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1322,6 +1322,7 @@ def test_no_predictor_returns_default_predictor( predictor=default_predictor, model_id=model_id, model_version="*", + hub_arn=None, region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index f45283935b..e7c00887fd 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -642,6 +642,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): assert js_class_init_args - parent_class_init_args == { "model_id", "model_version", + "hub_name", "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", @@ -713,6 +714,7 @@ def test_no_predictor_returns_default_predictor( predictor=default_predictor, model_id=model_id, model_version="*", + hub_arn=None, region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 7ab9cdd1cc..3cc2314a59 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -88,6 +88,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( predictor=patched_predictor.return_value, model_id="predictor-specs-model", model_version="1.2.3", + hub_arn=None, region=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, From 5ecf9e4dfc1dc6e6fdc25af8cf9203cb99676503 Mon Sep 17 00:00:00 2001 From: Benjamin Crabtree Date: Wed, 28 Feb 2024 17:50:45 +0000 Subject: [PATCH 2/2] update for resource requirements and model package --- src/sagemaker/accept_types.py | 26 +++++++++++-------- src/sagemaker/content_types.py | 26 +++++++++++-------- src/sagemaker/deserializers.py | 26 +++++++++++-------- .../jumpstart/artifacts/model_packages.py | 4 +++ .../jumpstart/artifacts/predictors.py | 16 ++++++++++++ .../artifacts/resource_requirements.py | 4 +++ src/sagemaker/jumpstart/factory/model.py | 4 +++ src/sagemaker/jumpstart/model.py | 1 + src/sagemaker/jumpstart/types.py | 4 +++ src/sagemaker/resource_requirements.py | 16 +++++++----- src/sagemaker/serializers.py | 26 +++++++++++-------- 11 files changed, 103 insertions(+), 50 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 3f8d3171f3..14212fd991 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -23,6 +23,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -36,6 +37,8 @@ def retrieve_options( retrieve the supported accept types. (Default: None). model_version (str): The version of the model for which to retrieve the supported accept types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -59,11 +62,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_accept_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -111,11 +115,11 @@ def retrieve_default( ) return artifacts._retrieve_default_accept_type( - model_id, - model_version, - hub_arn, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 9b70ecce96..5e82201c31 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -23,6 +23,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -36,6 +37,8 @@ def retrieve_options( retrieve the supported content types. (Default: None). model_version (str): The version of the model for which to retrieve the supported content types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -59,11 +62,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_content_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -111,12 +115,12 @@ def retrieve_default( ) return artifacts._retrieve_default_content_type( - model_id, - model_version, - hub_arn, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 51ad335ffe..7bb08ce15a 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -42,6 +42,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -55,6 +56,8 @@ def retrieve_options( retrieve the supported deserializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported deserializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -79,11 +82,12 @@ def retrieve_options( ) return artifacts._retrieve_deserializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -132,11 +136,11 @@ def retrieve_default( ) return artifacts._retrieve_default_deserializer( - model_id, - model_version, - hub_arn, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index e49d14682d..4a0fc147d5 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -31,6 +31,7 @@ def _retrieve_model_package_arn( model_version: str, instance_type: Optional[str], region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -46,6 +47,8 @@ def _retrieve_model_package_arn( instance_type (Optional[str]): An instance type to optionally supply in order to get an arn specific for the instance type. region (Optional[str]): Region for which to retrieve the model package arn. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). scope (Optional[str]): Scope for which to retrieve the model package arn. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -69,6 +72,7 @@ def _retrieve_model_package_arn( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 2b24a8ac8c..35fe4e3dcf 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -167,6 +167,7 @@ def _retrieve_default_serializer( def _retrieve_deserializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -179,6 +180,8 @@ def _retrieve_deserializer_options( retrieve the supported deserializers. model_version (str): Version of the JumpStart model for which to retrieve the supported deserializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve deserializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -198,6 +201,7 @@ def _retrieve_deserializer_options( supported_accept_types = _retrieve_supported_accept_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -224,6 +228,7 @@ def _retrieve_deserializer_options( def _retrieve_serializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -236,6 +241,8 @@ def _retrieve_serializer_options( retrieve the supported serializers. model_version (str): Version of the JumpStart model for which to retrieve the supported serializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve serializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -255,6 +262,7 @@ def _retrieve_serializer_options( supported_content_types = _retrieve_supported_content_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -386,6 +394,7 @@ def _retrieve_default_accept_type( def _retrieve_supported_accept_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -398,6 +407,8 @@ def _retrieve_supported_accept_types( retrieve the supported accept types. model_version (str): Version of the JumpStart model for which to retrieve the supported accept types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve accept type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -420,6 +431,7 @@ def _retrieve_supported_accept_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -435,6 +447,7 @@ def _retrieve_supported_accept_types( def _retrieve_supported_content_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -447,6 +460,8 @@ def _retrieve_supported_content_types( retrieve the supported content types. model_version (str): Version of the JumpStart model for which to retrieve the supported content types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve content type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -469,6 +484,7 @@ def _retrieve_supported_content_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 8356d1efac..ecf6d1b5ea 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -33,6 +33,7 @@ def _retrieve_default_resources( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -47,6 +48,8 @@ def _retrieve_default_resources( default resource requirements. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve default resource requirements. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -76,6 +79,7 @@ def _retrieve_default_resources( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index a586081981..0e1dbfe07d 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -355,6 +355,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, @@ -595,6 +596,7 @@ def get_deploy_kwargs( def get_register_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -626,6 +628,7 @@ def get_register_kwargs( register_kwargs = JumpStartModelRegisterKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -656,6 +659,7 @@ def get_register_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 68a5174b41..c0da00ac56 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -696,6 +696,7 @@ def register( register_kwargs = get_register_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 68ec952ec7..99753a3763 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1657,6 +1657,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", "content_types", "response_types", @@ -1687,6 +1688,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", } @@ -1694,6 +1696,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -1724,6 +1727,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.region = region self.image_uri = image_uri self.sagemaker_session = sagemaker_session diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index 446d034bf3..f0be00ea09 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -43,6 +44,8 @@ def retrieve_default( retrieve the default resource requirements. (Default: None). model_version (str): The version of the model for which to retrieve the default resource requirements. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -72,11 +75,12 @@ def retrieve_default( raise ValueError("Must specify scope for resource requirements.") return artifacts._retrieve_default_resources( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope=scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 7f1d9413b0..43b5a9fa34 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -41,6 +41,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -54,6 +55,8 @@ def retrieve_options( retrieve the supported serializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported serializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -78,11 +81,12 @@ def retrieve_options( ) return artifacts._retrieve_serializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -131,11 +135,11 @@ def retrieve_default( ) return artifacts._retrieve_default_serializer( - model_id, - model_version, - hub_arn, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, )