diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 74aa229d21..bafcfde3a8 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -265,7 +265,11 @@ def register( ) def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, ): """Return a container definition with framework configuration set in model environment. @@ -278,6 +282,11 @@ def prepare_container_def( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: dict[str, str]: A container definition object usable with the @@ -307,7 +316,12 @@ def prepare_container_def( deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( self.model_server_workers ) - return sagemaker.container_def(deploy_image, self.model_data, deploy_env) + return sagemaker.container_def( + deploy_image, + self.model_data, + deploy_env, + accept_eula=accept_eula, + ) def serving_image_uri( self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index c77b24fef5..118a4af5a0 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -733,6 +733,7 @@ def prepare_container_def( instance_type=None, accelerator_type=None, serverless_inference_config=None, + accept_eula=None, ): # pylint: disable=unused-argument """A container definition with framework configuration set in model environment variables. diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index bbbe3782bd..da294c89e2 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -465,6 +465,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, inference_tool=None, + accept_eula=None, ): """A container definition with framework configuration set in model environment variables. @@ -479,6 +480,11 @@ def prepare_container_def( not provided in serverless inference. So this is used to find image URIs. inference_tool (str): the tool that will be used to aid in the inference. Valid values: "neuron, neuronx, None" (default: None). + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: dict[str, str]: A container definition object usable with the @@ -510,7 +516,10 @@ def prepare_container_def( self.model_server_workers ) return sagemaker.container_def( - deploy_image, self.repacked_model_data or self.model_data, deploy_env + deploy_image, + self.repacked_model_data or self.model_data, + deploy_env, + accept_eula=accept_eula, ) def serving_image_uri( diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index e660cd65cc..daa9e0e30a 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -13,6 +13,7 @@ """This module stores constants related to SageMaker JumpStart.""" from __future__ import absolute_import import logging +import os from typing import Dict, Set, Type import boto3 from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer @@ -33,6 +34,8 @@ from sagemaker.session import Session +ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING = "DISABLE_JUMPSTART_LOGGING" + JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set( [ JumpStartLaunchedRegionInfo( @@ -209,6 +212,19 @@ JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") +# disable logging if env var is set +JUMPSTART_LOGGER.addHandler( + type( + "", + (logging.StreamHandler,), + { + "emit": lambda self, *args, **kwargs: logging.StreamHandler.emit(self, *args, **kwargs) + if not os.environ.get(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING) + else None + }, + )() +) + try: DEFAULT_JUMPSTART_SAGEMAKER_SESSION = Session( boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME) diff --git a/src/sagemaker/jumpstart/exceptions.py b/src/sagemaker/jumpstart/exceptions.py index 00ab1d3e52..c55c9081cb 100644 --- a/src/sagemaker/jumpstart/exceptions.py +++ b/src/sagemaker/jumpstart/exceptions.py @@ -39,6 +39,12 @@ "Note that models may have different input/output signatures after a major version upgrade." ) +_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION = ( + "We recommend that you specify a more recent " + "model version or choose a different model. To access the latest models " + "and model versions, be sure to upgrade to the latest version of the SageMaker Python SDK." +) + def get_wildcard_model_version_msg( model_id: str, wildcard_model_version: str, full_model_version: str @@ -115,16 +121,16 @@ def __init__( self.message = ( f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore "has at least 1 vulnerable dependency in the inference script. " - "Please try targeting a higher version of the model or using a " - "different model. List of vulnerabilities: " + f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} " + "List of vulnerabilities: " f"{', '.join(vulnerabilities)}" # type: ignore ) elif scope == JumpStartScriptScope.TRAINING: self.message = ( f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore "has at least 1 vulnerable dependency in the training script. " - "Please try targeting a higher version of the model or using a " - "different model. List of vulnerabilities: " + f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION} " + "List of vulnerabilities: " f"{', '.join(vulnerabilities)}" # type: ignore ) else: @@ -159,8 +165,7 @@ def __init__( raise RuntimeError("Must specify `model_id` and `version` arguments.") self.message = ( f"Version '{version}' of JumpStart model '{model_id}' is deprecated. " - "Please try targeting a higher version of the model or using a " - "different model." + f"{_VULNERABLE_DEPRECATED_ERROR_RECOMMENDATION}" ) super().__init__(self.message) diff --git a/src/sagemaker/jumpstart/filters.py b/src/sagemaker/jumpstart/filters.py index 5c79e717a0..56ef12a148 100644 --- a/src/sagemaker/jumpstart/filters.py +++ b/src/sagemaker/jumpstart/filters.py @@ -45,7 +45,6 @@ class SpecialSupportedFilterKeys(str, Enum): TASK = "task" FRAMEWORK = "framework" - SUPPORTED_MODEL = "supported_model" FILTER_OPERATOR_STRING_MAPPINGS = { @@ -74,7 +73,6 @@ class SpecialSupportedFilterKeys(str, Enum): [ SpecialSupportedFilterKeys.TASK, SpecialSupportedFilterKeys.FRAMEWORK, - SpecialSupportedFilterKeys.SUPPORTED_MODEL, ] ) diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 773ea9df41..732dbf4b83 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -15,10 +15,14 @@ import copy from functools import cmp_to_key +import os from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict from packaging.version import Version from sagemaker.jumpstart import accessors -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, + JUMPSTART_DEFAULT_REGION_NAME, +) from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.filters import ( SPECIAL_SUPPORTED_FILTER_KEYS, @@ -281,126 +285,160 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin results. (Default: False). """ - if isinstance(filter, str): - filter = Identity(filter) + class _ModelSearchContext: + """Context manager for conducting model searches.""" - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) - manifest_keys = set(models_manifest_list[0].__slots__) + def __init__(self): + """Initialize context manager.""" - all_keys: Set[str] = set() + self.old_disable_js_logging_env_var_value = os.environ.get( + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING + ) - model_filters: Set[ModelFilter] = set() + def __enter__(self, *args, **kwargs): + """Enter context. - for operator in _model_filter_in_operator_generator(filter): - model_filter = operator.unresolved_value - key = model_filter.key - all_keys.add(key) - model_filters.add(model_filter) + Disable JumpStart logs to avoid excessive logging. + """ - for key in all_keys: - if "." in key: - raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').") + os.environ[ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING] = "true" - metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS + def __exit__(self, *args, **kwargs): + """Exit context. - required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) - possible_spec_keys = metadata_filter_keys - manifest_keys + Restore JumpStart logging settings, and reset cache so + new logs would appear for models previously searched. + """ - unrecognized_keys: Set[str] = set() + if self.old_disable_js_logging_env_var_value: + os.environ[ + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING + ] = self.old_disable_js_logging_env_var_value + else: + os.environ.pop(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, None) + accessors.JumpStartModelsAccessor.reset_cache() - is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys - is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys - is_supported_model_filter = SpecialSupportedFilterKeys.SUPPORTED_MODEL in all_keys + with _ModelSearchContext(): - for model_manifest in models_manifest_list: + if isinstance(filter, str): + filter = Identity(filter) - copied_filter = copy.deepcopy(filter) + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) + manifest_keys = set(models_manifest_list[0].__slots__) - manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} + all_keys: Set[str] = set() - model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} + model_filters: Set[ModelFilter] = set() - for val in required_manifest_keys: - manifest_specs_cached_values[val] = getattr(model_manifest, val) + for operator in _model_filter_in_operator_generator(filter): + model_filter = operator.unresolved_value + key = model_filter.key + all_keys.add(key) + model_filters.add(model_filter) - if is_task_filter: - manifest_specs_cached_values[ - SpecialSupportedFilterKeys.TASK - ] = extract_framework_task_model(model_manifest.model_id)[1] + for key in all_keys: + if "." in key: + raise NotImplementedError( + f"No support for multiple level metadata indexing ('{key}')." + ) - if is_framework_filter: - manifest_specs_cached_values[ - SpecialSupportedFilterKeys.FRAMEWORK - ] = extract_framework_task_model(model_manifest.model_id)[0] + metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS - if is_supported_model_filter: - manifest_specs_cached_values[SpecialSupportedFilterKeys.SUPPORTED_MODEL] = Version( - model_manifest.min_version - ) <= Version(get_sagemaker_version()) + required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) + possible_spec_keys = metadata_filter_keys - manifest_keys - _populate_model_filters_to_resolved_values( - manifest_specs_cached_values, - model_filters_to_resolved_values, - model_filters, - ) + unrecognized_keys: Set[str] = set() - _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) + is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys + is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys - copied_filter.eval() + for model_manifest in models_manifest_list: - if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: - if copied_filter.resolved_value == BooleanValues.TRUE: - yield (model_manifest.model_id, model_manifest.version) - continue + copied_filter = copy.deepcopy(filter) - if copied_filter.resolved_value == BooleanValues.UNEVALUATED: - raise RuntimeError( - "Filter expression in unevaluated state after using values from model manifest. " - "Model ID and version that is failing: " - f"{(model_manifest.model_id, model_manifest.version)}." + manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} + + model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} + + for val in required_manifest_keys: + manifest_specs_cached_values[val] = getattr(model_manifest, val) + + if is_task_filter: + manifest_specs_cached_values[ + SpecialSupportedFilterKeys.TASK + ] = extract_framework_task_model(model_manifest.model_id)[1] + + if is_framework_filter: + manifest_specs_cached_values[ + SpecialSupportedFilterKeys.FRAMEWORK + ] = extract_framework_task_model(model_manifest.model_id)[0] + + if Version(model_manifest.min_version) > Version(get_sagemaker_version()): + continue + + _populate_model_filters_to_resolved_values( + manifest_specs_cached_values, + model_filters_to_resolved_values, + model_filters, ) - copied_filter_2 = copy.deepcopy(filter) - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( - region=region, - model_id=model_manifest.model_id, - version=model_manifest.version, - ) + _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) - model_specs_keys = set(model_specs.__slots__) + copied_filter.eval() - unrecognized_keys -= model_specs_keys - unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys - unrecognized_keys.update(unrecognized_keys_for_single_spec) + if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: + if copied_filter.resolved_value == BooleanValues.TRUE: + yield (model_manifest.model_id, model_manifest.version) + continue - for val in possible_spec_keys: - if hasattr(model_specs, val): - manifest_specs_cached_values[val] = getattr(model_specs, val) + if copied_filter.resolved_value == BooleanValues.UNEVALUATED: + raise RuntimeError( + "Filter expression in unevaluated state after using " + "values from model manifest. Model ID and version that " + f"is failing: {(model_manifest.model_id, model_manifest.version)}." + ) + copied_filter_2 = copy.deepcopy(filter) - _populate_model_filters_to_resolved_values( - manifest_specs_cached_values, - model_filters_to_resolved_values, - model_filters, - ) - _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_manifest.model_id, + version=model_manifest.version, + ) - copied_filter_2.eval() + model_specs_keys = set(model_specs.__slots__) - if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: - if copied_filter_2.resolved_value == BooleanValues.TRUE or ( - BooleanValues.UNKNOWN and list_incomplete_models - ): - yield (model_manifest.model_id, model_manifest.version) - continue + unrecognized_keys -= model_specs_keys + unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys + unrecognized_keys.update(unrecognized_keys_for_single_spec) - raise RuntimeError( - "Filter expression in unevaluated state after using values from model specs. " - "Model ID and version that is failing: " - f"{(model_manifest.model_id, model_manifest.version)}." - ) + for val in possible_spec_keys: + if hasattr(model_specs, val): + manifest_specs_cached_values[val] = getattr(model_specs, val) + + _populate_model_filters_to_resolved_values( + manifest_specs_cached_values, + model_filters_to_resolved_values, + model_filters, + ) + _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) + + copied_filter_2.eval() + + if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: + if copied_filter_2.resolved_value == BooleanValues.TRUE or ( + BooleanValues.UNKNOWN and list_incomplete_models + ): + yield (model_manifest.model_id, model_manifest.version) + continue + + raise RuntimeError( + "Filter expression in unevaluated state after using values from model specs. " + "Model ID and version that is failing: " + f"{(model_manifest.model_id, model_manifest.version)}." + ) - if len(unrecognized_keys) > 0: - raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}") + if len(unrecognized_keys) > 0: + raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}") def get_model_url( diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 1c40e2fabd..242118c56e 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -115,80 +115,6 @@ def _construct_payload( return payload_to_use -def _extract_generated_text_from_response( - response: dict, - model_id: str, - model_version: str, - region: Optional[str] = None, - tolerate_vulnerable_model: bool = False, - tolerate_deprecated_model: bool = False, - sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, - accept_type: Optional[str] = None, -) -> str: - """Returns generated text extracted from full response payload. - - Args: - response (dict): Dictionary-valued response from which to extract - generated text. - model_id (str): JumpStart model ID of the JumpStart model from which to extract - generated text. - model_version (str): Version of the JumpStart model for which to extract generated - text. - region (Optional[str]): Region for which to extract generated - text. (Default: None). - tolerate_vulnerable_model (bool): True if vulnerable versions of model - specifications should be tolerated (exception not raised). If False, raises an - exception if the script used by this version of the model has dependencies with known - security vulnerabilities. (Default: False). - tolerate_deprecated_model (bool): True if deprecated versions of model - specifications should be tolerated (exception not raised). If False, raises - an exception if the version of the model is deprecated. (Default: False). - sagemaker_session (sagemaker.session.Session): A SageMaker Session - object, used for SageMaker interactions. If not - specified, one is created using the default AWS configuration - chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). - accept_type (Optional[str]): The accept type to optionally specify for the response. - (Default: None). - - Returns: - str: extracted generated text from the endpoint response payload. - - Raises: - ValueError: If the model is invalid, the model does not support generated text extraction, - or if the response is malformed. - """ - - if not isinstance(response, dict): - raise ValueError(f"Response must be dictionary. Instead, got: {type(response)}") - - payloads: Optional[Dict[str, JumpStartSerializablePayload]] = _retrieve_example_payloads( - model_id=model_id, - model_version=model_version, - region=region, - tolerate_vulnerable_model=tolerate_vulnerable_model, - tolerate_deprecated_model=tolerate_deprecated_model, - sagemaker_session=sagemaker_session, - ) - if payloads is None or len(payloads) == 0: - raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.") - - for payload in payloads.values(): - if accept_type is None or payload.accept == accept_type: - generated_text_response_key: Optional[str] = payload.generated_text_response_key - if generated_text_response_key is None: - raise ValueError( - f"Model ID '{model_id}' does not support generated text extraction." - ) - - generated_text_response_key_split = generated_text_response_key.split(".") - try: - return _extract_field_from_json(response, generated_text_response_key_split) - except KeyError: - raise ValueError(f"Response is malformed: {response}") - - raise ValueError(f"Model ID '{model_id}' does not support generated text extraction.") - - class PayloadSerializer: """Utility class for serializing payloads associated with JumpStart models. diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 4cd07ff033..de9e2c10a3 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -341,7 +341,6 @@ class JumpStartSerializablePayload(JumpStartDataHolderType): "content_type", "accept", "body", - "generated_text_response_key", "prompt_key", ] @@ -373,7 +372,6 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: self.content_type = json_obj["content_type"] self.body = json_obj["body"] accept = json_obj.get("accept") - self.generated_text_response_key = json_obj.get("generated_text_response_key") self.prompt_key = json_obj.get("prompt_key") if accept: self.accept = accept @@ -732,6 +730,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType): "training_dependencies", "training_vulnerabilities", "deprecated", + "usage_info_message", "deprecated_message", "deprecate_warn_message", "default_inference_instance_type", @@ -803,6 +802,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.deprecated: bool = bool(json_obj["deprecated"]) self.deprecated_message: Optional[str] = json_obj.get("deprecated_message") self.deprecate_warn_message: Optional[str] = json_obj.get("deprecate_warn_message") + self.usage_info_message: Optional[str] = json_obj.get("usage_info_message") self.default_inference_instance_type: Optional[str] = json_obj.get( "default_inference_instance_type" ) diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index cd4ffcd702..0003081e99 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -109,7 +109,8 @@ def get_jumpstart_gated_content_bucket( accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return) if gated_bucket_to_return != old_gated_content_bucket: - accessors.JumpStartModelsAccessor.reset_cache() + if old_gated_content_bucket is not None: + accessors.JumpStartModelsAccessor.reset_cache() for info_log in info_logs: constants.JUMPSTART_LOGGER.info(info_log) @@ -153,7 +154,8 @@ def get_jumpstart_content_bucket( accessors.JumpStartModelsAccessor.set_jumpstart_content_bucket(bucket_to_return) if bucket_to_return != old_content_bucket: - accessors.JumpStartModelsAccessor.reset_cache() + if old_content_bucket is not None: + accessors.JumpStartModelsAccessor.reset_cache() for info_log in info_logs: constants.JUMPSTART_LOGGER.info(info_log) return bucket_to_return @@ -509,6 +511,9 @@ def emit_logs_based_on_model_specs( if model_specs.deprecate_warn_message: constants.JUMPSTART_LOGGER.warning(model_specs.deprecate_warn_message) + if model_specs.usage_info_message: + constants.JUMPSTART_LOGGER.info(model_specs.usage_info_message) + if model_specs.inference_vulnerable or model_specs.training_vulnerable: constants.JUMPSTART_LOGGER.warning( "Using vulnerable JumpStart model '%s' and version '%s'.", diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 1cf6042182..7b741a1269 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -559,6 +559,7 @@ def create( accelerator_type: Optional[str] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None, tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + accept_eula: Optional[bool] = None, ): """Create a SageMaker Model Entity @@ -582,6 +583,11 @@ def create( For more information about tags, see `boto3 documentation `_ + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: None or pipeline step arguments in case the Model instance is built with @@ -593,6 +599,7 @@ def create( accelerator_type=accelerator_type, tags=tags, serverless_inference_config=serverless_inference_config, + accept_eula=accept_eula, ) def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): @@ -613,6 +620,7 @@ def prepare_container_def( instance_type=None, accelerator_type=None, serverless_inference_config=None, + accept_eula=None, ): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()``. @@ -630,6 +638,11 @@ def prepare_container_def( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: dict: A container definition object usable with the CreateModel API. @@ -647,7 +660,9 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, image_config=self.image_config, - accept_eula=getattr(self, "accept_eula", None), + accept_eula=accept_eula + if accept_eula is not None + else getattr(self, "accept_eula", None), ) def is_repack(self) -> bool: @@ -789,6 +804,7 @@ def _create_sagemaker_model( accelerator_type=None, tags=None, serverless_inference_config=None, + accept_eula=None, ): """Create a SageMaker Model Entity @@ -808,6 +824,11 @@ def _create_sagemaker_model( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). """ if self.model_package_arn is not None or self.algorithm_arn is not None: model_package = ModelPackage( @@ -838,6 +859,7 @@ def _create_sagemaker_model( instance_type, accelerator_type=accelerator_type, serverless_inference_config=serverless_inference_config, + accept_eula=accept_eula, ) if not isinstance(self.sagemaker_session, PipelineSession): @@ -1459,7 +1481,12 @@ def deploy( "serverless_inference_config needs to be a ServerlessInferenceConfig object" ) - if instance_type and instance_type.startswith("ml.inf") and not self._is_compiled_model: + if ( + getattr(self, "model_id", None) in {"", None} + and instance_type + and instance_type.startswith("ml.inf") + and not self._is_compiled_model + ): LOGGER.warning( "Your model is not compiled. Please compile your model before using Inferentia." ) diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 93b73850ec..b656b4c671 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -121,7 +121,11 @@ def __init__( ) def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, ): """Return a container definition set. @@ -149,6 +153,7 @@ def prepare_container_def( env=environment, model_data_url=self.model_data_prefix, container_mode=self.container_mode, + accept_eula=accept_eula, ) def deploy( diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 58ba94cb48..8cd0ac6b65 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -267,7 +267,11 @@ def register( ) def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, ): """Return a container definition with framework configuration. @@ -282,6 +286,11 @@ def prepare_container_def( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: dict[str, str]: A container definition object usable with the @@ -312,7 +321,10 @@ def prepare_container_def( self.model_server_workers ) return sagemaker.container_def( - deploy_image, self.repacked_model_data or self.model_data, deploy_env + deploy_image, + self.repacked_model_data or self.model_data, + deploy_env, + accept_eula=accept_eula, ) def serving_image_uri( diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 286e7ecb28..fb731cabf4 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -269,7 +269,11 @@ def register( ) def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, ): """A container definition with framework configuration set in model environment variables. @@ -282,6 +286,11 @@ def prepare_container_def( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: dict[str, str]: A container definition object usable with the @@ -312,7 +321,10 @@ def prepare_container_def( self.model_server_workers ) return sagemaker.container_def( - deploy_image, self.repacked_model_data or self.model_data, deploy_env + deploy_image, + self.repacked_model_data or self.model_data, + deploy_env, + accept_eula=accept_eula, ) def serving_image_uri( diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 06076be312..195a6a3a57 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -262,7 +262,11 @@ def register( ) def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, ): """Container definition with framework configuration set in model environment variables. @@ -276,6 +280,11 @@ def prepare_container_def( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: dict[str, str]: A container definition object usable with the @@ -302,7 +311,12 @@ def prepare_container_def( model_data_uri = ( self.repacked_model_data if self.enable_network_isolation() else self.model_data ) - return sagemaker.container_def(deploy_image, model_data_uri, deploy_env) + return sagemaker.container_def( + deploy_image, + model_data_uri, + deploy_env, + accept_eula=accept_eula, + ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): """Create a URI for the serving image. diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 07af8fa09f..375a2ea7e5 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -379,7 +379,11 @@ def _eia_supported(self): ) def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, ): """Prepare the container definition. @@ -389,6 +393,11 @@ def prepare_container_def( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: A container definition for deploying a ``Model`` to an ``Endpoint``. @@ -446,7 +455,12 @@ def prepare_container_def( else: model_data = self.model_data - return sagemaker.container_def(image_uri, model_data, env) + return sagemaker.container_def( + image_uri, + model_data, + env, + accept_eula=accept_eula, + ) def _get_container_env(self): """Placeholder docstring.""" diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index bbb885b0d2..74776f8f72 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -250,7 +250,11 @@ def register( ) def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, + instance_type=None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, ): """Return a container definition with framework configuration. @@ -264,6 +268,11 @@ def prepare_container_def( serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig): Specifies configuration related to serverless endpoint. Instance type is not provided in serverless inference. So this is used to find image URIs. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: dict[str, str]: A container definition object usable with the CreateModel API. @@ -288,7 +297,12 @@ def prepare_container_def( model_data = ( self.repacked_model_data if self.enable_network_isolation() else self.model_data ) - return sagemaker.container_def(deploy_image, model_data, deploy_env) + return sagemaker.container_def( + deploy_image, + model_data, + deploy_env, + accept_eula=accept_eula, + ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): """Create a URI for the serving image. diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 529530e771..a3c4c747f7 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1023,6 +1023,602 @@ "training_enable_network_isolation": False, "resource_name_base": "dfsdfsds", }, + "gated_llama_neuron_model": { + "model_id": "meta-textgenerationneuron-llama-2-7b", + "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "version": "1.0.0", + "min_sdk_version": "2.198.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "djl-neuronx", + "framework_version": "0.24.0", + "py_version": "py39", + }, + "hosting_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuron-llama-2-7b/artifac" + "ts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/meta/inference/textgenerationneuron/v1.0.0/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "meta-textgenerationneuron/meta-textgenerationneuro" + "n-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", + "inference_vulnerable": False, + "inference_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", + "sagemaker_jumpstart_script_utilities==1.1.8", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "max_input_length", + "type": "int", + "default": 2048, + "min": 128, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 6e-06, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "min_learning_rate", + "type": "float", + "default": 1e-06, + "min": 1e-12, + "max": 1, + "scope": "algorithm", + }, + {"name": "max_steps", "type": "int", "default": 20, "min": 2, "scope": "algorithm"}, + { + "name": "global_train_batch_size", + "type": "int", + "default": 256, + "min": 1, + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "layer_norm_epilson", + "type": "float", + "default": 1e-05, + "min": 1e-12, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.1, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "CosineAnnealing", + "options": ["CosineAnnealing"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 10, "min": 0, "scope": "algorithm"}, + {"name": "constant_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.95, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "mixed_precision", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "tensor_parallel_degree", + "type": "text", + "default": "8", + "options": ["8"], + "scope": "algorithm", + }, + { + "name": "pipeline_parallel_degree", + "type": "text", + "default": "1", + "options": ["1"], + "scope": "algorithm", + }, + { + "name": "append_eod", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/meta/transfer_learning/textgenerati" + "onneuron/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-tarballs/meta/tra" + "nsfer_learning/textgenerationneuron/prepack/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_version": "1.0.0", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "meta-training/train-meta-textgenerationneuron-llama-2-7b.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "meta-textgenerationneuron:train-loss", + "Regex": "reduced_train_loss=([0-9]+\\.[0-9]+)", + } + ], + "default_inference_instance_type": "ml.inf2.xlarge", + "supported_inference_instance_types": [ + "ml.inf2.xlarge", + "ml.inf2.8xlarge", + "ml.inf2.24xlarge", + "ml.inf2.48xlarge", + ], + "default_training_instance_type": "ml.trn1.32xlarge", + "supported_training_instance_types": ["ml.trn1.32xlarge", "ml.trn1n.32xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, + }, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 256, + "training_volume_size": 256, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/sec_amazon/", + "validation_supported": False, + "fine_tuning_supported": True, + "resource_name_base": "meta-textgenerationneuron-llama-2-7b", + "default_payloads": { + "meaningOfLife": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "generated_text"}, + "body": { + "inputs": "I believe the meaning of life is", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "theoryOfRelativity": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "generated_text"}, + "body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "teamMessage": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "generated_text"}, + "body": { + "inputs": "A brief message congratulating the team on the launch:\n\nHi " + "everyone,\n\nI just ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "englishToFrench": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "generated_text"}, + "body": { + "inputs": "Translate English to French:\nsea otter => loutre de mer\npep" + "permint => menthe poivr\u00e9e\nplush girafe => girafe peluche\ncheese =>", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + }, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "alias_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/djl-in" + "ference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-east-1": { + "alias_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/djl-in" + "ference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-northeast-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/d" + "jl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-northeast-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-south-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-southeast-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ap-southeast-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "ca-central-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "cn-north-1": { + "alias_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-central-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-north-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-south-1": { + "alias_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-west-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/d" + "jl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-west-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/d" + "jl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "eu-west-3": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/d" + "jl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "me-south-1": { + "alias_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "sa-east-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "us-east-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "us-east-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com" + "/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "us-west-1": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.co" + "m/djl-inference:0.24.0-neuronx-sdk2.14.1" + }, + "us-west-2": { + "alias_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-" + "inference:0.24.0-neuronx-sdk2.14.1" + }, + }, + "variants": { + "c4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c5n": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "c6i": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g4dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf1": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "inf2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m4": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "m5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "r5d": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t2": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "t3": {"regional_properties": {"image_uri": "$alias_ecr_uri_1"}}, + "ml.inf2.xlarge": { + "properties": { + "environment_variables": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "1024", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "1", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + } + }, + "ml.inf2.8xlarge": { + "properties": { + "environment_variables": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + "OPTION_N_POSITIONS": "2048", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + } + }, + "ml.inf2.24xlarge": { + "properties": { + "environment_variables": { + "OPTION_TENSOR_PARALLEL_DEGREE": "12", + "OPTION_N_POSITIONS": "4096", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + } + }, + "ml.inf2.48xlarge": { + "properties": { + "environment_variables": { + "OPTION_TENSOR_PARALLEL_DEGREE": "24", + "OPTION_N_POSITIONS": "4096", + "OPTION_DTYPE": "fp16", + "OPTION_ROLLING_BATCH": "auto", + "OPTION_MAX_ROLLING_BATCH_SIZE": "4", + "OPTION_NEURON_OPTIMIZE_LEVEL": "2", + } + } + }, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": "626614931356.dkr.ecr.af-south-1.amazonaws.com/huggingface-pytorc" + "h-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-east-1": { + "gpu_ecr_uri_1": "871362719292.dkr.ecr.ap-east-1.amazonaws.com/huggingface-pytorch" + "-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/huggingface-" + "pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/huggingfa" + "ce-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-south-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-south-1.amazonaws.com/huggingface-pytorch" + "-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "ca-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.ca-central-1.amazonaws.com/hu" + "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "cn-north-1": { + "gpu_ecr_uri_1": "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-central-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-central-1.amazonaws.com/hug" + "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-north-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-north-1.amazonaws.com/hu" + "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-south-1": { + "gpu_ecr_uri_1": "692866216735.dkr.ecr.eu-south-1.amazonaws.com/hu" + "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-1.amazonaws.com/hug" + "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-2.amazonaws.com/hug" + "gingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "eu-west-3": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.eu-west-3.amazonaws.com/huggi" + "ngface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "me-south-1": { + "gpu_ecr_uri_1": "217643126080.dkr.ecr.me-south-1.amazonaws.com/huggin" + "gface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "sa-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.sa-east-1.amazonaws.com/hugg" + "ingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-east-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-1.amazonaws.com/hu" + "ggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + "neuron_ecr_uri": "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-" + "training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + }, + "us-east-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-py" + "torch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + "neuron_ecr_uri": "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-trai" + "ning-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + }, + "us-west-1": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-1.amazonaws.com/huggingface-pytor" + "ch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface" + "-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + "neuron_ecr_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch" + "-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04", + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "trn1": { + "regional_properties": {"image_uri": "$neuron_ecr_uri"}, + "properties": { + "gated_model_key_env_var_value": "meta-training/trn1/v1.0." + "0/train-meta-textgenerationneuron-llama-2-7b.tar.gz" + }, + }, + "trn1n": { + "regional_properties": {"image_uri": "$neuron_ecr_uri"}, + "properties": { + "gated_model_key_env_var_value": "meta-training/trn1n/v1.0.0" + "/train-meta-textgenerationneuron-llama-2-7b.tar.gz" + }, + }, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 8192, "num_accelerators": 1}, + "dynamic_container_deployment_supported": True, + }, "gated_variant-model": { "model_id": "pytorch-ic-mobilenet-v2", "gated_bucket": True, @@ -3533,7 +4129,6 @@ "Dog": { "content_type": "application/json", "prompt_key": "hello.prompt", - "generated_text_response_key": "key1.key2.generated_text", "body": { "hello": {"prompt": "a dog"}, "seed": 43, @@ -5626,6 +6221,7 @@ "ml.c5.2xlarge", ], "hosting_use_script_uri": True, + "usage_info_message": None, "metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], "model_kwargs": {"some-model-kwarg-key": "some-model-kwarg-value"}, "deploy_kwargs": {"some-model-deploy-kwarg-key": "some-model-deploy-kwarg-value"}, diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 68bd061a85..ff3e670e53 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -332,6 +332,43 @@ def test_prepacked( endpoint_logging=False, ) + @mock.patch("sagemaker.model.LOGGER.warning") + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") + @mock.patch("sagemaker.session.Session.create_model") + @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_no_compiled_model_warning_log_js_models( + self, + mock_get_model_specs: mock.Mock, + mock_session: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_create_model: mock.Mock, + mock_endpoint_from_production_variants: mock.Mock, + mock_timestamp: mock.Mock, + mock_warning: mock.Mock(), + ): + + mock_timestamp.return_value = "1234" + + mock_is_valid_model_id.return_value = True + + model_id, _ = "gated_llama_neuron_model", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session.return_value = sagemaker_session + + model = JumpStartModel( + model_id=model_id, + ) + + model.deploy(accept_eula=True) + + mock_warning.assert_not_called() + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.session.Session.endpoint_from_production_variants") @mock.patch("sagemaker.session.Session.create_model") diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 575167420d..181310a507 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -383,8 +383,14 @@ def test_list_jumpstart_models_region( @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils.get_sagemaker_version") - def test_list_jumpstart_models_unsupported_models( + @patch("sagemaker.jumpstart.notebook_utils.accessors.JumpStartModelsAccessor.reset_cache") + @patch.dict("os.environ", {}) + @patch("logging.StreamHandler.emit") + @patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER.propagate", False) + def test_list_jumpstart_models_disables_logging_resets_cache( self, + patched_emit: Mock, + patched_reset_cache: Mock, patched_get_sagemaker_version: Mock, patched_get_model_specs: Mock, patched_get_manifest: Mock, @@ -392,25 +398,12 @@ def test_list_jumpstart_models_unsupported_models( patched_get_model_specs.side_effect = get_prototype_model_spec patched_get_manifest.side_effect = get_prototype_manifest - patched_get_sagemaker_version.return_value = "0.0.0" + patched_get_sagemaker_version.return_value = "3.0.0" - assert [] == list_jumpstart_models("supported_model == True") - patched_get_model_specs.assert_not_called() - assert [] == list_jumpstart_models( - And("supported_model == True", "training_supported in [False, True]") - ) - patched_get_model_specs.assert_not_called() - - assert [] != list_jumpstart_models("supported_model == False") - - patched_get_sagemaker_version.return_value = "999999.0.0" - - assert [] != list_jumpstart_models("supported_model == True") - - patched_get_model_specs.reset_mock() + list_jumpstart_models("deprecate_warn_message is blah") - assert [] != list_jumpstart_models("training_supported in [False, True]") - patched_get_model_specs.assert_called() + patched_emit.assert_not_called() + patched_reset_cache.assert_called_once() @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/test_payload_utils.py b/tests/unit/sagemaker/jumpstart/test_payload_utils.py index 2172998b1a..afc955e2f3 100644 --- a/tests/unit/sagemaker/jumpstart/test_payload_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_payload_utils.py @@ -14,11 +14,9 @@ import base64 from unittest import TestCase from mock.mock import patch -import pytest from sagemaker.jumpstart.payload_utils import ( PayloadSerializer, - _extract_generated_text_from_response, _construct_payload, ) from sagemaker.jumpstart.types import JumpStartSerializablePayload @@ -59,77 +57,6 @@ def test_construct_payload(self, patched_get_model_specs): ) -class TestResponseExtraction(TestCase): - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - def test_extract_generated_text(self, patched_get_model_specs): - patched_get_model_specs.side_effect = get_special_model_spec - - model_id = "response-keys" - region = "us-west-2" - generated_text = _extract_generated_text_from_response( - response={"key1": {"key2": {"generated_text": "top secret"}}}, - model_id=model_id, - model_version="*", - region=region, - ) - - self.assertEqual( - _extract_generated_text_from_response( - response={"key1": {"key2": {"generated_text": "top secret"}}}, - model_id=model_id, - model_version="*", - region=region, - accept_type="application/json", - ), - generated_text, - ) - - self.assertEqual( - generated_text, - "top secret", - ) - - with pytest.raises(ValueError): - _extract_generated_text_from_response( - response={"key1": {"key2": {"generated_texts": "top secret"}}}, - model_id=model_id, - model_version="*", - region=region, - ) - - with pytest.raises(ValueError): - _extract_generated_text_from_response( - response={"key1": {"key2": {"generated_text": "top secret"}}}, - model_id=model_id, - model_version="*", - region=region, - accept_type="blah/blah", - ) - - with pytest.raises(ValueError): - _extract_generated_text_from_response( - response={"key1": {"key2": {"generated_text": "top secret"}}}, - model_id="env-var-variant-model", # some model without the required metadata - model_version="*", - region=region, - ) - with pytest.raises(ValueError): - _extract_generated_text_from_response( - response={"key1": {"generated_texts": "top secret"}}, - model_id=model_id, - model_version="*", - region=region, - ) - - with pytest.raises(ValueError): - _extract_generated_text_from_response( - response="blah", - model_id=model_id, - model_version="*", - region=region, - ) - - class TestPayloadSerializer(TestCase): payload_serializer = PayloadSerializer() diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 4cdb51a8cd..4c2cd5b123 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -76,7 +76,7 @@ def test_jumpstart_serializable_payload_with_predictor( "JumpStartSerializablePayload: {'content_type': 'application/json', 'accept': 'application/json'" ", 'body': {'prompt': 'a dog', 'num_images_per_prompt': 2, 'num_inference_steps':" " 20, 'guidance_scale': 7.5, 'seed': 43, 'eta': 0.7, 'image':" - " '$s3_b64'}, 'generated_text_response_key': None}" + " '$s3_b64'}}" ) js_predictor.predict(default_payload) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index c42012536e..2a4b1fc312 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -19,15 +19,16 @@ from sagemaker.jumpstart import utils from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE, ENV_VARIABLE_JUMPSTART_GATED_CONTENT_BUCKET_OVERRIDE, JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_GATED_AND_PUBLIC_BUCKET_NAME_SET, + JUMPSTART_LOGGER, JUMPSTART_REGION_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME, JumpStartScriptScope, ) - from functools import partial from sagemaker.jumpstart.enums import JumpStartTag, MIMEType from sagemaker.jumpstart.exceptions import ( @@ -1032,6 +1033,23 @@ def make_deprecated_warning_message_spec(*largs, **kwargs): ) +@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") +def test_jumpstart_usage_info_message(mock_get_manifest): + mock_get_manifest.return_value = [] + + usage_info_message = "This model might change your life." + + def make_info_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") + spec.usage_info_message = usage_info_message + return spec + + with patch("logging.Logger.info") as mocked_info_log: + utils.emit_logs_based_on_model_specs(make_info_spec(), "us-west-2", MOCK_CLIENT) + + mocked_info_log.assert_called_with(usage_info_message) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_vulnerable_model_errors(patched_get_model_specs): def make_vulnerable_inference_spec(*largs, **kwargs): @@ -1052,7 +1070,10 @@ def make_vulnerable_inference_spec(*largs, **kwargs): assert ( "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 " "vulnerable dependency in the inference script. " - "Please try targeting a higher version of the model or using a different model. " + "We recommend that you specify a more recent model version or " + "choose a different model. To access the " + "latest models and model versions, be sure to upgrade " + "to the latest version of the SageMaker Python SDK. " "List of vulnerabilities: some, vulnerability" ) == str(e.value.message) @@ -1074,7 +1095,10 @@ def make_vulnerable_training_spec(*largs, **kwargs): assert ( "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 " "vulnerable dependency in the training script. " - "Please try targeting a higher version of the model or using a different model. " + "We recommend that you specify a more recent model version or " + "choose a different model. To access the " + "latest models and model versions, be sure to upgrade " + "to the latest version of the SageMaker Python SDK. " "List of vulnerabilities: some, vulnerability" ) == str(e.value.message) @@ -1252,3 +1276,23 @@ def test_is_valid_model_id_false(self, mock_get_model_specs: Mock, mock_get_mani mock_get_manifest.assert_called_once_with( region=JUMPSTART_DEFAULT_REGION_NAME, s3_client=mock_s3_client_value ) + + +class TestJumpStartLogger(TestCase): + @patch.dict("os.environ", {}) + @patch("logging.StreamHandler.emit") + @patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER.propagate", False) + def test_logger_normal_mode(self, mocked_emit: Mock): + + JUMPSTART_LOGGER.warning("Self destruct in 3...2...1...") + + mocked_emit.assert_called_once() + + @patch.dict("os.environ", {ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING: "true"}) + @patch("logging.StreamHandler.emit") + @patch("sagemaker.jumpstart.constants.JUMPSTART_LOGGER.propagate", False) + def test_logger_disabled(self, mocked_emit: Mock): + + JUMPSTART_LOGGER.warning("Self destruct in 3...2...1...") + + mocked_emit.assert_not_called() diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index c321e2bd27..146c6fd1f7 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -16,7 +16,11 @@ import boto3 from sagemaker.jumpstart.cache import JumpStartModelsCache -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME, JUMPSTART_REGION_NAME_SET +from sagemaker.jumpstart.constants import ( + JUMPSTART_DEFAULT_REGION_NAME, + JUMPSTART_LOGGER, + JUMPSTART_REGION_NAME_SET, +) from sagemaker.jumpstart.types import ( JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue, @@ -93,6 +97,7 @@ def get_prototype_model_spec( we only retrieve model specs based on the model ID. """ + JUMPSTART_LOGGER.warning("some-logging-msg") specs = JumpStartModelSpecs(PROTOTYPICAL_MODEL_SPECS_DICT[model_id]) return specs diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 893688af5e..953cbe775c 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -114,7 +114,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None + INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None ) production_variant.assert_called_with( MODEL_NAME, @@ -927,7 +927,7 @@ def test_deploy_customized_volume_size_and_timeout( assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None + INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None ) production_variant.assert_called_with( MODEL_NAME, diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index ad985792d5..4d4248e0d6 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -287,7 +287,7 @@ def test_create_sagemaker_model(prepare_container_def, sagemaker_session): model._create_sagemaker_model() prepare_container_def.assert_called_with( - None, accelerator_type=None, serverless_inference_config=None + None, accelerator_type=None, serverless_inference_config=None, accept_eula=None ) sagemaker_session.create_model.assert_called_with( name=MODEL_NAME, @@ -305,7 +305,7 @@ def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_s model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None + INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None ) @@ -317,7 +317,40 @@ def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemake model._create_sagemaker_model(INSTANCE_TYPE, accelerator_type=accelerator_type) prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=accelerator_type, serverless_inference_config=None + INSTANCE_TYPE, + accelerator_type=accelerator_type, + serverless_inference_config=None, + accept_eula=None, + ) + + +@patch("sagemaker.model.Model.prepare_container_def") +def test_create_sagemaker_model_with_eula(prepare_container_def, sagemaker_session): + model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session) + + accelerator_type = "ml.eia.medium" + model.create(INSTANCE_TYPE, accelerator_type=accelerator_type, accept_eula=True) + + prepare_container_def.assert_called_with( + INSTANCE_TYPE, + accelerator_type=accelerator_type, + serverless_inference_config=None, + accept_eula=True, + ) + + +@patch("sagemaker.model.Model.prepare_container_def") +def test_create_sagemaker_model_with_eula_false(prepare_container_def, sagemaker_session): + model = Model(MODEL_IMAGE, MODEL_DATA, name=MODEL_NAME, sagemaker_session=sagemaker_session) + + accelerator_type = "ml.eia.medium" + model.create(INSTANCE_TYPE, accelerator_type=accelerator_type, accept_eula=False) + + prepare_container_def.assert_called_with( + INSTANCE_TYPE, + accelerator_type=accelerator_type, + serverless_inference_config=None, + accept_eula=False, ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index cc5c6dfcd9..ae75256794 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -259,7 +259,11 @@ def create_predictor(self, endpoint_name): return None def prepare_container_def( - self, instance_type, accelerator_type=None, serverless_inference_config=None + self, + instance_type, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, ): return MODEL_CONTAINER_DEF