diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 22f41812ca..bf081365ab 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -15,6 +15,8 @@ from typing import List, Optional from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session def retrieve_options( @@ -23,6 +25,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported accept types for the model matching the given arguments. @@ -40,6 +43,10 @@ def retrieve_options( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: The supported accept types to use for the model. @@ -57,6 +64,7 @@ def retrieve_options( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -66,6 +74,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default accept type for the model matching the given arguments. @@ -83,6 +92,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The default accept type to use for the model. @@ -100,4 +113,5 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 55655c4065..e43e96be17 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -15,6 +15,8 @@ from typing import List, Optional from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session def retrieve_options( @@ -23,6 +25,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported content types for the model matching the given arguments. @@ -40,6 +43,10 @@ def retrieve_options( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: The supported content types to use for the model. @@ -57,6 +64,7 @@ def retrieve_options( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -66,6 +74,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model matching the given arguments. @@ -83,6 +92,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The default content type to use for the model. @@ -100,6 +113,7 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 1a43c98589..21174d9f77 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -33,6 +33,8 @@ ) from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session def retrieve_options( @@ -41,6 +43,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model matching the given arguments. @@ -58,6 +61,10 @@ def retrieve_options( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[BaseDeserializer]: The supported deserializers to use for the model. @@ -76,6 +83,7 @@ def retrieve_options( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -85,6 +93,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseDeserializer: """Retrieves the default deserializer for the model matching the given arguments. @@ -102,6 +111,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: BaseDeserializer: The default deserializer to use for the model. @@ -120,4 +133,5 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index 864bffc5bf..615fc92f16 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -19,6 +19,8 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,6 +32,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Dict[str, str]: """Retrieves the default container environment variables for the model matching the arguments. @@ -51,6 +54,10 @@ def retrieve_default( should be included. The `Model` class of the SageMaker Python SDK inserts environment variables that would be required when making the low-level AWS API call. (Default: True). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: The variables to use for the model. @@ -70,4 +77,5 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, include_aws_sdk_env_vars, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 7fa4a14414..f53d9e4e2b 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -19,8 +19,10 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import HyperparameterValidationMode from sagemaker.jumpstart.validators import validate_hyperparameters +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -32,6 +34,7 @@ def retrieve_default( include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Dict[str, str]: """Retrieves the default training hyperparameters for the model matching the given arguments. @@ -56,6 +59,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: The hyperparameters to use for the model. @@ -74,6 +81,7 @@ def retrieve_default( include_container_hyperparameters, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -83,6 +91,9 @@ def validate( model_version: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> None: """Validates hyperparameters for models. @@ -100,6 +111,17 @@ def validate( If set to``VALIDATE_ALGORITHM``, all algorithm hyperparameters will be validated. If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. (Default: None). + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Raises: JumpStartHyperparametersError: If the hyperparameter is not formatted correctly, @@ -125,4 +147,7 @@ def validate( hyperparameters=hyperparameters, validation_mode=validation_mode, region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index fc28d91c9f..adc81d99f4 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -21,6 +21,7 @@ from packaging.version import Version from sagemaker import utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.utils import is_jumpstart_model_input from sagemaker.spark import defaults from sagemaker.jumpstart import artifacts @@ -60,6 +61,7 @@ def retrieve( sdk_version=None, inference_tool=None, serverless_inference_config=None, + sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the ECR URI for the Docker image matching the given arguments. @@ -109,6 +111,10 @@ def retrieve( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to determine processor type. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The ECR URI for the corresponding SageMaker Docker image. @@ -147,6 +153,7 @@ def retrieve( training_compiler_config, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 436ad7bb3f..111cc51f29 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -19,6 +19,8 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,6 +32,7 @@ def retrieve_default( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -49,6 +52,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The default instance type to use for the model. @@ -70,6 +77,7 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -80,6 +88,7 @@ def retrieve( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported training instance types for the model matching the given arguments. @@ -97,6 +106,10 @@ def retrieve( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: The supported instance types to use for the model. @@ -118,4 +131,5 @@ def retrieve( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 18fbf12fbe..8117606299 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -13,6 +13,7 @@ """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Any, Dict, List, Optional +import boto3 from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs @@ -86,14 +87,24 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: region (str): region for which to retrieve header/spec. cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``. """ - if JumpStartModelsAccessor._cache is None or region != JumpStartModelsAccessor._curr_region: + new_cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + cache_kwargs, region + ) + if ( + JumpStartModelsAccessor._cache is None + or region != JumpStartModelsAccessor._curr_region + or new_cache_kwargs != JumpStartModelsAccessor._cache_kwargs + ): JumpStartModelsAccessor._cache = cache.JumpStartModelsCache( region=region, **cache_kwargs ) JumpStartModelsAccessor._curr_region = region + JumpStartModelsAccessor._cache_kwargs = new_cache_kwargs @staticmethod - def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]: + def _get_manifest( + region: str = JUMPSTART_DEFAULT_REGION_NAME, s3_client: Optional[boto3.client] = None + ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest. Raises: @@ -101,9 +112,16 @@ def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStart Args: region (str): Optional. The region to use for the cache. + s3_client (boto3.client): Optional. Boto3 client to use for accessing JumpStart models + s3 cache. If not set, a default client will be made. """ + + additional_kwargs = {} + if s3_client is not None: + additional_kwargs.update({"s3_client": s3_client}) + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( - JumpStartModelsAccessor._cache_kwargs, region + {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}, region ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_manifest() # type: ignore @@ -126,16 +144,25 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel ) @staticmethod - def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs: + def get_model_specs( + region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None + ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. Args: region (str): region for which to retrieve header. model_id (str): model ID to retrieve. version (str): semantic version to retrieve for the model ID. + s3_client (boto3.client): boto3 client to use for accessing JumpStart models s3 cache. + If not set, a default client will be made. """ + + additional_kwargs = {} + if s3_client is not None: + additional_kwargs.update({"s3_client": s3_client}) + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( - JumpStartModelsAccessor._cache_kwargs, region + {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) return JumpStartModelsAccessor._cache.get_specs( # type: ignore diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index a6c2ba0f58..b54f9aab8d 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Dict, Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -22,6 +23,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_default_environment_variables( @@ -31,6 +33,7 @@ def _retrieve_default_environment_variables( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Dict[str, str]: """Retrieves the inference environment variables for the model matching the given arguments. @@ -52,7 +55,10 @@ def _retrieve_default_environment_variables( should be included. The `Model` class of the SageMaker Python SDK inserts environment variables that would be required when making the low-level AWS API call. (Default: True). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the inference environment variables to use for the model. """ @@ -67,6 +73,7 @@ def _retrieve_default_environment_variables( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) default_environment_variables: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index fc3a93212c..6a167aa8ba 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Dict, Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -23,6 +24,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_default_hyperparameters( @@ -32,6 +34,7 @@ def _retrieve_default_hyperparameters( include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Retrieves the training hyperparameters for the model matching the given arguments. @@ -56,6 +59,10 @@ def _retrieve_default_hyperparameters( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the hyperparameters to use for the model. """ @@ -70,6 +77,7 @@ def _retrieve_default_hyperparameters( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) default_hyperparameters: Dict[str, str] = {} diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 42819b6eec..0c08244ec6 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -16,6 +16,7 @@ from typing import Optional from sagemaker import image_uris from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -25,6 +26,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_image_uri( @@ -43,6 +45,7 @@ def _retrieve_image_uri( training_compiler_config: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Retrieves the container image URI for JumpStart models. @@ -88,7 +91,10 @@ def _retrieve_image_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the ECR URI for the corresponding SageMaker Docker image. @@ -108,6 +114,7 @@ def _retrieve_image_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if image_scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index fc78ef7b70..753a911422 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -22,6 +23,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _model_supports_incremental_training( @@ -30,6 +32,7 @@ def _model_supports_incremental_training( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns True if the model supports incremental training. @@ -47,6 +50,10 @@ def _model_supports_incremental_training( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for incremental training. """ @@ -61,6 +68,7 @@ def _model_supports_incremental_training( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.supports_incremental_training() diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 841125bbbe..428a33708d 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -17,6 +17,7 @@ from sagemaker.jumpstart.exceptions import NO_AVAILABLE_INSTANCES_ERROR_MSG from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -25,6 +26,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_default_instance_type( @@ -34,6 +36,7 @@ def _retrieve_default_instance_type( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default instance type for the model. @@ -53,6 +56,10 @@ def _retrieve_default_instance_type( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default instance type to use for the model or None. @@ -71,6 +78,7 @@ def _retrieve_default_instance_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if scope == JumpStartScriptScope.INFERENCE: @@ -94,6 +102,7 @@ def _retrieve_instance_types( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported instance types for the model. @@ -113,6 +122,10 @@ def _retrieve_instance_types( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported instance types to use for the model or None. @@ -131,6 +144,7 @@ def _retrieve_instance_types( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if scope == JumpStartScriptScope.INFERENCE: diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index b2db7f0f80..7acad9b793 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -14,8 +14,10 @@ from __future__ import absolute_import from copy import deepcopy from typing import Optional +from sagemaker.session import Session from sagemaker.utils import volume_size_supported from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -32,6 +34,7 @@ def _retrieve_model_init_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> dict: """Retrieves kwargs for `Model`. @@ -49,7 +52,10 @@ def _retrieve_model_init_kwargs( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. """ @@ -64,6 +70,7 @@ def _retrieve_model_init_kwargs( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) kwargs = deepcopy(model_specs.model_kwargs) @@ -81,6 +88,7 @@ def _retrieve_model_deploy_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> dict: """Retrieves kwargs for `Model.deploy`. @@ -100,6 +108,10 @@ def _retrieve_model_deploy_kwargs( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. @@ -115,6 +127,7 @@ def _retrieve_model_deploy_kwargs( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if volume_size_supported(instance_type) and model_specs.inference_volume_size is not None: @@ -130,6 +143,7 @@ def _retrieve_estimator_init_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> dict: """Retrieves kwargs for `Estimator`. @@ -149,7 +163,10 @@ def _retrieve_estimator_init_kwargs( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. """ @@ -164,6 +181,7 @@ def _retrieve_estimator_init_kwargs( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) kwargs = deepcopy(model_specs.estimator_kwargs) @@ -183,6 +201,7 @@ def _retrieve_estimator_fit_kwargs( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> dict: """Retrieves kwargs for `Estimator.fit`. @@ -200,6 +219,10 @@ def _retrieve_estimator_fit_kwargs( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: dict: the kwargs to use for the use case. @@ -215,6 +238,7 @@ def _retrieve_estimator_fit_kwargs( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.fit_kwargs diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 6921cb8473..0a9cfa00ae 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -15,6 +15,7 @@ from copy import deepcopy from typing import Dict, List, Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -23,6 +24,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_default_training_metric_definitions( @@ -31,6 +33,7 @@ def _retrieve_default_training_metric_definitions( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model. @@ -48,7 +51,10 @@ def _retrieve_default_training_metric_definitions( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the default training metric definitions to use for the model or None. """ @@ -63,6 +69,7 @@ def _retrieve_default_training_metric_definitions( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return deepcopy(model_specs.metrics) if model_specs.metrics else None diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index 0fbf1d74b3..56e3f34e91 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.utils import ( @@ -22,6 +23,7 @@ from sagemaker.jumpstart.enums import ( JumpStartScriptScope, ) +from sagemaker.session import Session def _retrieve_model_package_arn( @@ -31,6 +33,7 @@ def _retrieve_model_package_arn( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[str]: """Retrieves associated model pacakge arn for the model. @@ -48,6 +51,10 @@ def _retrieve_model_package_arn( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model package arn to use for the model or None. @@ -63,6 +70,7 @@ def _retrieve_model_package_arn( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if scope == JumpStartScriptScope.INFERENCE: @@ -84,6 +92,7 @@ def _retrieve_model_package_model_artifact_s3_uri( scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[str]: """Retrieves s3 artifact uri associated with model package. @@ -103,7 +112,10 @@ def _retrieve_model_package_model_artifact_s3_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model package artifact uri to use for the model or None. @@ -123,6 +135,7 @@ def _retrieve_model_package_model_artifact_s3_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if model_specs.training_model_package_artifact_uris is None: diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 31505e0865..928e7652eb 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -16,6 +16,7 @@ from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, JUMPSTART_DEFAULT_REGION_NAME, ) @@ -26,6 +27,7 @@ get_jumpstart_content_bucket, verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_model_uri( @@ -35,6 +37,7 @@ def _retrieve_model_uri( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Retrieves the model artifact S3 URI for the model matching the given arguments. @@ -55,6 +58,10 @@ def _retrieve_model_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model artifact S3 URI for the corresponding model. @@ -74,6 +81,7 @@ def _retrieve_model_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if model_scope == JumpStartScriptScope.INFERENCE: @@ -100,6 +108,7 @@ def _model_supports_training_model_uri( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns True if the model supports training with model uri field. @@ -117,6 +126,10 @@ def _model_supports_training_model_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for model uri with training. """ @@ -131,6 +144,7 @@ def _model_supports_training_model_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.use_training_model_artifact() diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 473739de0d..8d599c89cc 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -18,6 +18,7 @@ from sagemaker.jumpstart.constants import ( ACCEPT_TYPE_TO_DESERIALIZER_TYPE_MAP, CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP, + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, DESERIALIZER_TYPE_TO_CLASS_MAP, JUMPSTART_DEFAULT_REGION_NAME, SERIALIZER_TYPE_TO_CLASS_MAP, @@ -29,6 +30,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_serializer_from_content_type( @@ -73,6 +75,7 @@ def _retrieve_default_deserializer( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseDeserializer: """Retrieves the default deserializer for the model. @@ -89,6 +92,10 @@ def _retrieve_default_deserializer( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: BaseDeserializer: the default deserializer to use for the model. @@ -100,6 +107,7 @@ def _retrieve_default_deserializer( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return _retrieve_deserializer_from_accept_type(MIMEType.from_suffixed_type(default_accept_type)) @@ -111,6 +119,7 @@ def _retrieve_default_serializer( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model. @@ -127,7 +136,10 @@ def _retrieve_default_serializer( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: BaseSerializer: the default serializer to use for the model. """ @@ -138,6 +150,7 @@ def _retrieve_default_serializer( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return _retrieve_serializer_from_content_type(MIMEType.from_suffixed_type(default_content_type)) @@ -149,6 +162,7 @@ def _retrieve_deserializer_options( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[BaseDeserializer]: """Retrieves the supported deserializers for the model. @@ -165,7 +179,10 @@ def _retrieve_deserializer_options( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[BaseDeserializer]: the supported deserializers to use for the model. """ @@ -176,6 +193,7 @@ def _retrieve_deserializer_options( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) seen_classes: Set[Type] = set() @@ -201,6 +219,7 @@ def _retrieve_serializer_options( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model. @@ -217,7 +236,10 @@ def _retrieve_serializer_options( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[BaseSerializer]: the supported serializers to use for the model. """ @@ -228,6 +250,7 @@ def _retrieve_serializer_options( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) seen_classes: Set[Type] = set() @@ -253,6 +276,7 @@ def _retrieve_default_content_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default content type for the model. @@ -269,7 +293,10 @@ def _retrieve_default_content_type( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default content type to use for the model. """ @@ -284,6 +311,7 @@ def _retrieve_default_content_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) default_content_type = model_specs.predictor_specs.default_content_type @@ -296,6 +324,7 @@ def _retrieve_default_accept_type( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the default accept type for the model. @@ -312,7 +341,10 @@ def _retrieve_default_accept_type( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default accept type to use for the model. """ @@ -327,6 +359,7 @@ def _retrieve_default_accept_type( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) default_accept_type = model_specs.predictor_specs.default_accept_type @@ -340,6 +373,7 @@ def _retrieve_supported_accept_types( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported accept types for the model. @@ -356,7 +390,10 @@ def _retrieve_supported_accept_types( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported accept types to use for the model. """ @@ -371,6 +408,7 @@ def _retrieve_supported_accept_types( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) supported_accept_types = model_specs.predictor_specs.supported_accept_types @@ -384,6 +422,7 @@ def _retrieve_supported_content_types( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[str]: """Retrieves the supported content types for the model. @@ -400,7 +439,10 @@ def _retrieve_supported_content_types( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: the supported content types to use for the model. """ @@ -415,6 +457,7 @@ def _retrieve_supported_content_types( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) supported_content_types = model_specs.predictor_specs.supported_content_types diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index ca20cabdda..6b05f07b15 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -22,6 +23,7 @@ from sagemaker.jumpstart.utils import ( verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_resource_name_base( @@ -30,6 +32,7 @@ def _retrieve_resource_name_base( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns default resource name. @@ -47,6 +50,10 @@ def _retrieve_resource_name_base( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the default resource name. """ @@ -61,6 +68,7 @@ def _retrieve_resource_name_base( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.resource_name_base diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index f1bc4f31dd..c1b037ce61 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -15,6 +15,7 @@ import os from typing import Optional from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, JUMPSTART_DEFAULT_REGION_NAME, ) @@ -25,6 +26,7 @@ get_jumpstart_content_bucket, verify_model_region_and_return_specs, ) +from sagemaker.session import Session def _retrieve_script_uri( @@ -34,6 +36,7 @@ def _retrieve_script_uri( region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ): """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -55,6 +58,10 @@ def _retrieve_script_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: the model script URI for the corresponding model. @@ -74,6 +81,7 @@ def _retrieve_script_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) if script_scope == JumpStartScriptScope.INFERENCE: @@ -98,6 +106,7 @@ def _model_supports_inference_script_uri( region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns True if the model supports inference with script uri field. @@ -115,6 +124,10 @@ def _model_supports_inference_script_uri( tolerate_deprecated_model (bool): True if deprecated versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: bool: the support status for script uri with inference. """ @@ -129,6 +142,7 @@ def _model_supports_inference_script_uri( region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) return model_specs.use_inference_script_uri() diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 0a7a29a8cd..25206645f2 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -68,6 +68,7 @@ def __init__( JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, + s3_client: Optional[boto3.client] = None, ) -> None: # fmt: on """Initialize a ``JumpStartModelsCache`` instance. @@ -88,6 +89,7 @@ def __init__( Default: JumpStart-hosted content bucket for region. s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). + s3_client (Optional[boto3.client]): s3 client to use. Default: None. """ self._region = region @@ -109,7 +111,7 @@ def __init__( if s3_bucket_name is None else s3_bucket_name ) - self._s3_client = ( + self._s3_client = s3_client or ( boto3.client("s3", region_name=self._region, config=s3_client_config) if s3_client_config else boto3.client("s3", region_name=self._region) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index dc22b57fce..ad764317c6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -30,6 +30,7 @@ IdentitySerializer, JSONSerializer, ) +from sagemaker.session import Session JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set( @@ -177,3 +178,7 @@ MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") + +DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session( + boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) +) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 129692262b..3948bf5775 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -502,6 +502,7 @@ def _is_valid_model_id_hook(): model_version=model_version, region=region, script=JumpStartScriptScope.TRAINING, + sagemaker_session=sagemaker_session, ) if not _is_valid_model_id_hook(): @@ -649,6 +650,7 @@ def fit( experiment_config=experiment_config, tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, + sagemaker_session=self.sagemaker_session, ) return super(JumpStartEstimator, self).fit(**estimator_fit_kwargs.to_kwargs_dict()) @@ -991,6 +993,7 @@ def deploy( region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 92fccb39b0..3071411183 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -43,6 +43,7 @@ _model_supports_training_model_uri, ) from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, TRAINING_ENTRY_POINT_SCRIPT_NAME, @@ -208,6 +209,7 @@ def get_fit_kwargs( experiment_config: Optional[Dict[str, str]] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, + sagemaker_session: Optional[Session] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -222,6 +224,7 @@ def get_fit_kwargs( experiment_config=experiment_config, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) estimator_fit_kwargs = _add_model_version_to_kwargs(estimator_fit_kwargs) @@ -298,6 +301,7 @@ def get_deploy_kwargs( explainer_config=explainer_config, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) model_init_kwargs: JumpStartModelInitKwargs = model.get_init_kwargs( @@ -314,7 +318,7 @@ def get_deploy_kwargs( role=role, name=model_name, vpc_config=vpc_config, - sagemaker_session=sagemaker_session, + sagemaker_session=model_deploy_kwargs.sagemaker_session, enable_network_isolation=enable_network_isolation, model_kms_key=model_kms_key, image_config=image_config, @@ -381,7 +385,7 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" - kwargs.sagemaker_session = kwargs.sagemaker_session or Session() + kwargs.sagemaker_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION return kwargs @@ -420,6 +424,7 @@ def _add_instance_type_and_count_to_kwargs( scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.instance_count = kwargs.instance_count or 1 @@ -444,6 +449,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) return kwargs @@ -458,6 +464,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ): default_model_uri = model_uris.retrieve( model_scope=JumpStartScriptScope.TRAINING, @@ -465,6 +472,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) if ( @@ -476,6 +484,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) ): JUMPSTART_LOGGER.warning( @@ -509,6 +518,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) return kwargs @@ -526,6 +536,7 @@ def _add_env_to_kwargs( scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) if model_package_artifact_uri: @@ -560,6 +571,7 @@ def _add_training_job_name_to_kwargs( region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.job_name = kwargs.job_name or ( @@ -584,6 +596,7 @@ def _add_hyperparameters_to_kwargs( model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in default_hyperparameters.items(): @@ -615,6 +628,7 @@ def _add_metric_definitions_to_kwargs( model_version=kwargs.model_version, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) or [] ) @@ -643,6 +657,7 @@ def _add_estimator_extra_kwargs( region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in estimator_kwargs_to_add.items(): @@ -666,6 +681,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in fit_kwargs_to_add.items(): diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index ac5640ff24..967dd7f8ec 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -28,6 +28,7 @@ ) from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_LOGGER, @@ -59,6 +60,7 @@ def get_default_predictor( region: str, tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, + sagemaker_session: Session, ) -> Predictor: """Converts predictor returned from ``Model.deploy()`` into a JumpStart-specific one. @@ -79,6 +81,7 @@ def get_default_predictor( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) predictor.deserializer = deserializers.retrieve_default( model_id=model_id, @@ -86,6 +89,7 @@ def get_default_predictor( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) predictor.accept = accept_types.retrieve_default( model_id=model_id, @@ -93,6 +97,7 @@ def get_default_predictor( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) predictor.content_type = content_types.retrieve_default( model_id=model_id, @@ -100,6 +105,7 @@ def get_default_predictor( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) return predictor @@ -113,9 +119,11 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni return kwargs -def _add_sagemaker_session_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: +def _add_sagemaker_session_to_kwargs( + kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs] +) -> JumpStartModelInitKwargs: """Sets session in kwargs based on default or override, returns full kwargs.""" - kwargs.sagemaker_session = kwargs.sagemaker_session or Session() + kwargs.sagemaker_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION return kwargs @@ -165,6 +173,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) if orig_instance_type is None: @@ -188,6 +197,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) return kwargs @@ -205,6 +215,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) return kwargs @@ -221,6 +232,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ): source_dir = source_dir or script_uris.retrieve( script_scope=JumpStartScriptScope.INFERENCE, @@ -229,6 +241,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.source_dir = source_dir @@ -247,6 +260,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ): entry_point = entry_point or INFERENCE_ENTRY_POINT_SCRIPT_NAME @@ -271,6 +285,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in extra_env_vars.items(): @@ -298,6 +313,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.model_package_arn = model_package_arn @@ -313,6 +329,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in model_kwargs_to_add.items(): @@ -347,6 +364,7 @@ def _add_endpoint_name_to_kwargs( region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.endpoint_name = kwargs.endpoint_name or ( @@ -367,6 +385,7 @@ def _add_model_name_to_kwargs( region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) kwargs.name = kwargs.name or ( @@ -386,6 +405,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + sagemaker_session=kwargs.sagemaker_session, ) for key, value in deploy_kwargs_to_add.items(): @@ -418,6 +438,7 @@ def get_deploy_kwargs( explainer_config: Optional[ExplainerConfig] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, + sagemaker_session: Optional[Session] = None, ) -> JumpStartModelDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object.""" @@ -444,8 +465,11 @@ def get_deploy_kwargs( explainer_config=explainer_config, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) + deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) + deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 71bf0058ca..00ba8ce13e 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -261,6 +261,7 @@ def _is_valid_model_id_hook(): model_version=model_version, region=region, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=sagemaker_session, ) if not _is_valid_model_id_hook(): @@ -307,6 +308,7 @@ def _is_valid_model_id_hook(): self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.model_package_arn = model_init_kwargs.model_package_arn + self.sagemaker_session = model_init_kwargs.sagemaker_session super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) @@ -515,6 +517,7 @@ def deploy( region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, + sagemaker_session=self.sagemaker_session, ) # If a predictor class was passed, do not mutate predictor diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 839dff7474..e643a2cda1 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -16,6 +16,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Set, Union +from sagemaker.session import Session + class JumpStartDataHolderType: """Base class for many JumpStart types. @@ -698,6 +700,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "explainer_config", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "sagemaker_session", ] SERIALIZATION_EXCLUSION_SET = { @@ -706,6 +709,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", + "sagemaker_session", } def __init__( @@ -732,6 +736,7 @@ def __init__( explainer_config: Optional[Any] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, + sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -757,6 +762,7 @@ def __init__( self.explainer_config = explainer_config self.tolerate_vulnerable_model = tolerate_vulnerable_model self.tolerate_deprecated_model = tolerate_deprecated_model + self.sagemaker_session = sagemaker_session class JumpStartEstimatorInitKwargs(JumpStartKwargs): @@ -949,6 +955,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "experiment_config", "tolerate_deprecated_model", "tolerate_vulnerable_model", + "sagemaker_session", ] SERIALIZATION_EXCLUSION_SET = { @@ -972,6 +979,7 @@ def __init__( experiment_config: Optional[Dict[str, str]] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, + sagemaker_session: Optional[Session] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -985,6 +993,7 @@ def __init__( self.experiment_config = experiment_config self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model + self.sagemaker_session = sagemaker_session class JumpStartEstimatorDeployKwargs(JumpStartKwargs): diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index c33b411fce..f77e1ae231 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -100,6 +100,7 @@ def get_jumpstart_content_bucket( accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return) if bucket_to_return != old_content_bucket: + accessors.JumpStartModelsAccessor.reset_cache() for info_log in info_logs: constants.JUMPSTART_LOGGER.info(info_log) return bucket_to_return @@ -397,6 +398,7 @@ def verify_model_region_and_return_specs( region: str, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> JumpStartModelSpecs: """Verifies that an acceptable model_id, version, scope, and region combination is provided. @@ -415,7 +417,10 @@ def verify_model_region_and_return_specs( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). - + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Raises: NotImplementedError: If the scope is not supported. @@ -437,8 +442,11 @@ def verify_model_region_and_return_specs( f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}." ) - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( - region=region, model_id=model_id, version=version # type: ignore + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore + region=region, + model_id=model_id, + version=version, + s3_client=sagemaker_session.s3_client, ) if ( @@ -591,6 +599,7 @@ def is_valid_model_id( region: Optional[str] = None, model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, + sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> bool: """Returns True if the model ID is supported for the given script. @@ -602,10 +611,13 @@ def is_valid_model_id( if not isinstance(model_id, str): return False + s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME model_version = model_version or "*" - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest( + region=region, s3_client=s3_client + ) model_id_set = {model.model_id for model in models_manifest_list} if script == enums.JumpStartScriptScope.INFERENCE: return model_id in model_id_set @@ -616,6 +628,7 @@ def is_valid_model_id( region=region, model_id=model_id, version=model_version, + s3_client=s3_client, ).training_supported ) raise ValueError(f"Unsupported script: {script}") diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index fecd02a2eb..3199e5fc2e 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -13,6 +13,7 @@ """This module contains validators related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Any, Dict, List, Optional +from sagemaker import session from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.enums import ( @@ -168,6 +169,9 @@ def validate_hyperparameters( hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, + sagemaker_session: Optional[session.Session] = None, + tolerate_vulnerable_model: bool = False, + tolerate_deprecated_model: bool = False, ) -> None: """Validate hyperparameters for JumpStart models. @@ -182,6 +186,15 @@ def validate_hyperparameters( If set to ``VALIDATE_ALL``, all hyperparameters for the model will be validated. region (str): Region for which to validate hyperparameters. (Default: JumpStart default region). + sagemaker_session (Optional[Session]): Custom SageMaker Session to use. + (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + tolerate_vulnerable_model (bool): True if vulnerable versions of model + specifications should be tolerated (exception not raised). If False, raises an + exception if the script used by this version of the model has dependencies with known + security vulnerabilities. (Default: False). + tolerate_deprecated_model (bool): True if deprecated models should be tolerated + (exception not raised). False if these models should raise an exception. + (Default: False). Raises: JumpStartHyperparametersError: If the hyperparameters are not formatted correctly, @@ -200,6 +213,9 @@ def validate_hyperparameters( version=model_version, region=region, scope=JumpStartScriptScope.TRAINING, + sagemaker_session=sagemaker_session, + tolerate_deprecated_model=tolerate_deprecated_model, + tolerate_vulnerable_model=tolerate_vulnerable_model, ) hyperparameters_specs = model_specs.hyperparameters diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index ec6ae387d5..648c6e0cb4 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -19,6 +19,8 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -29,6 +31,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> Optional[List[Dict[str, str]]]: """Retrieves the default training metric definitions for the model matching the given arguments. @@ -46,6 +49,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: list: The default metric definitions to use for the model or None. @@ -59,5 +66,10 @@ def retrieve_default( ) return artifacts._retrieve_default_training_metric_definitions( - model_id, model_version, region, tolerate_vulnerable_model, tolerate_deprecated_model + model_id, + model_version, + region, + tolerate_vulnerable_model, + tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 982f3ca908..91890be975 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -18,6 +18,8 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,6 +32,7 @@ def retrieve( model_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the model artifact Amazon S3 URI for the model matching the given arguments. @@ -48,6 +51,10 @@ def retrieve( tolerate_deprecated_model (bool): ``True`` if deprecated versions of model specifications should be tolerated without raising an exception. If ``False``, raises an exception if the version of the model is deprecated. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The model artifact S3 URI for the corresponding model. @@ -70,4 +77,5 @@ def retrieve( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index ac4cb78df1..7b436a9dd8 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -14,6 +14,7 @@ from __future__ import print_function, absolute_import from typing import Optional +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.factory.model import get_default_predictor from sagemaker.jumpstart.utils import is_jumpstart_model_input @@ -32,7 +33,7 @@ def retrieve_default( endpoint_name: str, - sagemaker_session: Optional[Session] = None, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, @@ -44,7 +45,7 @@ def retrieve_default( Args: endpoint_name (str): Endpoint name for which to create a predictor. sagemaker_session (Session): The SageMaker Session to attach to the Predictor. - (Default: None). + (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). region (str): The AWS Region for which to retrieve the default predictor. (Default: None). model_id (str): The model ID of the model for which to @@ -80,4 +81,5 @@ def retrieve_default( region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 4bdf64060d..9a1c4933d2 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -19,6 +19,8 @@ from sagemaker.jumpstart import utils as jumpstart_utils from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session logger = logging.getLogger(__name__) @@ -30,6 +32,7 @@ def retrieve( script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> str: """Retrieves the script S3 URI associated with the model matching the given arguments. @@ -48,6 +51,10 @@ def retrieve( tolerate_deprecated_model (bool): ``True`` if deprecated models should be tolerated without raising an exception. ``False`` if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: str: The model script URI for the corresponding model. @@ -70,4 +77,5 @@ def retrieve( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 54091e71a6..60365d2621 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -31,6 +31,8 @@ ) from sagemaker.jumpstart import artifacts, utils as jumpstart_utils +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.session import Session def retrieve_options( @@ -39,6 +41,7 @@ def retrieve_options( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> List[BaseSerializer]: """Retrieves the supported serializers for the model matching the given arguments. @@ -56,6 +59,10 @@ def retrieve_options( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: List[SimpleBaseSerializer]: The supported serializers to use for the model. @@ -74,6 +81,7 @@ def retrieve_options( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) @@ -83,6 +91,7 @@ def retrieve_default( model_version: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, + sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> BaseSerializer: """Retrieves the default serializer for the model matching the given arguments. @@ -100,6 +109,10 @@ def retrieve_default( tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception not raised). False if these models should raise an exception. (Default: False). + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. If not + specified, one is created using the default AWS configuration + chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). Returns: SimpleBaseSerializer: The default serializer to use for the model. @@ -118,4 +131,5 @@ def retrieve_default( region, tolerate_vulnerable_model, tolerate_deprecated_model, + sagemaker_session=sagemaker_session, ) diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index f5a4798374..28211d06f1 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -12,14 +12,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import - -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -37,11 +40,12 @@ def test_jumpstart_default_accept_types( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert default_accept_type == "application/json" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) @@ -61,6 +65,7 @@ def test_jumpstart_supported_accept_types( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert supported_accept_types == [ "application/json;verbose", @@ -68,5 +73,5 @@ def test_jumpstart_supported_accept_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index a0076e8a78..4b2db7d7f4 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -12,14 +12,17 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import - -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock from sagemaker import content_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -37,11 +40,12 @@ def test_jumpstart_default_content_types( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert default_content_type == "application/x-text" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) @@ -61,11 +65,12 @@ def test_jumpstart_supported_content_types( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert supported_content_types == [ "application/x-text", ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 60784cd19e..9d6e2f21de 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -13,7 +13,8 @@ from __future__ import absolute_import -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock from sagemaker import base_deserializers, deserializers from sagemaker.jumpstart.utils import verify_model_region_and_return_specs @@ -21,6 +22,10 @@ from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + + @patch("sagemaker.jumpstart.artifacts.predictors.verify_model_region_and_return_specs") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_deserializers( @@ -37,11 +42,12 @@ def test_jumpstart_default_deserializers( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert isinstance(default_deserializer, base_deserializers.JSONDeserializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) @@ -61,6 +67,7 @@ def test_jumpstart_deserializer_options( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert len(deserializer_options) == 1 @@ -72,5 +79,5 @@ def test_jumpstart_deserializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index 169bd74758..da9d1559b7 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -13,13 +13,17 @@ from __future__ import absolute_import -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock import pytest from sagemaker import environment_variables from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_environment_variables(patched_get_model_specs): @@ -30,9 +34,7 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): region = "us-west-2" vars = environment_variables.retrieve_default( - region=region, - model_id=model_id, - model_version="*", + region=region, model_id=model_id, model_version="*", sagemaker_session=mock_session ) assert vars == { "MODEL_CACHE_ROOT": "/opt/ml/model", @@ -45,14 +47,14 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() vars = environment_variables.retrieve_default( - region=region, - model_id=model_id, - model_version="1.*", + region=region, model_id=model_id, model_version="1.*", sagemaker_session=mock_session ) assert vars == { "MODEL_CACHE_ROOT": "/opt/ml/model", @@ -65,7 +67,9 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="1.*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() @@ -107,6 +111,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): model_id=model_id, model_version="*", include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, ) assert vars == { "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -116,7 +121,9 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): "SAGEMAKER_PROGRAM": "inference.py", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() @@ -125,6 +132,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): model_id=model_id, model_version="1.*", include_aws_sdk_env_vars=False, + sagemaker_session=mock_session, ) assert vars == { "ENDPOINT_SERVER_TIMEOUT": "3600", @@ -134,7 +142,9 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): "SAGEMAKER_PROGRAM": "inference.py", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="1.*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index b7776b02a9..2d8f7d8166 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -12,8 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import boto3 -from mock.mock import patch +from mock.mock import patch, Mock import pytest from sagemaker import hyperparameters @@ -21,6 +22,10 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_default_hyperparameters(patched_get_model_specs): @@ -33,10 +38,16 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): region=region, model_id=model_id, model_version="*", + sagemaker_session=mock_session, ) assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*") + patched_get_model_specs.assert_called_once_with( + region=region, + model_id=model_id, + version="*", + s3_client=mock_client, + ) patched_get_model_specs.reset_mock() @@ -44,10 +55,16 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): region=region, model_id=model_id, model_version="1.*", + sagemaker_session=mock_session, ) assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + ) patched_get_model_specs.reset_mock() @@ -56,6 +73,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): model_id=model_id, model_version="1.*", include_container_hyperparameters=True, + sagemaker_session=mock_session, ) assert params == { "adam-learning-rate": "0.05", @@ -66,7 +84,12 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): "sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz", } - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, + model_id=model_id, + version="1.*", + s3_client=mock_client, + ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 83092f74e5..0054ed9dbd 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -13,9 +13,9 @@ from __future__ import absolute_import -from mock.mock import patch +from mock.mock import patch, Mock import pytest - +import boto3 from sagemaker import hyperparameters from sagemaker.jumpstart.enums import HyperparameterValidationMode from sagemaker.jumpstart.exceptions import JumpStartHyperparametersError @@ -23,6 +23,9 @@ from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec +mock_client = boto3.client("s3") +mock_session = Mock(s3_client=mock_client) + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_validate_provided_hyperparameters(patched_get_model_specs): @@ -129,10 +132,14 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() @@ -144,6 +151,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_id=model_id, model_version=model_version, hyperparameters=hyperparameter_to_test, + sagemaker_session=mock_session, ) hyperparameter_to_test["batch-size"] = "0" @@ -425,10 +433,11 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_version=model_version, hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) patched_get_model_specs.reset_mock() @@ -440,6 +449,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): model_version=model_version, hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALGORITHM, + sagemaker_session=mock_session, ) del hyperparameter_to_test["adam-learning-rate"] @@ -477,10 +487,11 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): model_version=model_version, hyperparameters=hyperparameter_to_test, validation_mode=HyperparameterValidationMode.VALIDATE_ALL, + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index a16145585b..8a41891280 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -12,7 +12,8 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -from mock.mock import patch +import boto3 +from mock.mock import patch, Mock import pytest from sagemaker import image_uris @@ -31,6 +32,9 @@ def test_jumpstart_common_image_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + image_uris.retrieve( framework=None, region="us-west-2", @@ -38,9 +42,13 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", model_version="*", instance_type="ml.p2.xlarge", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -54,9 +62,13 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", model_version="1.*", instance_type="ml.p2.xlarge", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -70,11 +82,13 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", model_version="*", instance_type="ml.p2.xlarge", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -88,11 +102,13 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", model_version="1.*", instance_type="ml.p2.xlarge", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index b21025968b..f13121aa94 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -11,8 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock - +import boto3 from mock.mock import patch import pytest @@ -29,30 +30,51 @@ def test_jumpstart_instance_types(patched_get_model_specs): model_id, model_version = "huggingface-eqa-bert-base-cased", "*" region = "us-west-2" + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + default_training_instance_types = instance_types.retrieve_default( - region=region, model_id=model_id, model_version=model_version, scope="training" + region=region, + model_id=model_id, + model_version=model_version, + scope="training", + sagemaker_session=mock_session, ) assert default_training_instance_types == "ml.p3.2xlarge" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() default_inference_instance_types = instance_types.retrieve_default( - region=region, model_id=model_id, model_version=model_version, scope="inference" + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, ) assert default_inference_instance_types == "ml.p2.xlarge" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() default_training_instance_types = instance_types.retrieve( - region=region, model_id=model_id, model_version=model_version, scope="training" + region=region, + model_id=model_id, + model_version=model_version, + scope="training", + sagemaker_session=mock_session, ) assert default_training_instance_types == [ "ml.p3.2xlarge", @@ -63,13 +85,20 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() default_inference_instance_types = instance_types.retrieve( - region=region, model_id=model_id, model_version=model_version, scope="inference" + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, ) assert default_inference_instance_types == [ "ml.p2.xlarge", @@ -82,7 +111,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, model_id=model_id, version=model_version, s3_client=mock_client ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 198839475b..6cfcc3420a 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -23,6 +23,7 @@ from sagemaker.debugger.profiler_config import ProfilerConfig from sagemaker.estimator import Estimator from sagemaker.instance_group import InstanceGroup +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.estimator import JumpStartEstimator @@ -31,7 +32,6 @@ from tests.integ.sagemaker.jumpstart.utils import get_training_dataset_for_model_and_version from sagemaker.model import Model from sagemaker.predictor import Predictor -from sagemaker.session import Session from tests.unit.sagemaker.jumpstart.utils import ( get_special_model_spec, overwrite_dictionary, @@ -40,7 +40,7 @@ execution_role = "fake role! do not use!" region = "us-west-2" -sagemaker_session = Session() +sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION sagemaker_session.get_caller_identity_arn = lambda: execution_role default_predictor = Predictor("eiifccreeeiuchhnehtlbdecgeeelgjccjvvbbcncnhv", sagemaker_session) default_predictor_with_presets = Predictor( @@ -434,7 +434,7 @@ def test_estimator_use_kwargs(self): "output_path": "Optional[Union[str, PipelineVariable]] = None", "output_kms_key": "Optional[Union[str, PipelineVariable]] = None", "base_job_name": "Optional[str] = None", - "sagemaker_session": Session(), + "sagemaker_session": DEFAULT_JUMPSTART_SAGEMAKER_SESSION, "hyperparameters": {"hyp1": "val1"}, "tags": [{"1": "hum"}], "subnets": ["1", "2"], @@ -747,6 +747,7 @@ def test_no_predictor_returns_default_predictor( region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, + sagemaker_session=estimator.sagemaker_session, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -903,6 +904,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, + sagemaker_session=sagemaker_session, ) @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @@ -953,6 +955,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, + sagemaker_session=sagemaker_session, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1030,8 +1033,13 @@ def test_training_passes_role_to_deploy( @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") - @mock.patch("sagemaker.jumpstart.factory.model.Session") - @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch( + "sagemaker.jumpstart.factory.model.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", sagemaker_session + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", + sagemaker_session, + ) @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") @@ -1044,8 +1052,6 @@ def test_training_passes_session_to_deploy( mock_estimator_fit: mock.Mock, mock_estimator_init: mock.Mock, mock_get_model_specs: mock.Mock, - mock_session_estimator: mock.Mock, - mock_session_model: mock.Mock, mock_is_valid_model_id: mock.Mock, mock_sagemaker_timestamp: mock.Mock, ): @@ -1059,11 +1065,8 @@ def test_training_passes_session_to_deploy( mock_get_model_specs.side_effect = get_special_model_spec - mock_session_estimator.return_value = sagemaker_session - mock_session_model.return_value = sagemaker_session - mock_role = f"dsfsdfsd{time.time()}" - mock_sagemaker_session = Session() + mock_sagemaker_session = mock.MagicMock(sagemaker_config={}) mock_sagemaker_session.get_caller_identity_arn = lambda: mock_role estimator = JumpStartEstimator( @@ -1146,12 +1149,14 @@ def test_model_id_not_found_refeshes_cache_training( model_version=None, region=None, script=JumpStartScriptScope.TRAINING, + sagemaker_session=None, ), mock.call( model_id="js-trainable-model", model_version=None, region=None, script=JumpStartScriptScope.TRAINING, + sagemaker_session=None, ), ] ) @@ -1172,12 +1177,14 @@ def test_model_id_not_found_refeshes_cache_training( model_version=None, region=None, script=JumpStartScriptScope.TRAINING, + sagemaker_session=None, ), mock.call( model_id="js-trainable-model", model_version=None, region=None, script=JumpStartScriptScope.TRAINING, + sagemaker_session=None, ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py index 1f2b11d6a6..d22e910a00 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_sagemaker_config.py @@ -22,6 +22,7 @@ TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH, TRAINING_JOB_ROLE_ARN_PATH, ) +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.estimator import JumpStartEstimator from sagemaker.session import Session @@ -32,7 +33,7 @@ execution_role = "fake role! do not use!" region = "us-west-2" -sagemaker_session = Session() +sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION sagemaker_session.get_caller_identity_arn = lambda: execution_role default_predictor = Predictor("eiifccreeeiujigjjdfgiujrcibigckbtregvkjeurru", sagemaker_session) diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index fb7698741e..9800d71668 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -18,19 +18,19 @@ from mock import MagicMock import pytest from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.model import JumpStartModel from sagemaker.model import Model from sagemaker.predictor import Predictor -from sagemaker.session import Session from tests.unit.sagemaker.jumpstart.utils import get_special_model_spec, overwrite_dictionary execution_role = "fake role! do not use!" region = "us-west-2" -sagemaker_session = Session() +sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION sagemaker_session.get_caller_identity_arn = lambda: execution_role default_predictor = Predictor("blah", sagemaker_session) default_predictor_with_presets = Predictor( @@ -39,6 +39,9 @@ class ModelTest(unittest.TestCase): + + mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -416,6 +419,7 @@ def test_no_predictor_returns_default_predictor( region=region, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, + sagemaker_session=model.sagemaker_session, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -532,12 +536,14 @@ def test_model_id_not_found_refeshes_cach_inference( model_version=None, region=None, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=None, ), mock.call( model_id="js-trainable-model", model_version=None, region=None, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=None, ), ] ) @@ -558,24 +564,24 @@ def test_model_id_not_found_refeshes_cach_inference( model_version=None, region=None, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=None, ), mock.call( model_id="js-trainable-model", model_version=None, region=None, script=JumpStartScriptScope.INFERENCE, + sagemaker_session=None, ), ] ) @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") - @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn( self, mock_get_model_specs: mock.Mock, - mock_session: mock.Mock, mock_is_valid_model_id: mock.Mock, ): @@ -585,9 +591,9 @@ def test_jumpstart_model_package_arn( mock_get_model_specs.side_effect = get_special_model_spec - mock_session.return_value = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}) - model = JumpStartModel(model_id=model_id) + model = JumpStartModel(model_id=model_id, sagemaker_session=mock_session) tag = {"Key": "foo", "Value": "bar"} tags = [tag] @@ -595,23 +601,21 @@ def test_jumpstart_model_package_arn( model.deploy(tags=tags) self.assertEqual( - mock_session.return_value.create_model.call_args[0][2], + mock_session.create_model.call_args[0][2], { "ModelPackageName": "arn:aws:sagemaker:us-west-2:594846645681:model-package" "/llama2-7b-f-e46eb8a833643ed58aaccd81498972c3" }, ) - self.assertIn(tag, mock_session.return_value.create_model.call_args[1]["tags"]) + self.assertIn(tag, mock_session.create_model.call_args[1]["tags"]) @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") - @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_jumpstart_model_package_arn_override( self, mock_get_model_specs: mock.Mock, - mock_session: mock.Mock, mock_is_valid_model_id: mock.Mock, ): @@ -622,18 +626,20 @@ def test_jumpstart_model_package_arn_override( mock_get_model_specs.side_effect = get_special_model_spec - mock_session.return_value = MagicMock(sagemaker_config={}) + mock_session = MagicMock(sagemaker_config={}) model_package_arn = ( "arn:aws:sagemaker:us-west-2:867530986753:model-package/" "llama2-ynnej-f-e46eb8a833643ed58aaccd81498972c3" ) - model = JumpStartModel(model_id=model_id, model_package_arn=model_package_arn) + model = JumpStartModel( + model_id=model_id, model_package_arn=model_package_arn, sagemaker_session=mock_session + ) model.deploy() self.assertEqual( - mock_session.return_value.create_model.call_args[0][2], + mock_session.create_model.call_args[0][2], { "ModelPackageName": model_package_arn, "Environment": { diff --git a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py index 00563ad106..727f3060b3 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py +++ b/tests/unit/sagemaker/jumpstart/model/test_sagemaker_config.py @@ -18,6 +18,7 @@ MODEL_ENABLE_NETWORK_ISOLATION_PATH, MODEL_EXECUTION_ROLE_ARN_PATH, ) +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.jumpstart.model import JumpStartModel from sagemaker.session import Session @@ -28,7 +29,7 @@ execution_role = "fake role! do not use!" region = "us-west-2" -sagemaker_session = Session() +sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION sagemaker_session.get_caller_identity_arn = lambda: execution_role override_role = "fdsfsdfs" @@ -56,7 +57,7 @@ class IntelligentDefaultsModelTest(unittest.TestCase): execution_role = "fake role! do not use!" region = "us-west-2" - sagemaker_session = Session() + sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.model.Model.__init__") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 2627b2f8bc..3ddb1b10e8 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -18,6 +18,7 @@ import random from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, JUMPSTART_BUCKET_NAME_SET, JUMPSTART_DEFAULT_REGION_NAME, @@ -25,6 +26,8 @@ JUMPSTART_RESOURCE_BASE_NAME, JumpStartScriptScope, ) + +from functools import partial from sagemaker.jumpstart.enums import JumpStartTag, MIMEType from sagemaker.jumpstart.exceptions import ( DeprecatedJumpStartModelError, @@ -972,25 +975,39 @@ def test_is_valid_model_id_true( Mock(model_id="bee"), Mock(model_id="see"), ] - self.assertTrue(utils.is_valid_model_id("bee")) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - mock_get_model_specs.assert_not_called() - mock_get_manifest.reset_mock() - mock_get_model_specs.reset_mock() - - mock_get_manifest.return_value = [ - Mock(model_id="ay"), - Mock(model_id="bee"), - Mock(model_id="see"), - ] - - mock_get_model_specs.return_value = Mock(training_supported=True) - self.assertTrue(utils.is_valid_model_id("bee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - mock_get_model_specs.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, model_id="bee", version="*" - ) + mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + mock_s3_client_value = mock_session_value.s3_client + + patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + + with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): + self.assertTrue(utils.is_valid_model_id("bee")) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + mock_get_model_specs.assert_not_called() + + mock_get_manifest.reset_mock() + mock_get_model_specs.reset_mock() + + mock_get_manifest.return_value = [ + Mock(model_id="ay"), + Mock(model_id="bee"), + Mock(model_id="see"), + ] + + mock_get_model_specs.return_value = Mock(training_supported=True) + self.assertTrue(utils.is_valid_model_id("bee", script=JumpStartScriptScope.TRAINING)) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + mock_get_model_specs.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + model_id="bee", + version="*", + s3_client=mock_s3_client_value, + ) @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1000,41 +1017,61 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani Mock(model_id="bee"), Mock(model_id="see"), ] - self.assertFalse(utils.is_valid_model_id("dee")) - self.assertFalse(utils.is_valid_model_id("")) - self.assertFalse(utils.is_valid_model_id(None)) - self.assertFalse(utils.is_valid_model_id(set())) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - - mock_get_model_specs.assert_not_called() - - mock_get_manifest.reset_mock() - mock_get_model_specs.reset_mock() - - mock_get_manifest.return_value = [ - Mock(model_id="ay"), - Mock(model_id="bee"), - Mock(model_id="see"), - ] - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - mock_get_manifest.reset_mock() - - self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id("", script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(None, script=JumpStartScriptScope.TRAINING)) - self.assertFalse(utils.is_valid_model_id(set(), script=JumpStartScriptScope.TRAINING)) - - mock_get_model_specs.assert_not_called() - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - - mock_get_manifest.reset_mock() - mock_get_model_specs.reset_mock() - - mock_get_model_specs.return_value = Mock(training_supported=False) - self.assertFalse(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) - mock_get_manifest.assert_called_once_with(region=JUMPSTART_DEFAULT_REGION_NAME) - mock_get_model_specs.assert_called_once_with( - region=JUMPSTART_DEFAULT_REGION_NAME, model_id="ay", version="*" - ) + mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION + mock_s3_client_value = mock_session_value.s3_client + + patched = partial(utils.is_valid_model_id, sagemaker_session=mock_session_value) + + with patch("sagemaker.jumpstart.utils.is_valid_model_id", patched): + + self.assertFalse(utils.is_valid_model_id("dee")) + self.assertFalse(utils.is_valid_model_id("")) + self.assertFalse(utils.is_valid_model_id(None)) + self.assertFalse(utils.is_valid_model_id(set())) + + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + + mock_get_model_specs.assert_not_called() + + mock_get_manifest.reset_mock() + mock_get_model_specs.reset_mock() + + mock_get_manifest.return_value = [ + Mock(model_id="ay"), + Mock(model_id="bee"), + Mock(model_id="see"), + ] + self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + + mock_get_manifest.reset_mock() + + self.assertFalse(utils.is_valid_model_id("dee", script=JumpStartScriptScope.TRAINING)) + self.assertFalse(utils.is_valid_model_id("", script=JumpStartScriptScope.TRAINING)) + self.assertFalse(utils.is_valid_model_id(None, script=JumpStartScriptScope.TRAINING)) + self.assertFalse(utils.is_valid_model_id(set(), script=JumpStartScriptScope.TRAINING)) + + mock_get_model_specs.assert_not_called() + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + + mock_get_manifest.reset_mock() + mock_get_model_specs.reset_mock() + + mock_get_model_specs.return_value = Mock(training_supported=False) + self.assertFalse(utils.is_valid_model_id("ay", script=JumpStartScriptScope.TRAINING)) + mock_get_manifest.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value + ) + mock_get_model_specs.assert_called_once_with( + region=JUMPSTART_DEFAULT_REGION_NAME, + model_id="ay", + version="*", + s3_client=mock_s3_client_value, + ) diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 1c7b61c554..6be6f8251a 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import copy from typing import List +import boto3 from sagemaker.jumpstart.cache import JumpStartModelsCache from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_REGION_NAME_SET @@ -83,7 +84,10 @@ def get_prototype_manifest( def get_prototype_model_spec( - region: str = None, model_id: str = None, version: str = None + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -94,7 +98,10 @@ def get_prototype_model_spec( def get_special_model_spec( - region: str = None, model_id: str = None, version: str = None + region: str = None, + model_id: str = None, + version: str = None, + s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -111,6 +118,7 @@ def get_spec_from_base_spec( model_id: str = None, semantic_version_str: str = None, version: str = None, + s3_client: boto3.client = None, ) -> JumpStartModelSpecs: if version and semantic_version_str: diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ee1dd3aed4..bea68dd713 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -11,8 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock - +import boto3 from mock.mock import patch import pytest @@ -26,6 +27,9 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): patched_get_model_specs.side_effect = get_spec_from_base_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + model_id = "pytorch-ic-mobilenet-v2" region = "us-west-2" @@ -33,12 +37,15 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): region=region, model_id=model_id, model_version="*", + sagemaker_session=mock_session, ) assert definitions == [ {"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"} ] - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() @@ -46,12 +53,15 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): region=region, model_id=model_id, model_version="1.*", + sagemaker_session=mock_session, ) assert definitions == [ {"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"} ] - patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*") + patched_get_model_specs.assert_called_once_with( + region=region, model_id=model_id, version="1.*", s3_client=mock_client + ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 6efdef1830..000540e12e 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -11,7 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock +import boto3 from mock.mock import patch import pytest @@ -31,15 +33,20 @@ def test_jumpstart_common_model_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + model_uris.retrieve( model_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -50,11 +57,13 @@ def test_jumpstart_common_model_uri( model_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -66,9 +75,13 @@ def test_jumpstart_common_model_uri( model_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -80,9 +93,13 @@ def test_jumpstart_common_model_uri( model_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index ef33eff974..3f38326608 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -11,7 +11,9 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock import pytest +import boto3 from mock.mock import patch @@ -31,15 +33,20 @@ def test_jumpstart_common_script_uri( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_spec_from_base_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + script_uris.retrieve( script_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -50,11 +57,13 @@ def test_jumpstart_common_script_uri( script_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -66,9 +75,10 @@ def test_jumpstart_common_script_uri( script_scope="training", model_id="pytorch-ic-mobilenet-v2", model_version="*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" + region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -80,9 +90,13 @@ def test_jumpstart_common_script_uri( script_scope="inference", model_id="pytorch-ic-mobilenet-v2", model_version="1.*", + sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="1.*" + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="1.*", + s3_client=mock_client, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index dcca3a7b4d..b22b61dc40 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -11,6 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import Mock +import boto3 from mock.mock import patch @@ -32,16 +34,22 @@ def test_jumpstart_default_serializers( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) default_serializer = serializers.retrieve_default( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert isinstance(default_serializer, base_serializers.IdentitySerializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() @@ -56,6 +64,9 @@ def test_jumpstart_serializer_options( patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs patched_get_model_specs.side_effect = get_special_model_spec + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" @@ -63,6 +74,7 @@ def test_jumpstart_serializer_options( region=region, model_id=model_id, model_version=model_version, + sagemaker_session=mock_session, ) assert len(serializer_options) == 1 assert all( @@ -73,5 +85,8 @@ def test_jumpstart_serializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 21c68d2f13..7a31de9237 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -89,6 +89,7 @@ def test_default_session(boto3_default_session): assert sess.boto_session is boto3_default_session +@patch("boto3.DEFAULT_SESSION", None) @patch("boto3.Session") def test_new_session_created(boto3_session): sess = Session()