From 63431d5d2131e481f27e20d586d0ca957a6067bc Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 7 Aug 2023 17:21:17 +0000 Subject: [PATCH 01/13] fix: jumpstart cache using sagemaker session s3 client --- src/sagemaker/accept_types.py | 13 ++ src/sagemaker/content_types.py | 13 ++ src/sagemaker/deserializers.py | 13 ++ src/sagemaker/environment_variables.py | 7 + src/sagemaker/hyperparameters.py | 24 ++++ src/sagemaker/image_uris.py | 7 + src/sagemaker/instance_types.py | 13 ++ src/sagemaker/jumpstart/accessors.py | 47 ++++++- .../artifacts/environment_variables.py | 8 +- .../jumpstart/artifacts/hyperparameters.py | 7 + .../jumpstart/artifacts/image_uris.py | 8 +- .../artifacts/incremental_training.py | 7 + .../jumpstart/artifacts/instance_types.py | 13 ++ src/sagemaker/jumpstart/artifacts/kwargs.py | 27 +++- .../jumpstart/artifacts/metric_definitions.py | 8 +- .../jumpstart/artifacts/model_packages.py | 14 +- .../jumpstart/artifacts/model_uris.py | 13 ++ .../jumpstart/artifacts/predictors.py | 56 +++++++- .../jumpstart/artifacts/resource_names.py | 7 + .../jumpstart/artifacts/script_uris.py | 13 ++ src/sagemaker/jumpstart/cache.py | 4 +- src/sagemaker/jumpstart/estimator.py | 3 + src/sagemaker/jumpstart/factory/estimator.py | 17 ++- src/sagemaker/jumpstart/factory/model.py | 25 +++- src/sagemaker/jumpstart/model.py | 3 + src/sagemaker/jumpstart/types.py | 9 ++ src/sagemaker/jumpstart/utils.py | 20 ++- src/sagemaker/jumpstart/validators.py | 16 +++ src/sagemaker/metric_definitions.py | 13 +- src/sagemaker/model_uris.py | 7 + src/sagemaker/predictor.py | 5 +- src/sagemaker/script_uris.py | 7 + src/sagemaker/serializers.py | 13 ++ .../jumpstart/test_accept_types.py | 13 +- .../jumpstart/test_content_types.py | 13 +- .../jumpstart/test_deserializers.py | 13 +- .../jumpstart/test_default.py | 32 +++-- .../hyperparameters/jumpstart/test_default.py | 31 ++++- .../jumpstart/test_validate.py | 21 ++- .../image_uris/jumpstart/test_common.py | 22 ++- .../jumpstart/test_instance_types.py | 47 +++++-- .../jumpstart/estimator/test_estimator.py | 7 + .../sagemaker/jumpstart/model/test_model.py | 5 + tests/unit/sagemaker/jumpstart/test_utils.py | 129 +++++++++++------- tests/unit/sagemaker/jumpstart/utils.py | 12 +- .../jumpstart/test_default.py | 16 ++- .../model_uris/jumpstart/test_common.py | 21 ++- .../script_uris/jumpstart/test_common.py | 18 ++- .../serializers/jumpstart/test_serializers.py | 19 ++- 49 files changed, 749 insertions(+), 130 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 22f41812ca..7fb3aaa441 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -15,6 +15,7 @@ from typing import List, Optional from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.session import Session def retrieve_options( @@ -23,6 +24,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[str]: """Retrieves the supported accept types for the model matching the given arguments. @@ -40,6 +42,10 @@ def retrieve_options( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: list: The supported accept types to use for the model. @@ -57,6 +63,7 @@ def retrieve_options( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -66,6 +73,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -83,6 +91,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: The default accept type to use for the model. @@ -100,4 +112,5 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 55655c4065..9bac693e9a 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -15,6 +15,7 @@ from typing import List, Optional from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.session import Session def retrieve_options( @@ -23,6 +24,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[str]: """Retrieves the supported content types for the model matching the given arguments. @@ -40,6 +42,10 @@ def retrieve_options( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: list: The supported content types to use for the model. @@ -57,6 +63,7 @@ def retrieve_options( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -66,6 +73,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -83,6 +91,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: The default content type to use for the model. @@ -100,6 +112,7 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 1a43c98589..0cb6ed4180 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -33,6 +33,7 @@ ) from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.session import Session def retrieve_options( @@ -41,6 +42,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model matching the given arguments. @@ -58,6 +60,10 @@ def retrieve_options( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: List[BaseDeserializer]: The supported deserializers to use for the model. @@ -76,6 +82,7 @@ def retrieve_options( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -85,6 +92,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -102,6 +110,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: BaseDeserializer: The default deserializer to use for the model. @@ -120,4 +132,5 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 864bffc5bf..e0e6daa0cb 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,6 +31,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, + sagemaker_session: Session = Session(), ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. @@ -51,6 +53,10 @@ def retrieve_default( should be included. The `Model` class of the SageMaker Python SDK inserts environment variables that would be required when making the low-level AWS API call. (Default: True). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: dict: The variables to use for the model. @@ -70,4 +76,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, include_aws_sdk_env_vars, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 7fa4a14414..46f43f2e85 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -21,6 +21,7 @@ from sagemaker.jumpstart import artifacts from sagemaker.jumpstart.enums import HyperparameterValidationMode from sagemaker.jumpstart.validators import validate_hyperparameters +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -32,6 +33,7 @@ def retrieve_default( include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -56,6 +58,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: dict: The hyperparameters to use for the model. @@ -74,6 +80,7 @@ def retrieve_default( include_container_hyperparameters, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -83,6 +90,9 @@ def validate( model_version: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> None: """Validates hyperparameters for models. @@ -100,6 +110,17 @@ def validate( If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. (Default: None). + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Raises: JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, @@ -125,4 +146,7 @@ def validate( hyperparameters=hyperparameters, validation_mode=validation_mode, region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index fc28d91c9f..72947a68e3 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -22,6 +22,7 @@ from sagemaker import utils from sagemaker.jumpstart.utils import is_jumpstart_model_input +from sagemaker.session import Session from sagemaker.spark import defaults from sagemaker.jumpstart import artifacts from sagemaker.workflow import is_pipeline_variable @@ -60,6 +61,7 @@ def retrieve( sdk_version=None, inference_tool=None, serverless_inference_config=None, + sagemaker_session=Session(), ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -109,6 +111,10 @@ def retrieve( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to determine processor type. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: The ECR URI for the corresponding SageMaker Docker image. @@ -147,6 +153,7 @@ def retrieve( training_compiler_config, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 436ad7bb3f..7639bd8d59 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,6 +31,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -49,6 +51,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: The default instance type to use for the model. @@ -70,6 +76,7 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -80,6 +87,7 @@ def retrieve( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[str]: """Retrieves the supported training instance types for the model matching the given arguments. @@ -97,6 +105,10 @@ def retrieve( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: list: The supported instance types to use for the model. @@ -118,4 +130,5 @@ def retrieve( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index e07564d362..4292130311 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -13,6 +13,7 @@ """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Any, Dict, List, Optional +import boto3 from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs @@ -74,14 +75,24 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: region (str): region for which to retrieve header/spec. cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``. """ - if JumpStartModelsAccessor._cache is None or region != JumpStartModelsAccessor._curr_region: + new_cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + cache_kwargs, region + ) + if ( + JumpStartModelsAccessor._cache is None + or region != JumpStartModelsAccessor._curr_region + or new_cache_kwargs != JumpStartModelsAccessor._cache_kwargs + ): JumpStartModelsAccessor._cache = cache.JumpStartModelsCache( region=region, **cache_kwargs ) JumpStartModelsAccessor._curr_region = region + JumpStartModelsAccessor._cache_kwargs = new_cache_kwargs @staticmethod - def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]: + def _get_manifest( + region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None + ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest. Raises: @@ -90,11 +101,20 @@ def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStart Args: region (str): Optional. The region to use for the cache. """ + + old_cache_kwargs = JumpStartModelsAccessor._cache_kwargs.copy() + + additional_kwargs = {} + if s3_client is not None: + additional_kwargs.update({"s3_client": s3_client}) + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( - JumpStartModelsAccessor._cache_kwargs, region + {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - return JumpStartModelsAccessor._cache.get_manifest() # type: ignore + manifest = JumpStartModelsAccessor._cache.get_manifest() # type: ignore + JumpStartModelsAccessor._cache_kwargs = old_cache_kwargs + return manifest @staticmethod def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: @@ -114,21 +134,34 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel ) @staticmethod - def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs: + def get_model_specs( + region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None + ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. Args: region (str): region for which to retrieve header. model_id (str): model ID to retrieve. version (str): semantic version to retrieve for the model ID. + s3_client (boto3.client): boto3 client to use for accessing JumpStart models s3 cache. + If not set, a default client will be made. """ + + old_cache_kwargs = JumpStartModelsAccessor._cache_kwargs.copy() + + additional_kwargs = {} + if s3_client is not None: + additional_kwargs.update({"s3_client": s3_client}) + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( - JumpStartModelsAccessor._cache_kwargs, region + {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - return JumpStartModelsAccessor._cache.get_specs( # type: ignore + specs = JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version ) + JumpStartModelsAccessor._cache_kwargs = old_cache_kwargs + return specs @staticmethod def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None: diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index a6c2ba0f58..4818f6c430 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -22,6 +22,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_default_environment_variables( @@ -31,6 +32,7 @@ def _retrieve_default_environment_variables( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, + sagemaker_session: Session = Session(), ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -52,7 +54,10 @@ def _retrieve_default_environment_variables( should be included. The `Model` class of the SageMaker Python SDK inserts environment variables that would be required when making the low-level AWS API call. (Default: True). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: dict: the inference environment variables to use for the model. """ @@ -67,6 +72,7 @@ def _retrieve_default_environment_variables( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) default_environment_variables: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index fc3a93212c..403c94a060 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -23,6 +23,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_default_hyperparameters( @@ -32,6 +33,7 @@ def _retrieve_default_hyperparameters( include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -56,6 +58,10 @@ def _retrieve_default_hyperparameters( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: dict: the hyperparameters to use for the model. """ @@ -70,6 +76,7 @@ def _retrieve_default_hyperparameters( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) default_hyperparameters: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 42819b6eec..d9f4917c42 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_image_uri( @@ -43,6 +44,7 @@ def _retrieve_image_uri( training_compiler_config: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ): """Retrieves the container image URI for JumpStart models. @@ -88,7 +90,10 @@ def _retrieve_image_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -108,6 +113,7 @@ def _retrieve_image_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if image_scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index fc78ef7b70..ba32e54764 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -22,6 +22,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _model_supports_incremental_training( @@ -30,6 +31,7 @@ def _model_supports_incremental_training( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> bool: """Returns True if the model supports incremental training. @@ -47,6 +49,10 @@ def _model_supports_incremental_training( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: bool: the support status for incremental training. """ @@ -61,6 +67,7 @@ def _model_supports_incremental_training( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.supports_incremental_training() diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 841125bbbe..3cdf51073f 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -25,6 +25,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_default_instance_type( @@ -34,6 +35,7 @@ def _retrieve_default_instance_type( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> str: """Retrieves the default instance type for the model. @@ -53,6 +55,10 @@ def _retrieve_default_instance_type( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the default instance type to use for the model or None. @@ -71,6 +77,7 @@ def _retrieve_default_instance_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if scope == JumpStartScriptScope.INFERENCE: @@ -94,6 +101,7 @@ def _retrieve_instance_types( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[str]: """Retrieves the supported instance types for the model. @@ -113,6 +121,10 @@ def _retrieve_instance_types( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: list: the supported instance types to use for the model or None. @@ -131,6 +143,7 @@ def _retrieve_instance_types( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index b2db7f0f80..41855ad598 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from copy import deepcopy from typing import Optional +from sagemaker.session import Session from sagemaker.utils import volume_size_supported from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, @@ -32,6 +33,7 @@ def _retrieve_model_init_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> dict: """Retrieves kwargs for `Model`. @@ -49,7 +51,10 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: dict: the kwargs to use for the use case. """ @@ -64,6 +69,7 @@ def _retrieve_model_init_kwargs( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -81,6 +87,7 @@ def _retrieve_model_deploy_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -100,6 +107,10 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: dict: the kwargs to use for the use case. @@ -115,6 +126,7 @@ def _retrieve_model_deploy_kwargs( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: @@ -130,6 +142,7 @@ def _retrieve_estimator_init_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> dict: """Retrieves kwargs for `Estimator`. @@ -149,7 +162,10 @@ def _retrieve_estimator_init_kwargs( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: dict: the kwargs to use for the use case. """ @@ -164,6 +180,7 @@ def _retrieve_estimator_init_kwargs( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) kwargs = deepcopy(model_specs.estimator_kwargs) @@ -183,6 +200,7 @@ def _retrieve_estimator_fit_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -200,6 +218,10 @@ def _retrieve_estimator_fit_kwargs( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: dict: the kwargs to use for the use case. @@ -215,6 +237,7 @@ def _retrieve_estimator_fit_kwargs( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.fit_kwargs diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 6921cb8473..f32d82952f 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -23,6 +23,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_default_training_metric_definitions( @@ -31,6 +32,7 @@ def _retrieve_default_training_metric_definitions( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -48,7 +50,10 @@ def _retrieve_default_training_metric_definitions( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: list: the default training metric definitions to use for the model or None. """ @@ -63,6 +68,7 @@ def _retrieve_default_training_metric_definitions( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return deepcopy(model_specs.metrics) if model_specs.metrics else None diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 0fbf1d74b3..67ac107091 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -22,6 +22,7 @@ from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) +from sagemaker.session import Session def _retrieve_model_package_arn( @@ -31,6 +32,7 @@ def _retrieve_model_package_arn( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -48,6 +50,10 @@ def _retrieve_model_package_arn( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the model package arn to use for the model or None. @@ -63,6 +69,7 @@ def _retrieve_model_package_arn( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if scope == JumpStartScriptScope.INFERENCE: @@ -84,6 +91,7 @@ def _retrieve_model_package_model_artifact_s3_uri( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -103,7 +111,10 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the model package artifact uri to use for the model or None. @@ -123,6 +134,7 @@ def _retrieve_model_package_model_artifact_s3_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if model_specs.training_model_package_artifact_uris is None: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 31505e0865..4f5f5ab6c2 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -26,6 +26,7 @@ get_jumpstart_content_bucket, verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_model_uri( @@ -35,6 +36,7 @@ def _retrieve_model_uri( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -55,6 +57,10 @@ def _retrieve_model_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the model artifact S3 URI for the corresponding model. @@ -74,6 +80,7 @@ def _retrieve_model_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if model_scope == JumpStartScriptScope.INFERENCE: @@ -100,6 +107,7 @@ def _model_supports_training_model_uri( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> bool: """Returns True if the model supports training with model uri field. @@ -117,6 +125,10 @@ def _model_supports_training_model_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: bool: the support status for model uri with training. """ @@ -131,6 +143,7 @@ def _model_supports_training_model_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.use_training_model_artifact() diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 473739de0d..cd75b5b358 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -29,6 +29,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_serializer_from_content_type( @@ -73,6 +74,7 @@ def _retrieve_default_deserializer( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -89,6 +91,10 @@ def _retrieve_default_deserializer( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -100,6 +106,7 @@ def _retrieve_default_deserializer( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -111,6 +118,7 @@ def _retrieve_default_serializer( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -127,7 +135,10 @@ def _retrieve_default_serializer( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -138,6 +149,7 @@ def _retrieve_default_serializer( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -149,6 +161,7 @@ def _retrieve_deserializer_options( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -165,7 +178,10 @@ def _retrieve_deserializer_options( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -176,6 +192,7 @@ def _retrieve_deserializer_options( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) seen_classes: Set[Type] = set() @@ -201,6 +218,7 @@ def _retrieve_serializer_options( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model. @@ -217,7 +235,10 @@ def _retrieve_serializer_options( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -228,6 +249,7 @@ def _retrieve_serializer_options( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) seen_classes: Set[Type] = set() @@ -253,6 +275,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> str: """Retrieves the default content type for the model. @@ -269,7 +292,10 @@ def _retrieve_default_content_type( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the default content type to use for the model. """ @@ -284,6 +310,7 @@ def _retrieve_default_content_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -296,6 +323,7 @@ def _retrieve_default_accept_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> str: """Retrieves the default accept type for the model. @@ -312,7 +340,10 @@ def _retrieve_default_accept_type( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the default accept type to use for the model. """ @@ -327,6 +358,7 @@ def _retrieve_default_accept_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -340,6 +372,7 @@ def _retrieve_supported_accept_types( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[str]: """Retrieves the supported accept types for the model. @@ -356,7 +389,10 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: list: the supported accept types to use for the model. """ @@ -371,6 +407,7 @@ def _retrieve_supported_accept_types( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -384,6 +421,7 @@ def _retrieve_supported_content_types( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[str]: """Retrieves the supported content types for the model. @@ -400,7 +438,10 @@ def _retrieve_supported_content_types( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: list: the supported content types to use for the model. """ @@ -415,6 +456,7 @@ def _retrieve_supported_content_types( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index ca20cabdda..908e18b7ae 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -22,6 +22,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_resource_name_base( @@ -30,6 +31,7 @@ def _retrieve_resource_name_base( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> bool: """Returns default resource name. @@ -47,6 +49,10 @@ def _retrieve_resource_name_base( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the default resource name. """ @@ -61,6 +67,7 @@ def _retrieve_resource_name_base( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.resource_name_base diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index f1bc4f31dd..ceba059b60 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -25,6 +25,7 @@ get_jumpstart_content_bucket, verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_script_uri( @@ -34,6 +35,7 @@ def _retrieve_script_uri( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -55,6 +57,10 @@ def _retrieve_script_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: the model script URI for the corresponding model. @@ -74,6 +80,7 @@ def _retrieve_script_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if script_scope == JumpStartScriptScope.INFERENCE: @@ -98,6 +105,7 @@ def _model_supports_inference_script_uri( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> bool: """Returns True if the model supports inference with script uri field. @@ -115,6 +123,10 @@ def _model_supports_inference_script_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: bool: the support status for script uri with inference. """ @@ -129,6 +141,7 @@ def _model_supports_inference_script_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.use_inference_script_uri() diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 0a7a29a8cd..fe3e2224c8 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -68,6 +68,7 @@ def __init__( JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, + s3_client: Optional[boto3.client] = None, ) -> None: # fmt: on """Initialize a ``JumpStartModelsCache`` instance. @@ -88,6 +89,7 @@ def __init__( Default: JumpStart-hosted content bucket for region. s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). + s3_client (Optional[boto3.client]): s3 client to use. Default: None. """ self._region = region @@ -109,7 +111,7 @@ def __init__( if s3_bucket_name is None else s3_bucket_name ) - self._s3_client = ( + self._s3_client = s3_client if s3_client else ( boto3.client("s3", region_name=self._region, config=s3_client_config) if s3_client_config else boto3.client("s3", region_name=self._region) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 129692262b..3948bf5775 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -502,6 +502,7 @@ def _is_valid_model_id_hook(): model_version=model_version, region=region, script=JumpStartScriptScope.TRAINING, + sagemaker_session=sagemaker_session, ) if not _is_valid_model_id_hook(): @@ -649,6 +650,7 @@ def fit( experiment_config=experiment_config, tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, + sagemaker_session=self.sagemaker_session, ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -991,6 +993,7 @@ def deploy( region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 92fccb39b0..612c40f7f5 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -208,6 +208,7 @@ def get_fit_kwargs( experiment_config: Optional[Dict[str, str]] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, + sagemaker_session: Optional[Session] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -222,6 +223,7 @@ def get_fit_kwargs( experiment_config=experiment_config, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs) @@ -298,6 +300,7 @@ def get_deploy_kwargs( explainer_config=explainer_config, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( @@ -314,7 +317,7 @@ def get_deploy_kwargs( role=role, name=model_name, vpc_config=vpc_config, - sagemaker_session=sagemaker_session, + sagemaker_session=model_deploy_kwargs.sagemaker_session, enable_network_isolation=enable_network_isolation, model_kms_key=model_kms_key, image_config=image_config, @@ -420,6 +423,7 @@ def _add_instance_type_and_count_to_kwargs( scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.instance_count = kwargs.instance_count or 1 @@ -444,6 +448,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) return kwargs @@ -458,6 +463,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ): default_model_uri = model_uris.retrieve( model_scope=JumpStartScriptScope.TRAINING, @@ -465,6 +471,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) if ( @@ -476,6 +483,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) ): JUMPSTART_LOGGER.warning( @@ -509,6 +517,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) return kwargs @@ -526,6 +535,7 @@ def _add_env_to_kwargs( scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) if model_package_artifact_uri: @@ -560,6 +570,7 @@ def _add_training_job_name_to_kwargs( region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.job_name = kwargs.job_name or ( @@ -584,6 +595,7 @@ def _add_hyperparameters_to_kwargs( model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in default_hyperparameters.items(): @@ -615,6 +627,7 @@ def _add_metric_definitions_to_kwargs( model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) or [] ) @@ -643,6 +656,7 @@ def _add_estimator_extra_kwargs( region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in estimator_kwargs_to_add.items(): @@ -666,6 +680,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in fit_kwargs_to_add.items(): diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index ac5640ff24..5455179831 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -59,6 +59,7 @@ def get_default_predictor( region: str, tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, + sagemaker_session: Session, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -79,6 +80,7 @@ def get_default_predictor( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -86,6 +88,7 @@ def get_default_predictor( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -93,6 +96,7 @@ def get_default_predictor( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -100,6 +104,7 @@ def get_default_predictor( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) return predictor @@ -113,7 +118,9 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni return kwargs -def _add_sagemaker_session_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: +def _add_sagemaker_session_to_kwargs( + kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] +) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" kwargs.sagemaker_session = kwargs.sagemaker_session or Session() return kwargs @@ -165,6 +172,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) if orig_instance_type is None: @@ -188,6 +196,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) return kwargs @@ -205,6 +214,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) return kwargs @@ -221,6 +231,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ): source_dir = source_dir or script_uris.retrieve( script_scope=JumpStartScriptScope.INFERENCE, @@ -229,6 +240,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.source_dir = source_dir @@ -247,6 +259,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME @@ -271,6 +284,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in extra_env_vars.items(): @@ -298,6 +312,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.model_package_arn = model_package_arn @@ -313,6 +328,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in model_kwargs_to_add.items(): @@ -347,6 +363,7 @@ def _add_endpoint_name_to_kwargs( region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -367,6 +384,7 @@ def _add_model_name_to_kwargs( region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.name = kwargs.name or ( @@ -386,6 +404,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in deploy_kwargs_to_add.items(): @@ -418,6 +437,7 @@ def get_deploy_kwargs( explainer_config: Optional[ExplainerConfig] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, + sagemaker_session: Optional[Session] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -444,8 +464,11 @@ def get_deploy_kwargs( explainer_config=explainer_config, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) + deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index c6b663c0fa..76e2816879 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -261,6 +261,7 @@ def _is_valid_model_id_hook(): model_version=model_version, region=region, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=sagemaker_session, ) if not _is_valid_model_id_hook(): @@ -307,6 +308,7 @@ def _is_valid_model_id_hook(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.model_package_arn = model_init_kwargs.model_package_arn + self.sagemaker_session = model_init_kwargs.sagemaker_session super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) @@ -488,6 +490,7 @@ def deploy( region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 839dff7474..e643a2cda1 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -16,6 +16,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.session import Session + class JumpStartDataHolderType: """Base class for many JumpStart types. @@ -698,6 +700,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "explainer_config", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "sagemaker_session", ] SERIALIZATION_EXCLUSION_SET = { @@ -706,6 +709,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", + "sagemaker_session", } def __init__( @@ -732,6 +736,7 @@ def __init__( explainer_config: Optional[Any] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, + sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -757,6 +762,7 @@ def __init__( self.explainer_config = explainer_config self.tolerate_vulnerable_model = tolerate_vulnerable_model self.tolerate_deprecated_model = tolerate_deprecated_model + self.sagemaker_session = sagemaker_session class JumpStartEstimatorInitKwargs(JumpStartKwargs): @@ -949,6 +955,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "experiment_config", "tolerate_deprecated_model", "tolerate_vulnerable_model", + "sagemaker_session", ] SERIALIZATION_EXCLUSION_SET = { @@ -972,6 +979,7 @@ def __init__( experiment_config: Optional[Dict[str, str]] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, + sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -985,6 +993,7 @@ def __init__( self.experiment_config = experiment_config self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model + self.sagemaker_session = sagemaker_session class JumpStartEstimatorDeployKwargs(JumpStartKwargs): diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index caac3b9c2f..efa9bf34d2 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -381,6 +381,7 @@ def verify_model_region_and_return_specs( region: str, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -399,7 +400,10 @@ def verify_model_region_and_return_specs( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Raises: NotImplementedError: If the scope is not supported. @@ -421,8 +425,11 @@ def verify_model_region_and_return_specs( f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." ) - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=version # type: ignore + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore + region=region, + model_id=model_id, + version=version, + s3_client=sagemaker_session.s3_client, ) if ( @@ -575,6 +582,7 @@ def is_valid_model_id( region: Optional[str] = None, model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, + sagemaker_session: Optional[Session] = Session(), ) -> bool: """Returns True if the model ID is supported for the given script. @@ -586,10 +594,13 @@ def is_valid_model_id( if not isinstance(model_id, str): return False + s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME model_version = model_version or "*" - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, s3_client=s3_client + ) model_id_set = {model.model_id for model in models_manifest_list} if script == enums.JumpStartScriptScope.INFERENCE: return model_id in model_id_set @@ -600,6 +611,7 @@ def is_valid_model_id( region=region, model_id=model_id, version=model_version, + s3_client=s3_client, ).training_supported ) raise ValueError(f"Unsupported script: {script}") diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index fecd02a2eb..539cb21cc8 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -13,6 +13,7 @@ """This module contains validators related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Any, Dict, List, Optional +from sagemaker import session from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.enums import ( @@ -168,6 +169,9 @@ def validate_hyperparameters( hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, + sagemaker_session: Optional[session.Session] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> None: """Validate hyperparameters for JumpStart models. @@ -182,6 +186,15 @@ def validate_hyperparameters( If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. region (str): Region for which to validate hyperparameters. (Default: JumpStart default region). + sagemaker_session (Optional[Session]): Custom SageMaker Session to use. + (Default: Session()). + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: False). Raises: JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, @@ -200,6 +213,9 @@ def validate_hyperparameters( version=model_version, region=region, scope=JumpStartScriptScope.TRAINING, + sagemaker_session=sagemaker_session, + tolerate_deprecated_model=tolerate_deprecated_model, + tolerate_vulnerable_model=tolerate_vulnerable_model, ) hyperparameters_specs = model_specs.hyperparameters diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index ec6ae387d5..1e8b192e39 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -29,6 +30,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. @@ -46,6 +48,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: list: The default metric definitions to use for the model or None. @@ -59,5 +65,10 @@ def retrieve_default( ) return artifacts._retrieve_default_training_metric_definitions( - model_id, model_version, region, tolerate_vulnerable_model, tolerate_deprecated_model + model_id, + model_version, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 982f3ca908..fe3209a39e 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -18,6 +18,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,6 +31,7 @@ def retrieve( model_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. @@ -48,6 +50,10 @@ def retrieve( tolerate_deprecated_model (bool): ``True`` if deprecated versions of model specifications should be tolerated without raising an exception. If ``False``, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: The model artifact S3 URI for the corresponding model. @@ -70,4 +76,5 @@ def retrieve( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index ac4cb78df1..695a51c5f8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -32,7 +32,7 @@ def retrieve_default( endpoint_name: str, - sagemaker_session: Optional[Session] = None, + sagemaker_session: Session = Session(), region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, @@ -44,7 +44,7 @@ def retrieve_default( Args: endpoint_name (str): Endpoint name for which to create a predictor. sagemaker_session (Session): The SageMaker Session to attach to the Predictor. - (Default: None). + (Default: Session()). region (str): The AWS Region for which to retrieve the default predictor. (Default: None). model_id (str): The model ID of the model for which to @@ -80,4 +80,5 @@ def retrieve_default( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 4bdf64060d..38ff180787 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,6 +31,7 @@ def retrieve( script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -48,6 +50,10 @@ def retrieve( tolerate_deprecated_model (bool): ``True`` if deprecated models should be tolerated without raising an exception. ``False`` if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: str: The model script URI for the corresponding model. @@ -70,4 +76,5 @@ def retrieve( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 54091e71a6..5a1c5aa47b 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -31,6 +31,7 @@ ) from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.session import Session def retrieve_options( @@ -39,6 +40,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model matching the given arguments. @@ -56,6 +58,10 @@ def retrieve_options( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: List[SimpleBaseSerializer]: The supported serializers to use for the model. @@ -74,6 +80,7 @@ def retrieve_options( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -83,6 +90,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = Session(), ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -100,6 +108,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions (Default: None). If not + specified, one is created using the default AWS configuration + chain. (Default: Session()). Returns: SimpleBaseSerializer: The default serializer to use for the model. @@ -118,4 +130,5 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index f5a4798374..28211d06f1 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -12,14 +12,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import - -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -37,11 +40,12 @@ def test_jumpstart_default_accept_types( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert default_accept_type == "application/json" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) @@ -61,6 +65,7 @@ def test_jumpstart_supported_accept_types( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert supported_accept_types == [ "application/json;verbose", @@ -68,5 +73,5 @@ def test_jumpstart_supported_accept_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index a0076e8a78..4b2db7d7f4 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -12,14 +12,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import - -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock from sagemaker import content_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -37,11 +40,12 @@ def test_jumpstart_default_content_types( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert default_content_type == "application/x-text" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) @@ -61,11 +65,12 @@ def test_jumpstart_supported_content_types( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert supported_content_types == [ "application/x-text", ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 60784cd19e..9d6e2f21de 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -13,7 +13,8 @@ from __future__ import absolute_import -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock from sagemaker import base_deserializers, deserializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs @@ -21,6 +22,10 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + + @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_deserializers( @@ -37,11 +42,12 @@ def test_jumpstart_default_deserializers( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert isinstance(default_deserializer, base_deserializers.JSONDeserializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) @@ -61,6 +67,7 @@ def test_jumpstart_deserializer_options( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert len(deserializer_options) == 1 @@ -72,5 +79,5 @@ def test_jumpstart_deserializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index 169bd74758..da9d1559b7 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -13,13 +13,17 @@ from __future__ import absolute_import -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock import pytest from sagemaker import environment_variables from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_environment_variables(patched_get_model_specs): @@ -30,9 +34,7 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): region = "us-west-2" vars = environment_variables.retrieve_default( - region=region, - model_id=model_id, - model_version="*", + region=region, model_id=model_id, model_version="*", sagemaker_session=mock_session ) assert vars == { "MODEL_CACHE_ROOT": "/opt/ml/model", @@ -45,14 +47,14 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() vars = environment_variables.retrieve_default( - region=region, - model_id=model_id, - model_version="1.*", + region=region, model_id=model_id, model_version="1.*", sagemaker_session=mock_session ) assert vars == { "MODEL_CACHE_ROOT": "/opt/ml/model", @@ -65,7 +67,9 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="1.*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() @@ -107,6 +111,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): model_id=model_id, model_version="*", include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, ) assert vars == { "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -116,7 +121,9 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): "SAGEMAKER_PROGRAM": "inference.py", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() @@ -125,6 +132,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): model_id=model_id, model_version="1.*", include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, ) assert vars == { "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -134,7 +142,9 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): "SAGEMAKER_PROGRAM": "inference.py", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="1.*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index b7776b02a9..2d8f7d8166 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -12,8 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import boto3 -from mock.mock import patch +from mock.mock import patch, Mock import pytest from sagemaker import hyperparameters @@ -21,6 +22,10 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_hyperparameters(patched_get_model_specs): @@ -33,10 +38,16 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): region=region, model_id=model_id, model_version="*", + sagemaker_session=mock_session, ) assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*") + patched_get_model_specs.assert_called_once_with( + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + ) patched_get_model_specs.reset_mock() @@ -44,10 +55,16 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): region=region, model_id=model_id, model_version="1.*", + sagemaker_session=mock_session, ) assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + ) patched_get_model_specs.reset_mock() @@ -56,6 +73,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, model_version="1.*", include_container_hyperparameters=True, + sagemaker_session=mock_session, ) assert params == { "adam-learning-rate": "0.05", @@ -66,7 +84,12 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): "sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 83092f74e5..0054ed9dbd 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -13,9 +13,9 @@ from __future__ import absolute_import -from mock.mock import patch +from mock.mock import patch, Mock import pytest - +import boto3 from sagemaker import hyperparameters from sagemaker.jumpstart.enums import HyperparameterValidationMode from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError @@ -23,6 +23,9 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): @@ -129,10 +132,14 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() @@ -144,6 +151,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, + sagemaker_session=mock_session, ) hyperparameter_to_test["batch-size"] = "0" @@ -425,10 +433,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_version=model_version, hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) patched_get_model_specs.reset_mock() @@ -440,6 +449,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_version=model_version, hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + sagemaker_session=mock_session, ) del hyperparameter_to_test["adam-learning-rate"] @@ -477,10 +487,11 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): model_version=model_version, hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index a16145585b..8a41891280 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -12,7 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock import pytest from sagemaker import image_uris @@ -31,6 +32,9 @@ def test_jumpstart_common_image_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + image_uris.retrieve( framework=None, region="us-west-2", @@ -38,9 +42,13 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", model_version="*", instance_type="ml.p2.xlarge", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -54,9 +62,13 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", model_version="1.*", instance_type="ml.p2.xlarge", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -70,11 +82,13 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", model_version="*", instance_type="ml.p2.xlarge", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -88,11 +102,13 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", model_version="1.*", instance_type="ml.p2.xlarge", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index b21025968b..f13121aa94 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -11,8 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock - +import boto3 from mock.mock import patch import pytest @@ -29,30 +30,51 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id, model_version = "huggingface-eqa-bert-base-cased", "*" region = "us-west-2" + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + default_training_instance_types = instance_types.retrieve_default( - region=region, model_id=model_id, model_version=model_version, scope="training" + region=region, + model_id=model_id, + model_version=model_version, + scope="training", + sagemaker_session=mock_session, ) assert default_training_instance_types == "ml.p3.2xlarge" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() default_inference_instance_types = instance_types.retrieve_default( - region=region, model_id=model_id, model_version=model_version, scope="inference" + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, ) assert default_inference_instance_types == "ml.p2.xlarge" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() default_training_instance_types = instance_types.retrieve( - region=region, model_id=model_id, model_version=model_version, scope="training" + region=region, + model_id=model_id, + model_version=model_version, + scope="training", + sagemaker_session=mock_session, ) assert default_training_instance_types == [ "ml.p3.2xlarge", @@ -63,13 +85,20 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() default_inference_instance_types = instance_types.retrieve( - region=region, model_id=model_id, model_version=model_version, scope="inference" + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, ) assert default_inference_instance_types == [ "ml.p2.xlarge", @@ -82,7 +111,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 198839475b..18bc10bcec 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -747,6 +747,7 @@ def test_no_predictor_returns_default_predictor( region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, + sagemaker_session=estimator.sagemaker_session, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -903,6 +904,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, + sagemaker_session=sagemaker_session, ) @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @@ -953,6 +955,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, + sagemaker_session=sagemaker_session, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1146,12 +1149,14 @@ def test_model_id_not_found_refeshes_cache_training( model_version=None, region=None, script=JumpStartScriptScope.TRAINING, + sagemaker_session=None, ), mock.call( model_id="js-trainable-model", model_version=None, region=None, script=JumpStartScriptScope.TRAINING, + sagemaker_session=None, ), ] ) @@ -1172,12 +1177,14 @@ def test_model_id_not_found_refeshes_cache_training( model_version=None, region=None, script=JumpStartScriptScope.TRAINING, + sagemaker_session=None, ), mock.call( model_id="js-trainable-model", model_version=None, region=None, script=JumpStartScriptScope.TRAINING, + sagemaker_session=None, ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index c5197e5399..86bdff0b43 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -416,6 +416,7 @@ def test_no_predictor_returns_default_predictor( region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, + sagemaker_session=model.sagemaker_session, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -532,12 +533,14 @@ def test_model_id_not_found_refeshes_cach_inference( model_version=None, region=None, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=None, ), mock.call( model_id="js-trainable-model", model_version=None, region=None, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=None, ), ] ) @@ -558,12 +561,14 @@ def test_model_id_not_found_refeshes_cach_inference( model_version=None, region=None, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=None, ), mock.call( model_id="js-trainable-model", model_version=None, region=None, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=None, ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index c375563a94..2bb3c996a2 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -25,12 +25,15 @@ JUMPSTART_RESOURCE_BASE_NAME, JumpStartScriptScope, ) + +from functools import partial from sagemaker.jumpstart.enums import JumpStartTag, MIMEType from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, VulnerableJumpStartModelError, ) from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from sagemaker.session import Session from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec @@ -975,25 +978,39 @@ def test_is_valid_model_id_true( Mock(model_id="bee"), Mock(model_id="see"), ] - self.assertTrue(utils.is_valid_model_id("bee")) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - mock_get_model_specs.assert_not_called() - mock_get_manifest.reset_mock() - mock_get_model_specs.reset_mock() + mock_session_value = Session() + mock_s3_client_value = mock_session_value.s3_client - mock_get_manifest.return_value = [ - Mock(model_id="ay"), - Mock(model_id="bee"), - Mock(model_id="see"), - ] + patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) - mock_get_model_specs.return_value = Mock(training_supported=True) - self.assertTrue(utils.is_valid_model_id("bee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - mock_get_model_specs.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, model_id="bee", version="*" - ) + with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): + self.assertTrue(utils.is_valid_model_id("bee")) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + mock_get_model_specs.assert_not_called() + + mock_get_manifest.reset_mock() + mock_get_model_specs.reset_mock() + + mock_get_manifest.return_value = [ + Mock(model_id="ay"), + Mock(model_id="bee"), + Mock(model_id="see"), + ] + + mock_get_model_specs.return_value = Mock(training_supported=True) + self.assertTrue(utils.is_valid_model_id("bee", script=JumpStartScriptScope.TRAINING)) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + mock_get_model_specs.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + model_id="bee", + version="*", + s3_client=mock_s3_client_value, + ) @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1003,41 +1020,61 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani Mock(model_id="bee"), Mock(model_id="see"), ] - self.assertFalse(utils.is_valid_model_id("dee")) - self.assertFalse(utils.is_valid_model_id("")) - self.assertFalse(utils.is_valid_model_id(None)) - self.assertFalse(utils.is_valid_model_id(set())) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - mock_get_model_specs.assert_not_called() + mock_session_value = Session() + mock_s3_client_value = mock_session_value.s3_client - mock_get_manifest.reset_mock() - mock_get_model_specs.reset_mock() + patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) - mock_get_manifest.return_value = [ - Mock(model_id="ay"), - Mock(model_id="bee"), - Mock(model_id="see"), - ] - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) + with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): - mock_get_manifest.reset_mock() + self.assertFalse(utils.is_valid_model_id("dee")) + self.assertFalse(utils.is_valid_model_id("")) + self.assertFalse(utils.is_valid_model_id(None)) + self.assertFalse(utils.is_valid_model_id(set())) - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id("", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(None, script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(set(), script=JumpStartScriptScope.TRAINING)) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) - mock_get_model_specs.assert_not_called() - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) + mock_get_model_specs.assert_not_called() - mock_get_manifest.reset_mock() - mock_get_model_specs.reset_mock() + mock_get_manifest.reset_mock() + mock_get_model_specs.reset_mock() - mock_get_model_specs.return_value = Mock(training_supported=False) - self.assertFalse(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - mock_get_model_specs.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, model_id="ay", version="*" - ) + mock_get_manifest.return_value = [ + Mock(model_id="ay"), + Mock(model_id="bee"), + Mock(model_id="see"), + ] + self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + + mock_get_manifest.reset_mock() + + self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) + self.assertFalse(utils.is_valid_model_id("", script=JumpStartScriptScope.TRAINING)) + self.assertFalse(utils.is_valid_model_id(None, script=JumpStartScriptScope.TRAINING)) + self.assertFalse(utils.is_valid_model_id(set(), script=JumpStartScriptScope.TRAINING)) + + mock_get_model_specs.assert_not_called() + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + + mock_get_manifest.reset_mock() + mock_get_model_specs.reset_mock() + + mock_get_model_specs.return_value = Mock(training_supported=False) + self.assertFalse(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + mock_get_model_specs.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + model_id="ay", + version="*", + s3_client=mock_s3_client_value, + ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 1c7b61c554..6be6f8251a 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import copy from typing import List +import boto3 from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_REGION_NAME_SET @@ -83,7 +84,10 @@ def get_prototype_manifest( def get_prototype_model_spec( - region: str = None, model_id: str = None, version: str = None + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -94,7 +98,10 @@ def get_prototype_model_spec( def get_special_model_spec( - region: str = None, model_id: str = None, version: str = None + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -111,6 +118,7 @@ def get_spec_from_base_spec( model_id: str = None, semantic_version_str: str = None, version: str = None, + s3_client: boto3.client = None, ) -> JumpStartModelSpecs: if version and semantic_version_str: diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ee1dd3aed4..bea68dd713 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -11,8 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock - +import boto3 from mock.mock import patch import pytest @@ -26,6 +27,9 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): patched_get_model_specs.side_effect = get_spec_from_base_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + model_id = "pytorch-ic-mobilenet-v2" region = "us-west-2" @@ -33,12 +37,15 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): region=region, model_id=model_id, model_version="*", + sagemaker_session=mock_session, ) assert definitions == [ {"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"} ] - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() @@ -46,12 +53,15 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): region=region, model_id=model_id, model_version="1.*", + sagemaker_session=mock_session, ) assert definitions == [ {"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"} ] - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="1.*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 6efdef1830..000540e12e 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -11,7 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock +import boto3 from mock.mock import patch import pytest @@ -31,15 +33,20 @@ def test_jumpstart_common_model_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + model_uris.retrieve( model_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -50,11 +57,13 @@ def test_jumpstart_common_model_uri( model_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -66,9 +75,13 @@ def test_jumpstart_common_model_uri( model_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -80,9 +93,13 @@ def test_jumpstart_common_model_uri( model_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index ef33eff974..3f38326608 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -11,7 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock import pytest +import boto3 from mock.mock import patch @@ -31,15 +33,20 @@ def test_jumpstart_common_script_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + script_uris.retrieve( script_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -50,11 +57,13 @@ def test_jumpstart_common_script_uri( script_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -66,9 +75,10 @@ def test_jumpstart_common_script_uri( script_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -80,9 +90,13 @@ def test_jumpstart_common_script_uri( script_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index dcca3a7b4d..b22b61dc40 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -11,6 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock +import boto3 from mock.mock import patch @@ -32,16 +34,22 @@ def test_jumpstart_default_serializers( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) default_serializer = serializers.retrieve_default( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert isinstance(default_serializer, base_serializers.IdentitySerializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() @@ -56,6 +64,9 @@ def test_jumpstart_serializer_options( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -63,6 +74,7 @@ def test_jumpstart_serializer_options( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert len(serializer_options) == 1 assert all( @@ -73,5 +85,8 @@ def test_jumpstart_serializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) From e99cb3c04a78efafc176cbe22ac0d80bc13dae27 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 7 Aug 2023 17:25:45 +0000 Subject: [PATCH 02/13] fix: docstring --- src/sagemaker/accept_types.py | 4 ++-- src/sagemaker/content_types.py | 4 ++-- src/sagemaker/deserializers.py | 4 ++-- src/sagemaker/environment_variables.py | 2 +- src/sagemaker/hyperparameters.py | 4 ++-- src/sagemaker/image_uris.py | 2 +- src/sagemaker/instance_types.py | 4 ++-- .../jumpstart/artifacts/environment_variables.py | 2 +- .../jumpstart/artifacts/hyperparameters.py | 2 +- src/sagemaker/jumpstart/artifacts/image_uris.py | 2 +- .../jumpstart/artifacts/incremental_training.py | 2 +- .../jumpstart/artifacts/instance_types.py | 4 ++-- src/sagemaker/jumpstart/artifacts/kwargs.py | 8 ++++---- .../jumpstart/artifacts/metric_definitions.py | 2 +- .../jumpstart/artifacts/model_packages.py | 4 ++-- src/sagemaker/jumpstart/artifacts/model_uris.py | 4 ++-- src/sagemaker/jumpstart/artifacts/predictors.py | 16 ++++++++-------- .../jumpstart/artifacts/resource_names.py | 2 +- src/sagemaker/jumpstart/artifacts/script_uris.py | 4 ++-- src/sagemaker/jumpstart/utils.py | 2 +- src/sagemaker/metric_definitions.py | 2 +- src/sagemaker/model_uris.py | 2 +- src/sagemaker/script_uris.py | 2 +- src/sagemaker/serializers.py | 4 ++-- 24 files changed, 44 insertions(+), 44 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 7fb3aaa441..dda66597c7 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -43,7 +43,7 @@ def retrieve_options( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -92,7 +92,7 @@ def retrieve_default( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 9bac693e9a..f2c8a6efa6 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -43,7 +43,7 @@ def retrieve_options( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -92,7 +92,7 @@ def retrieve_default( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 0cb6ed4180..9cac71ccae 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -61,7 +61,7 @@ def retrieve_options( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -111,7 +111,7 @@ def retrieve_default( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index e0e6daa0cb..b74a395e43 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -54,7 +54,7 @@ def retrieve_default( variables that would be required when making the low-level AWS API call. (Default: True). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 46f43f2e85..9a9717fd06 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -59,7 +59,7 @@ def retrieve_default( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -118,7 +118,7 @@ def validate( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 72947a68e3..fd10a8e522 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -112,7 +112,7 @@ def retrieve( Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to determine processor type. sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 7639bd8d59..a625d1bc6c 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -52,7 +52,7 @@ def retrieve_default( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -106,7 +106,7 @@ def retrieve( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 4818f6c430..33bd20ba5d 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -55,7 +55,7 @@ def _retrieve_default_environment_variables( variables that would be required when making the low-level AWS API call. (Default: True). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index 403c94a060..0acb8cd641 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -59,7 +59,7 @@ def _retrieve_default_hyperparameters( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index d9f4917c42..985ad4cea2 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -91,7 +91,7 @@ def _retrieve_image_uri( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index ba32e54764..901f0e70d8 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -50,7 +50,7 @@ def _model_supports_incremental_training( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 3cdf51073f..33b9609994 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -56,7 +56,7 @@ def _retrieve_default_instance_type( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -122,7 +122,7 @@ def _retrieve_instance_types( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 41855ad598..f4f38bbb9a 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -52,7 +52,7 @@ def _retrieve_model_init_kwargs( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -108,7 +108,7 @@ def _retrieve_model_deploy_kwargs( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). @@ -163,7 +163,7 @@ def _retrieve_estimator_init_kwargs( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -219,7 +219,7 @@ def _retrieve_estimator_fit_kwargs( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index f32d82952f..85eeb29316 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -51,7 +51,7 @@ def _retrieve_default_training_metric_definitions( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 67ac107091..b64ac194cc 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -51,7 +51,7 @@ def _retrieve_model_package_arn( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). @@ -112,7 +112,7 @@ def _retrieve_model_package_model_artifact_s3_uri( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 4f5f5ab6c2..d2f26c099c 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -58,7 +58,7 @@ def _retrieve_model_uri( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -126,7 +126,7 @@ def _model_supports_training_model_uri( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index cd75b5b358..e4ac8238e4 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -92,7 +92,7 @@ def _retrieve_default_deserializer( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). @@ -136,7 +136,7 @@ def _retrieve_default_serializer( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -179,7 +179,7 @@ def _retrieve_deserializer_options( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -236,7 +236,7 @@ def _retrieve_serializer_options( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -293,7 +293,7 @@ def _retrieve_default_content_type( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -341,7 +341,7 @@ def _retrieve_default_accept_type( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -390,7 +390,7 @@ def _retrieve_supported_accept_types( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -439,7 +439,7 @@ def _retrieve_supported_content_types( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 908e18b7ae..817df05818 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -50,7 +50,7 @@ def _retrieve_resource_name_base( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index ceba059b60..bb396815ed 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -58,7 +58,7 @@ def _retrieve_script_uri( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -124,7 +124,7 @@ def _model_supports_inference_script_uri( specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index efa9bf34d2..1ff2c6b217 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -401,7 +401,7 @@ def verify_model_region_and_return_specs( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 1e8b192e39..1fbb523941 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -49,7 +49,7 @@ def retrieve_default( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index fe3209a39e..58ef6f51fb 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -51,7 +51,7 @@ def retrieve( specifications should be tolerated without raising an exception. If ``False``, raises an exception if the version of the model is deprecated. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 38ff180787..4de23043a4 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -51,7 +51,7 @@ def retrieve( without raising an exception. ``False`` if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 5a1c5aa47b..adcf9c653d 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -59,7 +59,7 @@ def retrieve_options( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: @@ -109,7 +109,7 @@ def retrieve_default( (exception not raised). False if these models should raise an exception. (Default: False). sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions (Default: None). If not + object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: Session()). Returns: From c7d3bd2195a74bad97f98ffd02561c1b22b5f3b5 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 7 Aug 2023 22:26:29 +0000 Subject: [PATCH 03/13] fix: unit tests --- src/sagemaker/accept_types.py | 5 +-- src/sagemaker/content_types.py | 5 +-- src/sagemaker/deserializers.py | 5 +-- src/sagemaker/environment_variables.py | 3 +- src/sagemaker/hyperparameters.py | 5 +-- src/sagemaker/image_uris.py | 6 ++-- src/sagemaker/instance_types.py | 5 +-- .../artifacts/environment_variables.py | 5 +-- .../jumpstart/artifacts/hyperparameters.py | 5 +-- .../jumpstart/artifacts/image_uris.py | 5 +-- .../artifacts/incremental_training.py | 5 +-- .../jumpstart/artifacts/instance_types.py | 9 ++--- src/sagemaker/jumpstart/artifacts/kwargs.py | 17 +++++----- .../jumpstart/artifacts/metric_definitions.py | 5 +-- .../jumpstart/artifacts/model_packages.py | 9 ++--- .../jumpstart/artifacts/model_uris.py | 9 ++--- .../jumpstart/artifacts/predictors.py | 33 ++++++++++--------- .../jumpstart/artifacts/resource_names.py | 5 +-- .../jumpstart/artifacts/script_uris.py | 9 ++--- src/sagemaker/jumpstart/constants.py | 5 +++ src/sagemaker/jumpstart/factory/estimator.py | 3 +- src/sagemaker/jumpstart/factory/model.py | 3 +- src/sagemaker/jumpstart/utils.py | 6 ++-- src/sagemaker/jumpstart/validators.py | 2 +- src/sagemaker/metric_definitions.py | 3 +- src/sagemaker/model_uris.py | 3 +- src/sagemaker/predictor.py | 5 +-- src/sagemaker/script_uris.py | 3 +- src/sagemaker/serializers.py | 5 +-- .../jumpstart/estimator/test_estimator.py | 22 ++++++------- .../estimator/test_sagemaker_config.py | 3 +- .../sagemaker/jumpstart/model/test_model.py | 25 +++++++------- .../jumpstart/model/test_sagemaker_config.py | 5 +-- tests/unit/sagemaker/jumpstart/test_utils.py | 6 ++-- 34 files changed, 141 insertions(+), 108 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index dda66597c7..85f8d22771 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -15,6 +15,7 @@ from typing import List, Optional from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session @@ -24,7 +25,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported accept types for the model matching the given arguments. @@ -73,7 +74,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default accept type for the model matching the given arguments. diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index f2c8a6efa6..6b7e474cf3 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -15,6 +15,7 @@ from typing import List, Optional from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session @@ -24,7 +25,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported content types for the model matching the given arguments. @@ -73,7 +74,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model matching the given arguments. diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 9cac71ccae..70d7d94d7e 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -33,6 +33,7 @@ ) from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session @@ -42,7 +43,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model matching the given arguments. @@ -92,7 +93,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b74a395e43..dbf413057f 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 9a9717fd06..3a0a5dfa55 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import HyperparameterValidationMode from sagemaker.jumpstart.validators import validate_hyperparameters from sagemaker.session import Session @@ -33,7 +34,7 @@ def retrieve_default( include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -92,7 +93,7 @@ def validate( validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> None: """Validates hyperparameters for models. diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index fd10a8e522..dd356893a2 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -21,8 +21,8 @@ from packaging.version import Version from sagemaker import utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.utils import is_jumpstart_model_input -from sagemaker.session import Session from sagemaker.spark import defaults from sagemaker.jumpstart import artifacts from sagemaker.workflow import is_pipeline_variable @@ -61,7 +61,7 @@ def retrieve( sdk_version=None, inference_tool=None, serverless_inference_config=None, - sagemaker_session=Session(), + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -114,7 +114,7 @@ def retrieve( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The ECR URI for the corresponding SageMaker Docker image. diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index a625d1bc6c..c7cd269585 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -87,7 +88,7 @@ def retrieve( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported training instance types for the model matching the given arguments. diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 33bd20ba5d..63e8a39a0d 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Dict, Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -32,7 +33,7 @@ def _retrieve_default_environment_variables( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -57,7 +58,7 @@ def _retrieve_default_environment_variables( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the inference environment variables to use for the model. """ diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index 0acb8cd641..a61440a3f2 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Dict, Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -33,7 +34,7 @@ def _retrieve_default_hyperparameters( include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -61,7 +62,7 @@ def _retrieve_default_hyperparameters( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the hyperparameters to use for the model. """ diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 985ad4cea2..4aeb1c4da4 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -16,6 +16,7 @@ from typing import Optional from sagemaker import image_uris from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -44,7 +45,7 @@ def _retrieve_image_uri( training_compiler_config: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Retrieves the container image URI for JumpStart models. @@ -93,7 +94,7 @@ def _retrieve_image_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the ECR URI for the corresponding SageMaker Docker image. diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 901f0e70d8..f8748439a3 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -31,7 +32,7 @@ def _model_supports_incremental_training( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns True if the model supports incremental training. @@ -52,7 +53,7 @@ def _model_supports_incremental_training( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for incremental training. """ diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 33b9609994..30201f7c25 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -17,6 +17,7 @@ from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -35,7 +36,7 @@ def _retrieve_default_instance_type( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default instance type for the model. @@ -58,7 +59,7 @@ def _retrieve_default_instance_type( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default instance type to use for the model or None. @@ -101,7 +102,7 @@ def _retrieve_instance_types( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported instance types for the model. @@ -124,7 +125,7 @@ def _retrieve_instance_types( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported instance types to use for the model or None. diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index f4f38bbb9a..97eab454fa 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -17,6 +17,7 @@ from sagemaker.session import Session from sagemaker.utils import volume_size_supported from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -33,7 +34,7 @@ def _retrieve_model_init_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> dict: """Retrieves kwargs for `Model`. @@ -54,7 +55,7 @@ def _retrieve_model_init_kwargs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. """ @@ -87,7 +88,7 @@ def _retrieve_model_deploy_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -110,7 +111,7 @@ def _retrieve_model_deploy_kwargs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. @@ -142,7 +143,7 @@ def _retrieve_estimator_init_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> dict: """Retrieves kwargs for `Estimator`. @@ -165,7 +166,7 @@ def _retrieve_estimator_init_kwargs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. """ @@ -200,7 +201,7 @@ def _retrieve_estimator_fit_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -221,7 +222,7 @@ def _retrieve_estimator_fit_kwargs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 85eeb29316..79b8c8063b 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -15,6 +15,7 @@ from copy import deepcopy from typing import Dict, List, Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -32,7 +33,7 @@ def _retrieve_default_training_metric_definitions( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -53,7 +54,7 @@ def _retrieve_default_training_metric_definitions( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the default training metric definitions to use for the model or None. """ diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index b64ac194cc..e9913ec7f3 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.utils import ( @@ -32,7 +33,7 @@ def _retrieve_model_package_arn( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -53,7 +54,7 @@ def _retrieve_model_package_arn( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model package arn to use for the model or None. @@ -91,7 +92,7 @@ def _retrieve_model_package_model_artifact_s3_uri( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -114,7 +115,7 @@ def _retrieve_model_package_model_artifact_s3_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model package artifact uri to use for the model or None. diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index d2f26c099c..d953f7e809 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -16,6 +16,7 @@ from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, JUMPSTART_DEFAULT_REGION_NAME, ) @@ -36,7 +37,7 @@ def _retrieve_model_uri( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -60,7 +61,7 @@ def _retrieve_model_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model artifact S3 URI for the corresponding model. @@ -107,7 +108,7 @@ def _model_supports_training_model_uri( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns True if the model supports training with model uri field. @@ -128,7 +129,7 @@ def _model_supports_training_model_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for model uri with training. """ diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index e4ac8238e4..0920ed9ed2 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -18,6 +18,7 @@ from sagemaker.jumpstart.constants import ( ACCEPT_TYPE_TO_DESERIALIZER_TYPE_MAP, CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, DESERIALIZER_TYPE_TO_CLASS_MAP, JUMPSTART_DEFAULT_REGION_NAME, SERIALIZER_TYPE_TO_CLASS_MAP, @@ -74,7 +75,7 @@ def _retrieve_default_deserializer( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -94,7 +95,7 @@ def _retrieve_default_deserializer( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -118,7 +119,7 @@ def _retrieve_default_serializer( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -138,7 +139,7 @@ def _retrieve_default_serializer( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -161,7 +162,7 @@ def _retrieve_deserializer_options( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -181,7 +182,7 @@ def _retrieve_deserializer_options( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -218,7 +219,7 @@ def _retrieve_serializer_options( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model. @@ -238,7 +239,7 @@ def _retrieve_serializer_options( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -275,7 +276,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model. @@ -295,7 +296,7 @@ def _retrieve_default_content_type( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default content type to use for the model. """ @@ -323,7 +324,7 @@ def _retrieve_default_accept_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default accept type for the model. @@ -343,7 +344,7 @@ def _retrieve_default_accept_type( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default accept type to use for the model. """ @@ -372,7 +373,7 @@ def _retrieve_supported_accept_types( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -392,7 +393,7 @@ def _retrieve_supported_accept_types( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported accept types to use for the model. """ @@ -421,7 +422,7 @@ def _retrieve_supported_content_types( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported content types for the model. @@ -441,7 +442,7 @@ def _retrieve_supported_content_types( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported content types to use for the model. """ diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 817df05818..29cf072595 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -31,7 +32,7 @@ def _retrieve_resource_name_base( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns default resource name. @@ -52,7 +53,7 @@ def _retrieve_resource_name_base( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default resource name. """ diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index bb396815ed..5c3dfa0408 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -15,6 +15,7 @@ import os from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, JUMPSTART_DEFAULT_REGION_NAME, ) @@ -35,7 +36,7 @@ def _retrieve_script_uri( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -60,7 +61,7 @@ def _retrieve_script_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model script URI for the corresponding model. @@ -105,7 +106,7 @@ def _model_supports_inference_script_uri( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns True if the model supports inference with script uri field. @@ -126,7 +127,7 @@ def _model_supports_inference_script_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for script uri with inference. """ diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index dc22b57fce..ad764317c6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -30,6 +30,7 @@ IdentitySerializer, JSONSerializer, ) +from sagemaker.session import Session JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set( @@ -177,3 +178,7 @@ MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") + +DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session( + boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) +) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 612c40f7f5..3071411183 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -43,6 +43,7 @@ _model_supports_training_model_uri, ) from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, TRAINING_ENTRY_POINT_SCRIPT_NAME, @@ -384,7 +385,7 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" - kwargs.sagemaker_session = kwargs.sagemaker_session or Session() + kwargs.sagemaker_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION return kwargs diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 5455179831..967dd7f8ec 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -28,6 +28,7 @@ ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, @@ -122,7 +123,7 @@ def _add_sagemaker_session_to_kwargs( kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] ) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" - kwargs.sagemaker_session = kwargs.sagemaker_session or Session() + kwargs.sagemaker_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION return kwargs diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 1ff2c6b217..56386348af 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -381,7 +381,7 @@ def verify_model_region_and_return_specs( region: str, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -403,7 +403,7 @@ def verify_model_region_and_return_specs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Raises: NotImplementedError: If the scope is not supported. @@ -582,7 +582,7 @@ def is_valid_model_id( region: Optional[str] = None, model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, - sagemaker_session: Optional[Session] = Session(), + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns True if the model ID is supported for the given script. diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 539cb21cc8..4c0465904a 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -187,7 +187,7 @@ def validate_hyperparameters( region (str): Region for which to validate hyperparameters. (Default: JumpStart default region). sagemaker_session (Optional[Session]): Custom SageMaker Session to use. - (Default: Session()). + (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 1fbb523941..1425bfba75 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,7 +31,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 58ef6f51fb..7e7ef89cc4 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -18,6 +18,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session @@ -31,7 +32,7 @@ def retrieve( model_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 695a51c5f8..cc9fea8287 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -14,6 +14,7 @@ from __future__ import print_function, absolute_import from typing import Optional +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.utils import is_jumpstart_model_input @@ -32,7 +33,7 @@ def retrieve_default( endpoint_name: str, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, @@ -44,7 +45,7 @@ def retrieve_default( Args: endpoint_name (str): Endpoint name for which to create a predictor. sagemaker_session (Session): The SageMaker Session to attach to the Predictor. - (Default: Session()). + (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). region (str): The AWS Region for which to retrieve the default predictor. (Default: None). model_id (str): The model ID of the model for which to diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 4de23043a4..950258878f 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -19,6 +19,7 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ def retrieve( script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index adcf9c653d..4a8ae94344 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -31,6 +31,7 @@ ) from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session @@ -40,7 +41,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model matching the given arguments. @@ -90,7 +91,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, - sagemaker_session: Session = Session(), + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 18bc10bcec..6cfcc3420a 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -23,6 +23,7 @@ from sagemaker.debugger.profiler_config import ProfilerConfig from sagemaker.estimator import Estimator from sagemaker.instance_group import InstanceGroup +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.estimator import JumpStartEstimator @@ -31,7 +32,6 @@ from tests.integ.sagemaker.jumpstart.utils import get_training_dataset_for_model_and_version from sagemaker.model import Model from sagemaker.predictor import Predictor -from sagemaker.session import Session from tests.unit.sagemaker.jumpstart.utils import ( get_special_model_spec, overwrite_dictionary, @@ -40,7 +40,7 @@ execution_role = "fake role! do not use!" region = "us-west-2" -sagemaker_session = Session() +sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION sagemaker_session.get_caller_identity_arn = lambda: execution_role default_predictor = Predictor("eiifccreeeiuchhnehtlbdecgeeelgjccjvvbbcncnhv", sagemaker_session) default_predictor_with_presets = Predictor( @@ -434,7 +434,7 @@ def test_estimator_use_kwargs(self): "output_path": "Optional[Union[str, PipelineVariable]] = None", "output_kms_key": "Optional[Union[str, PipelineVariable]] = None", "base_job_name": "Optional[str] = None", - "sagemaker_session": Session(), + "sagemaker_session": DEFAULT_JUMPSTART_SAGEMAKER_SESSION, "hyperparameters": {"hyp1": "val1"}, "tags": [{"1": "hum"}], "subnets": ["1", "2"], @@ -1033,8 +1033,13 @@ def test_training_passes_role_to_deploy( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", sagemaker_session + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", + sagemaker_session, + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1047,8 +1052,6 @@ def test_training_passes_session_to_deploy( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_session_estimator: mock.Mock, - mock_session_model: mock.Mock, mock_is_valid_model_id: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): @@ -1062,11 +1065,8 @@ def test_training_passes_session_to_deploy( mock_get_model_specs.side_effect = get_special_model_spec - mock_session_estimator.return_value = sagemaker_session - mock_session_model.return_value = sagemaker_session - mock_role = f"dsfsdfsd{time.time()}" - mock_sagemaker_session = Session() + mock_sagemaker_session = mock.MagicMock(sagemaker_config={}) mock_sagemaker_session.get_caller_identity_arn = lambda: mock_role estimator = JumpStartEstimator( diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index 1f2b11d6a6..d22e910a00 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -22,6 +22,7 @@ TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, TRAINING_JOB_ROLE_ARN_PATH, ) +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.estimator import JumpStartEstimator from sagemaker.session import Session @@ -32,7 +33,7 @@ execution_role = "fake role! do not use!" region = "us-west-2" -sagemaker_session = Session() +sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION sagemaker_session.get_caller_identity_arn = lambda: execution_role default_predictor = Predictor("eiifccreeeiujigjjdfgiujrcibigckbtregvkjeurru", sagemaker_session) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 86bdff0b43..dc5c980f90 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -18,19 +18,19 @@ from mock import MagicMock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model from sagemaker.predictor import Predictor -from sagemaker.session import Session from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, overwrite_dictionary execution_role = "fake role! do not use!" region = "us-west-2" -sagemaker_session = Session() +sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION sagemaker_session.get_caller_identity_arn = lambda: execution_role default_predictor = Predictor("blah", sagemaker_session) default_predictor_with_presets = Predictor( @@ -39,6 +39,9 @@ class ModelTest(unittest.TestCase): + + mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -574,13 +577,11 @@ def test_model_id_not_found_refeshes_cach_inference( ) @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") - @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, - mock_session: mock.Mock, mock_is_valid_model_id: mock.Mock, ): @@ -590,14 +591,14 @@ def test_jumpstart_model_package_arn( mock_get_model_specs.side_effect = get_special_model_spec - mock_session.return_value = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}) - model = JumpStartModel(model_id=model_id) + model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) model.deploy() self.assertEqual( - mock_session.return_value.create_model.call_args[0][2], + mock_session.create_model.call_args[0][2], { "ModelPackageName": "arn:aws:sagemaker:us-west-2:594846645681:model-package" "/llama2-7b-f-e46eb8a833643ed58aaccd81498972c3" @@ -605,13 +606,11 @@ def test_jumpstart_model_package_arn( ) @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") - @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, - mock_session: mock.Mock, mock_is_valid_model_id: mock.Mock, ): @@ -622,18 +621,20 @@ def test_jumpstart_model_package_arn_override( mock_get_model_specs.side_effect = get_special_model_spec - mock_session.return_value = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}) model_package_arn = ( "arn:aws:sagemaker:us-west-2:867530986753:model-package/" "llama2-ynnej-f-e46eb8a833643ed58aaccd81498972c3" ) - model = JumpStartModel(model_id=model_id, model_package_arn=model_package_arn) + model = JumpStartModel( + model_id=model_id, model_package_arn=model_package_arn, sagemaker_session=mock_session + ) model.deploy() self.assertEqual( - mock_session.return_value.create_model.call_args[0][2], + mock_session.create_model.call_args[0][2], { "ModelPackageName": model_package_arn, "Environment": { diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 00563ad106..727f3060b3 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -18,6 +18,7 @@ MODEL_ENABLE_NETWORK_ISOLATION_PATH, MODEL_EXECUTION_ROLE_ARN_PATH, ) +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.model import JumpStartModel from sagemaker.session import Session @@ -28,7 +29,7 @@ execution_role = "fake role! do not use!" region = "us-west-2" -sagemaker_session = Session() +sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION sagemaker_session.get_caller_identity_arn = lambda: execution_role override_role = "fdsfsdfs" @@ -56,7 +57,7 @@ class IntelligentDefaultsModelTest(unittest.TestCase): execution_role = "fake role! do not use!" region = "us-west-2" - sagemaker_session = Session() + sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.model.Model.__init__") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 2bb3c996a2..a52884f68c 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -18,6 +18,7 @@ import random from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, JUMPSTART_BUCKET_NAME_SET, JUMPSTART_DEFAULT_REGION_NAME, @@ -33,7 +34,6 @@ VulnerableJumpStartModelError, ) from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId -from sagemaker.session import Session from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec @@ -979,7 +979,7 @@ def test_is_valid_model_id_true( Mock(model_id="see"), ] - mock_session_value = Session() + mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) @@ -1021,7 +1021,7 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani Mock(model_id="see"), ] - mock_session_value = Session() + mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION mock_s3_client_value = mock_session_value.s3_client patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) From 5ce011be6e1b4efd62f931057940b684c68e00c2 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 8 Aug 2023 14:14:42 +0000 Subject: [PATCH 04/13] fix: docstring --- src/sagemaker/accept_types.py | 4 ++-- src/sagemaker/content_types.py | 4 ++-- src/sagemaker/deserializers.py | 4 ++-- src/sagemaker/environment_variables.py | 2 +- src/sagemaker/hyperparameters.py | 4 ++-- src/sagemaker/instance_types.py | 4 ++-- src/sagemaker/metric_definitions.py | 2 +- src/sagemaker/model_uris.py | 2 +- src/sagemaker/script_uris.py | 2 +- src/sagemaker/serializers.py | 4 ++-- 10 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 85f8d22771..bf081365ab 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -46,7 +46,7 @@ def retrieve_options( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: The supported accept types to use for the model. @@ -95,7 +95,7 @@ def retrieve_default( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The default accept type to use for the model. diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 6b7e474cf3..e43e96be17 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -46,7 +46,7 @@ def retrieve_options( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: The supported content types to use for the model. @@ -95,7 +95,7 @@ def retrieve_default( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The default content type to use for the model. diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 70d7d94d7e..21174d9f77 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -64,7 +64,7 @@ def retrieve_options( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[BaseDeserializer]: The supported deserializers to use for the model. @@ -114,7 +114,7 @@ def retrieve_default( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: BaseDeserializer: The default deserializer to use for the model. diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index dbf413057f..615fc92f16 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -57,7 +57,7 @@ def retrieve_default( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: The variables to use for the model. diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 3a0a5dfa55..f53d9e4e2b 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -62,7 +62,7 @@ def retrieve_default( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: The hyperparameters to use for the model. @@ -121,7 +121,7 @@ def validate( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Raises: JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index c7cd269585..111cc51f29 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -55,7 +55,7 @@ def retrieve_default( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The default instance type to use for the model. @@ -109,7 +109,7 @@ def retrieve( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: The supported instance types to use for the model. diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 1425bfba75..648c6e0cb4 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -52,7 +52,7 @@ def retrieve_default( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: The default metric definitions to use for the model or None. diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 7e7ef89cc4..91890be975 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -54,7 +54,7 @@ def retrieve( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The model artifact S3 URI for the corresponding model. diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 950258878f..9a1c4933d2 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -54,7 +54,7 @@ def retrieve( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The model script URI for the corresponding model. diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 4a8ae94344..60365d2621 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -62,7 +62,7 @@ def retrieve_options( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[SimpleBaseSerializer]: The supported serializers to use for the model. @@ -112,7 +112,7 @@ def retrieve_default( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: Session()). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: SimpleBaseSerializer: The default serializer to use for the model. From 87da5bc78a176dea5c32f49ec2a53b9f72441308 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 8 Aug 2023 15:03:51 +0000 Subject: [PATCH 05/13] fix: docstring --- src/sagemaker/image_uris.py | 2 +- .../jumpstart/artifacts/environment_variables.py | 2 +- .../jumpstart/artifacts/hyperparameters.py | 2 +- src/sagemaker/jumpstart/artifacts/image_uris.py | 2 +- .../jumpstart/artifacts/incremental_training.py | 2 +- .../jumpstart/artifacts/instance_types.py | 4 ++-- src/sagemaker/jumpstart/artifacts/kwargs.py | 8 ++++---- .../jumpstart/artifacts/metric_definitions.py | 2 +- .../jumpstart/artifacts/model_packages.py | 4 ++-- src/sagemaker/jumpstart/artifacts/model_uris.py | 4 ++-- src/sagemaker/jumpstart/artifacts/predictors.py | 16 ++++++++-------- .../jumpstart/artifacts/resource_names.py | 2 +- src/sagemaker/jumpstart/artifacts/script_uris.py | 4 ++-- src/sagemaker/jumpstart/utils.py | 2 +- src/sagemaker/jumpstart/validators.py | 2 +- src/sagemaker/predictor.py | 2 +- 16 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index dd356893a2..adc81d99f4 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -114,7 +114,7 @@ def retrieve( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The ECR URI for the corresponding SageMaker Docker image. diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 63e8a39a0d..b54f9aab8d 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -58,7 +58,7 @@ def _retrieve_default_environment_variables( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the inference environment variables to use for the model. """ diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index a61440a3f2..6a167aa8ba 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -62,7 +62,7 @@ def _retrieve_default_hyperparameters( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the hyperparameters to use for the model. """ diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 4aeb1c4da4..0c08244ec6 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -94,7 +94,7 @@ def _retrieve_image_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the ECR URI for the corresponding SageMaker Docker image. diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index f8748439a3..753a911422 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -53,7 +53,7 @@ def _model_supports_incremental_training( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for incremental training. """ diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 30201f7c25..428a33708d 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -59,7 +59,7 @@ def _retrieve_default_instance_type( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default instance type to use for the model or None. @@ -125,7 +125,7 @@ def _retrieve_instance_types( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported instance types to use for the model or None. diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 97eab454fa..7acad9b793 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -55,7 +55,7 @@ def _retrieve_model_init_kwargs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. """ @@ -111,7 +111,7 @@ def _retrieve_model_deploy_kwargs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. @@ -166,7 +166,7 @@ def _retrieve_estimator_init_kwargs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. """ @@ -222,7 +222,7 @@ def _retrieve_estimator_fit_kwargs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 79b8c8063b..0a9cfa00ae 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -54,7 +54,7 @@ def _retrieve_default_training_metric_definitions( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the default training metric definitions to use for the model or None. """ diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index e9913ec7f3..56e3f34e91 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -54,7 +54,7 @@ def _retrieve_model_package_arn( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model package arn to use for the model or None. @@ -115,7 +115,7 @@ def _retrieve_model_package_model_artifact_s3_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model package artifact uri to use for the model or None. diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index d953f7e809..928e7652eb 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -61,7 +61,7 @@ def _retrieve_model_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model artifact S3 URI for the corresponding model. @@ -129,7 +129,7 @@ def _model_supports_training_model_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for model uri with training. """ diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 0920ed9ed2..8d599c89cc 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -95,7 +95,7 @@ def _retrieve_default_deserializer( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -139,7 +139,7 @@ def _retrieve_default_serializer( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -182,7 +182,7 @@ def _retrieve_deserializer_options( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -239,7 +239,7 @@ def _retrieve_serializer_options( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -296,7 +296,7 @@ def _retrieve_default_content_type( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default content type to use for the model. """ @@ -344,7 +344,7 @@ def _retrieve_default_accept_type( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default accept type to use for the model. """ @@ -393,7 +393,7 @@ def _retrieve_supported_accept_types( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported accept types to use for the model. """ @@ -442,7 +442,7 @@ def _retrieve_supported_content_types( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported content types to use for the model. """ diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 29cf072595..6b05f07b15 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -53,7 +53,7 @@ def _retrieve_resource_name_base( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default resource name. """ diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index 5c3dfa0408..c1b037ce61 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -61,7 +61,7 @@ def _retrieve_script_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model script URI for the corresponding model. @@ -127,7 +127,7 @@ def _model_supports_inference_script_uri( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for script uri with inference. """ diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 56386348af..164d2e7c79 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -403,7 +403,7 @@ def verify_model_region_and_return_specs( sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration - chain. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Raises: NotImplementedError: If the scope is not supported. diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 4c0465904a..3199e5fc2e 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -187,7 +187,7 @@ def validate_hyperparameters( region (str): Region for which to validate hyperparameters. (Default: JumpStart default region). sagemaker_session (Optional[Session]): Custom SageMaker Session to use. - (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index cc9fea8287..7b436a9dd8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -45,7 +45,7 @@ def retrieve_default( Args: endpoint_name (str): Endpoint name for which to create a predictor. sagemaker_session (Session): The SageMaker Session to attach to the Predictor. - (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). region (str): The AWS Region for which to retrieve the default predictor. (Default: None). model_id (str): The model ID of the model for which to From fd5f890d239cf34873115f68958e3e8d3b564ca5 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 8 Aug 2023 21:10:25 +0000 Subject: [PATCH 06/13] fix: cache kwargs --- src/sagemaker/jumpstart/accessors.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 4292130311..f63fd6f8e1 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -102,8 +102,6 @@ def _get_manifest( region (str): Optional. The region to use for the cache. """ - old_cache_kwargs = JumpStartModelsAccessor._cache_kwargs.copy() - additional_kwargs = {} if s3_client is not None: additional_kwargs.update({"s3_client": s3_client}) @@ -113,7 +111,6 @@ def _get_manifest( ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) manifest = JumpStartModelsAccessor._cache.get_manifest() # type: ignore - JumpStartModelsAccessor._cache_kwargs = old_cache_kwargs return manifest @staticmethod @@ -147,8 +144,6 @@ def get_model_specs( If not set, a default client will be made. """ - old_cache_kwargs = JumpStartModelsAccessor._cache_kwargs.copy() - additional_kwargs = {} if s3_client is not None: additional_kwargs.update({"s3_client": s3_client}) @@ -160,7 +155,6 @@ def get_model_specs( specs = JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version ) - JumpStartModelsAccessor._cache_kwargs = old_cache_kwargs return specs @staticmethod From 31f25687dd3c4fbb9b9f7e26c38d8e51cf9f0725 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 9 Aug 2023 12:40:55 +0000 Subject: [PATCH 07/13] chore: cleanup git diff --- src/sagemaker/jumpstart/accessors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index f63fd6f8e1..4a15e35172 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -110,8 +110,7 @@ def _get_manifest( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - manifest = JumpStartModelsAccessor._cache.get_manifest() # type: ignore - return manifest + return JumpStartModelsAccessor._cache.get_manifest() # type: ignore @staticmethod def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: @@ -152,10 +151,9 @@ def get_model_specs( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) - specs = JumpStartModelsAccessor._cache.get_specs( # type: ignore + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version ) - return specs @staticmethod def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None: From 1125c4f6b7f606b2e0dbc50aa479acfae35d2712 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 9 Aug 2023 20:49:30 +0000 Subject: [PATCH 08/13] chore: improve docstring --- src/sagemaker/jumpstart/accessors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 4a15e35172..7f55450a21 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -100,6 +100,8 @@ def _get_manifest( Args: region (str): Optional. The region to use for the cache. + s3_client (boto3.client): Optional. Boto3 client to use for accessing JumpStart models s3 + cache. If not set, a default client will be made. """ additional_kwargs = {} From c6fb8bd07a4f27b903c24d2f833579feca1ec3f2 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 14 Aug 2023 21:55:52 +0000 Subject: [PATCH 09/13] fix: line too long --- src/sagemaker/jumpstart/accessors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index c2dc8c80e1..8117606299 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -112,8 +112,8 @@ def _get_manifest( Args: region (str): Optional. The region to use for the cache. - s3_client (boto3.client): Optional. Boto3 client to use for accessing JumpStart models s3 - cache. If not set, a default client will be made. + s3_client (boto3.client): Optional. Boto3 client to use for accessing JumpStart models + s3 cache. If not set, a default client will be made. """ additional_kwargs = {} From 6ec4e8da77221eedd4928eb88e90b0cb57532264 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 14 Aug 2023 22:12:15 +0000 Subject: [PATCH 10/13] chore: cleanup code from PR comment --- src/sagemaker/jumpstart/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index fe3e2224c8..25206645f2 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -111,7 +111,7 @@ def __init__( if s3_bucket_name is None else s3_bucket_name ) - self._s3_client = s3_client if s3_client else ( + self._s3_client = s3_client or ( boto3.client("s3", region_name=self._region, config=s3_client_config) if s3_client_config else boto3.client("s3", region_name=self._region) From 12678bc209596ccf05e91618313e46d832cd84a5 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 15 Aug 2023 15:38:53 +0000 Subject: [PATCH 11/13] fix: model package arn unit test --- tests/unit/sagemaker/jumpstart/model/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index e688a75159..9800d71668 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -608,7 +608,7 @@ def test_jumpstart_model_package_arn( }, ) - self.assertIn(tag, mock_session.return_value.create_model.call_args[1]["tags"]) + self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") From 7e96f0ef556c56cf162eb49e13e2407fbe02fdcb Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 15 Aug 2023 16:00:17 +0000 Subject: [PATCH 12/13] fix: failing test_new_session_created --- tests/unit/test_session.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 21c68d2f13..7a31de9237 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -89,6 +89,7 @@ def test_default_session(boto3_default_session): assert sess.boto_session is boto3_default_session +@patch("boto3.DEFAULT_SESSION", None) @patch("boto3.Session") def test_new_session_created(boto3_session): sess = Session() From b2d281c05f85e5da486b68ff014d7aeb8b95326f Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 15 Aug 2023 20:35:31 +0000 Subject: [PATCH 13/13] fix: reset cache if content bucket changes --- src/sagemaker/jumpstart/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index c93b5f7311..f77e1ae231 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -100,6 +100,7 @@ def get_jumpstart_content_bucket( accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return) if bucket_to_return != old_content_bucket: + accessors.JumpStartModelsAccessor.reset_cache() for info_log in info_logs: constants.JUMPSTART_LOGGER.info(info_log) return bucket_to_return