diff --git a/doc/api/inference/explainer.rst b/doc/api/inference/explainer.rst new file mode 100644 index 0000000000..d522e6c1dc --- /dev/null +++ b/doc/api/inference/explainer.rst @@ -0,0 +1,16 @@ +Online Explainability +--------------------- + +This module contains classes related to Amazon Sagemaker Clarify Online Explainability + +.. automodule:: sagemaker.explainer.explainer_config + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: sagemaker.explainer.clarify_explainer_config + :members: + :undoc-members: + :show-inheritance: + + diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 922150b901..0851efc6e6 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1378,6 +1378,7 @@ def deploy( model_data_download_timeout=None, container_startup_health_check_timeout=None, inference_recommendation_id=None, + explainer_config=None, **kwargs, ): """Deploy the trained model to an Amazon SageMaker endpoint. @@ -1458,6 +1459,8 @@ def deploy( inference_recommendation_id (str): The recommendation id which specifies the recommendation you picked from inference recommendation job results and would like to deploy the model and endpoint with recommended parameters. + explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability + configuration for use with Amazon SageMaker Clarify. (default: None) **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. @@ -1516,6 +1519,7 @@ def deploy( data_capture_config=data_capture_config, serverless_inference_config=serverless_inference_config, async_inference_config=async_inference_config, + explainer_config=explainer_config, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, diff --git a/src/sagemaker/explainer/__init__.py b/src/sagemaker/explainer/__init__.py new file mode 100644 index 0000000000..a9aac5bd3f --- /dev/null +++ b/src/sagemaker/explainer/__init__.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Imports the classes in this module to simplify customer imports""" + +from __future__ import absolute_import + +from sagemaker.explainer.explainer_config import ExplainerConfig # noqa: F401 +from sagemaker.explainer.clarify_explainer_config import ( # noqa: F401 + ClarifyExplainerConfig, + ClarifyInferenceConfig, + ClarifyShapConfig, + ClarifyShapBaselineConfig, + ClarifyTextConfig, +) diff --git a/src/sagemaker/explainer/clarify_explainer_config.py b/src/sagemaker/explainer/clarify_explainer_config.py new file mode 100644 index 0000000000..b3fc18ebb3 --- /dev/null +++ b/src/sagemaker/explainer/clarify_explainer_config.py @@ -0,0 +1,298 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A member of :class:`~sagemaker.explainer.explainer_config.ExplainerConfig` that contains configuration parameters for the SageMaker Clarify explainer.""" # noqa E501 # pylint: disable=line-too-long + +from __future__ import print_function, absolute_import +from typing import List, Optional + + +class ClarifyTextConfig(object): + """A parameter used to configure the SageMaker Clarify explainer to treat text features as text so that explanations are provided for individual units of text. Required only for NLP explainability.""" # noqa E501 # pylint: disable=line-too-long + + def __init__( + self, + language: str, + granularity: str, + ): + """Initialize a config object for text explainability. + + Args: + language (str): Specifies the language of the text features in `ISO 639-1 + `__ or `ISO 639-3 + `__ code of a supported + language. See valid values `here + `__. + granularity (str): The unit of granularity for the analysis of text features. For + example, if the unit is ``"token"``, then each token (like a word in English) of the + text is treated as a feature. SHAP values are computed for each unit/feature. + Accepted values are ``"token"``, ``"sentence"``, or ``"paragraph"``. + """ # noqa E501 # pylint: disable=line-too-long + self.language = language + self.granularity = granularity + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + request_dict = {"Language": self.language, "Granularity": self.granularity} + return request_dict + + +class ClarifyShapBaselineConfig(object): + """Configuration for the SHAP baseline of the Kernal SHAP algorithm.""" + + def __init__( + self, + mime_type: Optional[str] = "text/csv", + shap_baseline: Optional[str] = None, + shap_baseline_uri: Optional[str] = None, + ): + """Initialize a config object for SHAP baseline. + + Args: + mime_type (str): Optional. The MIME type of the baseline data. Choose + from ``"text/csv"`` or ``"application/jsonlines"``. (Default: ``"text/csv"``) + shap_baseline (str): Optional. The inline SHAP baseline data in string format. + ShapBaseline can have one or multiple records to be used as the baseline dataset. + The format of the SHAP baseline file should be the same format as the training + dataset. For example, if the training dataset is in CSV format and each record + contains four features, and all features are numerical, then the format of the + baseline data should also share these characteristics. For NLP of text columns, the + baseline value should be the value used to replace the unit of text specified by + the ``granularity`` of the + :class:`~sagemaker.explainer.clarify_explainer_config.ClarifyTextConfig` + parameter. The size limit for ``shap_baseline`` is 4 KB. Use the + ``shap_baseline_uri`` parameter if you want to provide more than 4 KB of baseline + data. + shap_baseline_uri (str): Optional. The S3 URI where the SHAP baseline file is stored. + The format of the SHAP baseline file should be the same format as the format of + the training dataset. For example, if the training dataset is in CSV format, + and each record in the training dataset has four features, and all features are + numerical, then the baseline file should also have this same format. Each record + should contain only the features. If you are using a virtual private cloud (VPC), + the ``shap_baseline_uri`` should be accessible to the VPC. + """ + self.mime_type = mime_type + self.shap_baseline = shap_baseline + self.shap_baseline_uri = shap_baseline_uri + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + request_dict = {"MimeType": self.mime_type} + if self.shap_baseline is not None: + request_dict["ShapBaseline"] = self.shap_baseline + if self.shap_baseline_uri is not None: + request_dict["ShapBaselineUri"] = self.shap_baseline_uri + + return request_dict + + +class ClarifyShapConfig(object): + """Configuration for SHAP analysis using SageMaker Clarify Explainer.""" + + def __init__( + self, + shap_baseline_config: ClarifyShapBaselineConfig, + number_of_samples: Optional[int] = None, + seed: Optional[int] = None, + use_logit: Optional[bool] = False, + text_config: Optional[ClarifyTextConfig] = None, + ): + """Initialize a config object for SHAP analysis. + + Args: + shap_baseline_config (:class:`~sagemaker.explainer.clarify_explainer_config.ClarifyShapBaselineConfig`): + The configuration for the SHAP baseline of the Kernal SHAP algorithm. + number_of_samples (int): Optional. The number of samples to be used for analysis by the + Kernal SHAP algorithm. The number of samples determines the size of the synthetic + dataset, which has an impact on latency of explainability requests. For more + information, see the `Synthetic data` of `Configure and create an endpoint + `__. + seed (int): Optional. The starting value used to initialize the random number generator + in the explainer. Provide a value for this parameter to obtain a deterministic SHAP + result. + use_logit (bool): Optional. A Boolean toggle to indicate if you want to use the logit + function (true) or log-odds units (false) for model predictions. (Default: false) + text_config (:class:`~sagemaker.explainer.clarify_explainer_config.ClarifyTextConfig`): + Optional. A parameter that indicates if text features are treated as text and + explanations are provided for individual units of text. Required for NLP + explainability only. + """ # noqa E501 # pylint: disable=line-too-long + self.number_of_samples = number_of_samples + self.seed = seed + self.shap_baseline_config = shap_baseline_config + self.text_config = text_config + self.use_logit = use_logit + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + request_dict = { + "ShapBaselineConfig": self.shap_baseline_config._to_request_dict(), + "UseLogit": self.use_logit, + } + if self.number_of_samples is not None: + request_dict["NumberOfSamples"] = self.number_of_samples + + if self.seed is not None: + request_dict["Seed"] = self.seed + + if self.text_config is not None: + request_dict["TextConfig"] = self.text_config._to_request_dict() + + return request_dict + + +class ClarifyInferenceConfig(object): + """The inference configuration parameter for the model container.""" + + def __init__( + self, + feature_headers: Optional[List[str]] = None, + feature_types: Optional[List[str]] = None, + features_attribute: Optional[str] = None, + probability_index: Optional[int] = None, + probability_attribute: Optional[str] = None, + label_index: Optional[int] = None, + label_attribute: Optional[str] = None, + label_headers: Optional[List[str]] = None, + max_payload_in_mb: Optional[int] = 6, + max_record_count: Optional[int] = None, + content_template: Optional[str] = None, + ): + """Initialize a config object for model container. + + Args: + feature_headers (list[str]): Optional. The names of the features. If provided, these are + included in the endpoint response payload to help readability of the + ``InvokeEndpoint`` output. + feature_types (list[str]): Optional. A list of data types of the features. Applicable + only to NLP explainability. If provided, ``feature_types`` must have at least one + ``'text'`` string (for example, ``['text']``). If ``feature_types`` is not provided, + the explainer infers the feature types based on the baseline data. The feature + types are included in the endpoint response payload. + features_attribute (str): Optional. Provides the JMESPath expression to extract the + features from a model container input in JSON Lines format. For example, + if ``features_attribute`` is the JMESPath expression ``'myfeatures'``, it extracts a + list of features ``[1,2,3]`` from request data ``'{"myfeatures":[1,2,3]}'``. + probability_index (int): Optional. A zero-based index used to extract a probability + value (score) or list from model container output in CSV format. If this value is + not provided, the entire model container output will be treated as a probability + value (score) or list. See examples `here + `__. + probability_attribute (str): Optional. A JMESPath expression used to extract the + probability (or score) from the model container output if the model container + is in JSON Lines format. See examples `here + `__. + label_index (int): Optional. A zero-based index used to extract a label header or list + of label headers from model container output in CSV format. + label_attribute (str): Optional. A JMESPath expression used to locate the list of label + headers in the model container output. + label_headers (list[str]): Optional. For multiclass classification problems, the label + headers are the names of the classes. Otherwise, the label header is the name of + the predicted label. These are used to help readability for the output of the + ``InvokeEndpoint`` API. + max_payload_in_mb (int): Optional. The maximum payload size (MB) allowed of a request + from the explainer to the model container. (Default: 6) + max_record_count (int): Optional. The maximum number of records in a request that the + model container can process when querying the model container for the predictions + of a `synthetic dataset + `__. + A record is a unit of input data that inference can be made on, for example, a + single line in CSV data. If ``max_record_count`` is ``1``, the model container + expects one record per request. A value of 2 or greater means that the model expects + batch requests, which can reduce overhead and speed up the inferencing process. If + this parameter is not provided, the explainer will tune the record count per request + according to the model container's capacity at runtime. + content_template (str): Optional. A template string used to format a JSON record into an + acceptable model container input. For example, a ``ContentTemplate`` string ``'{ + "myfeatures":$features}'`` will format a list of features ``[1,2,3]`` into the + record string ``'{"myfeatures":[1,2,3]}'``. Required only when the model + container input is in JSON Lines format. + """ # noqa E501 # pylint: disable=line-too-long + self.feature_headers = feature_headers + self.feature_types = feature_types + self.features_attribute = features_attribute + self.probability_index = probability_index + self.probability_attribute = probability_attribute + self.label_index = label_index + self.label_attribute = label_attribute + self.label_headers = label_headers + self.max_payload_in_mb = max_payload_in_mb + self.max_record_count = max_record_count + self.content_template = content_template + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + request_dict = {} + + if self.feature_headers is not None and self.feature_headers: + request_dict["FeatureHeaders"] = self.feature_headers + if self.feature_types is not None: + request_dict["FeatureTypes"] = self.feature_types + if self.features_attribute is not None: + request_dict["FeaturesAttribute"] = self.features_attribute + if self.probability_index is not None: + request_dict["ProbabilityIndex"] = self.probability_index + if self.probability_attribute is not None: + request_dict["ProbabilityAttribute"] = self.probability_attribute + if self.label_index is not None: + request_dict["LabelIndex"] = self.label_index + if self.label_attribute is not None: + request_dict["LabelAttribute"] = self.label_attribute + if self.label_headers is not None: + request_dict["LabelHeaders"] = self.label_headers + if self.max_payload_in_mb is not None: + request_dict["MaxPayloadInMB"] = self.max_payload_in_mb + if self.max_record_count is not None: + request_dict["MaxRecordCount"] = self.max_record_count + if self.content_template is not None: + request_dict["ContentTemplate"] = self.content_template + return request_dict + + +class ClarifyExplainerConfig(object): + """The configuration parameters for the SageMaker Clarify explainer.""" + + def __init__( + self, + shap_config: ClarifyShapConfig, + enable_explanations: Optional[str] = None, + inference_config: Optional[ClarifyInferenceConfig] = None, + ): + """Initialize a config object for online explainability with AWS SageMaker Clarify. + + Args: + shap_config (:class:`~sagemaker.explainer.clarify_explainer_config.ClarifyShapConfig`): + The configuration for SHAP analysis. + enable_explanations (str): Optional. A `JMESPath boolean expression + `__ + used to filter which records to explain (Default: None). If not specified, + explanations are activated by default. + inference_config (:class:`~sagemaker.explainer.clarify_explainer_config.ClarifyInferenceConfig`): + Optional. The inference configuration parameter for the model container. (Default: None) + """ # noqa E501 # pylint: disable=line-too-long + self.enable_explanations = enable_explanations + self.shap_config = shap_config + self.inference_config = inference_config + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + request_dict = { + "ShapConfig": self.shap_config._to_request_dict(), + } + + if self.enable_explanations is not None: + request_dict["EnableExplanations"] = self.enable_explanations + + if self.inference_config is not None: + request_dict["InferenceConfig"] = self.inference_config._to_request_dict() + + return request_dict diff --git a/src/sagemaker/explainer/explainer_config.py b/src/sagemaker/explainer/explainer_config.py new file mode 100644 index 0000000000..6a174b27d5 --- /dev/null +++ b/src/sagemaker/explainer/explainer_config.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A member of ``CreateEndpointConfig`` that enables explainers.""" + +from __future__ import print_function, absolute_import +from typing import Optional +from sagemaker.explainer.clarify_explainer_config import ClarifyExplainerConfig + + +class ExplainerConfig(object): + """A parameter to activate explainers.""" + + def __init__( + self, + clarify_explainer_config: Optional[ClarifyExplainerConfig] = None, + ): + """Initializes a config object to activate explainer. + + Args: + clarify_explainer_config (:class:`~sagemaker.explainer.explainer_config.ClarifyExplainerConfig`): + Optional. A config contains parameters for the SageMaker Clarify explainer. (Default: None) + """ # noqa E501 # pylint: disable=line-too-long + self.clarify_explainer_config = clarify_explainer_config + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + request_dict = {} + + if self.clarify_explainer_config: + request_dict[ + "ClarifyExplainerConfig" + ] = self.clarify_explainer_config._to_request_dict() + + return request_dict diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 437b749a90..ff0f805f7b 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -210,6 +210,7 @@ def deploy( model_data_download_timeout=None, container_startup_health_check_timeout=None, inference_recommendation_id=None, + explainer_config=None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -286,6 +287,8 @@ def deploy( inference_recommendation_id (str): The recommendation id which specifies the recommendation you picked from inference recommendation job results and would like to deploy the model and endpoint with recommended parameters. + explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability + configuration for use with Amazon SageMaker Clarify. (default: None) Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -322,6 +325,7 @@ def deploy( model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, + explainer_config=explainer_config, ) def register( diff --git a/src/sagemaker/inference_recommender/inference_recommender_mixin.py b/src/sagemaker/inference_recommender/inference_recommender_mixin.py index af421382b9..90b460a23e 100644 --- a/src/sagemaker/inference_recommender/inference_recommender_mixin.py +++ b/src/sagemaker/inference_recommender/inference_recommender_mixin.py @@ -215,6 +215,7 @@ def _update_params( accelerator_type = kwargs["accelerator_type"] async_inference_config = kwargs["async_inference_config"] serverless_inference_config = kwargs["serverless_inference_config"] + explainer_config = kwargs["explainer_config"] inference_recommendation_id = kwargs["inference_recommendation_id"] inference_recommender_job_results = kwargs["inference_recommender_job_results"] if inference_recommendation_id is not None: @@ -225,6 +226,7 @@ def _update_params( async_inference_config=async_inference_config, serverless_inference_config=serverless_inference_config, inference_recommendation_id=inference_recommendation_id, + explainer_config=explainer_config, ) elif inference_recommender_job_results is not None: inference_recommendation = self._update_params_for_right_size( @@ -233,6 +235,7 @@ def _update_params( accelerator_type, serverless_inference_config, async_inference_config, + explainer_config, ) return inference_recommendation or (instance_type, initial_instance_count) @@ -243,6 +246,7 @@ def _update_params_for_right_size( accelerator_type=None, serverless_inference_config=None, async_inference_config=None, + explainer_config=None, ): """Validates that Inference Recommendation parameters can be used in `model.deploy()` @@ -262,6 +266,8 @@ def _update_params_for_right_size( whether serverless_inference_config has been passed into `model.deploy()`. async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): whether async_inference_config has been passed into `model.deploy()`. + explainer_config (sagemaker.explainer.ExplainerConfig): whether explainer_config + has been passed into `model.deploy()`. Returns: (string, int) or None: Top instance_type and associated initial_instance_count @@ -285,6 +291,11 @@ def _update_params_for_right_size( "serverless_inference_config is specified. Overriding right_size() recommendations." ) return None + if explainer_config: + LOGGER.warning( + "explainer_config is specified. Overriding right_size() recommendations." + ) + return None instance_type = self.inference_recommendations[0]["EndpointConfiguration"]["InstanceType"] initial_instance_count = self.inference_recommendations[0]["EndpointConfiguration"][ @@ -300,6 +311,7 @@ def _update_params_for_recommendation_id( async_inference_config, serverless_inference_config, inference_recommendation_id, + explainer_config, ): """Update parameters with inference recommendation results. @@ -332,6 +344,8 @@ def _update_params_for_recommendation_id( the recommendation you picked from inference recommendation job results and would like to deploy the model and endpoint with recommended parameters. + explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability + configuration for use with Amazon SageMaker Clarify. Default: None. Raises: ValueError: If arguments combination check failed in these circumstances: - If only one of instance type or instance count specified or @@ -367,6 +381,8 @@ def _update_params_for_recommendation_id( raise ValueError( "serverless_inference_config is not compatible with inference_recommendation_id." ) + if explainer_config is not None: + raise ValueError("explainer_config is not compatible with inference_recommendation_id.") # Validate recommendation id if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id): diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 0ca017e1f4..38286f5205 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -42,6 +42,7 @@ from sagemaker.model_metrics import ModelMetrics from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.explainer import ExplainerConfig from sagemaker.metadata_properties import MetadataProperties from sagemaker.predictor import PredictorBase from sagemaker.serverless import ServerlessInferenceConfig @@ -1080,6 +1081,7 @@ def deploy( model_data_download_timeout=None, container_startup_health_check_timeout=None, inference_recommendation_id=None, + explainer_config=None, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1158,6 +1160,8 @@ def deploy( inference_recommendation_id (str): The recommendation id which specifies the recommendation you picked from inference recommendation job results and would like to deploy the model and endpoint with recommended parameters. + explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability + configuration for use with Amazon SageMaker Clarify. Default: None. Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1204,6 +1208,7 @@ def deploy( accelerator_type=accelerator_type, async_inference_config=async_inference_config, serverless_inference_config=serverless_inference_config, + explainer_config=explainer_config, inference_recommendation_id=inference_recommendation_id, inference_recommender_job_results=self.inference_recommender_job_results, ) @@ -1212,6 +1217,10 @@ def deploy( if is_async and not isinstance(async_inference_config, AsyncInferenceConfig): raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object") + is_explainer_enabled = explainer_config is not None + if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig): + raise ValueError("explainer_config needs to be a ExplainerConfig object") + is_serverless = serverless_inference_config is not None if not is_serverless and not (instance_type and initial_instance_count): raise ValueError( @@ -1279,6 +1288,10 @@ def deploy( ) async_inference_config_dict = async_inference_config._to_request_dict() + explainer_config_dict = None + if is_explainer_enabled: + explainer_config_dict = explainer_config._to_request_dict() + self.sagemaker_session.endpoint_from_production_variants( name=self.endpoint_name, production_variants=[production_variant], @@ -1286,6 +1299,7 @@ def deploy( kms_key=kms_key, wait=wait, data_capture_config_dict=data_capture_config_dict, + explainer_config_dict=explainer_config_dict, async_inference_config_dict=async_inference_config_dict, ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 15a5fc9b77..8bfc682748 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3663,6 +3663,7 @@ def create_endpoint_config( volume_size=None, model_data_download_timeout=None, container_startup_health_check_timeout=None, + explainer_config_dict=None, ): """Create an Amazon SageMaker endpoint configuration. @@ -3696,6 +3697,8 @@ def create_endpoint_config( inference container to pass health check by SageMaker Hosting. For more information about health check see: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests + explainer_config_dict (dict): Specifies configuration to enable explainers. + Default: None. Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] @@ -3751,6 +3754,9 @@ def create_endpoint_config( ) request["DataCaptureConfig"] = inferred_data_capture_config_dict + if explainer_config_dict is not None: + request["ExplainerConfig"] = explainer_config_dict + self.sagemaker_client.create_endpoint_config(**request) return name @@ -3762,6 +3768,7 @@ def create_endpoint_config_from_existing( new_kms_key=None, new_data_capture_config_dict=None, new_production_variants=None, + new_explainer_config_dict=None, ): """Create an Amazon SageMaker endpoint configuration from an existing one. @@ -3789,6 +3796,9 @@ def create_endpoint_config_from_existing( new_production_variants (list[dict]): The configuration for which model(s) to host and the resources to deploy for hosting the model(s). If not specified, the ``ProductionVariants`` of the existing endpoint configuration is used. + new_explainer_config_dict (dict): Specifies configuration to enable explainers. + (default: None). If not specified, the explainer configuration of the existing + endpoint configuration is used. Returns: str: Name of the endpoint point configuration created. @@ -3856,6 +3866,13 @@ def create_endpoint_config_from_existing( ) request["AsyncInferenceConfig"] = inferred_async_inference_config_dict + request_explainer_config_dict = ( + new_explainer_config_dict or existing_endpoint_config_desc.get("ExplainerConfig", None) + ) + + if request_explainer_config_dict is not None: + request["ExplainerConfig"] = request_explainer_config_dict + self.sagemaker_client.create_endpoint_config(**request) def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True): @@ -4372,6 +4389,7 @@ def endpoint_from_production_variants( wait=True, data_capture_config_dict=None, async_inference_config_dict=None, + explainer_config_dict=None, ): """Create an SageMaker ``Endpoint`` from a list of production variants. @@ -4389,6 +4407,9 @@ def endpoint_from_production_variants( async_inference_config_dict (dict) : specifies configuration related to async endpoint. Use this configuration when trying to create async endpoint and make async inference (default: None) + explainer_config_dict (dict) : Specifies configuration related to explainer. + Use this configuration when trying to use online explainability. + (default: None) Returns: str: The name of the created ``Endpoint``. """ @@ -4422,6 +4443,8 @@ def endpoint_from_production_variants( sagemaker_session=self, ) config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict + if explainer_config_dict is not None: + config_options["ExplainerConfig"] = explainer_config_dict LOGGER.info("Creating endpoint-config with name %s", name) self.sagemaker_client.create_endpoint_config(**config_options) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index b0eaf753fb..4b9c23621b 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -324,6 +324,7 @@ def deploy( model_data_download_timeout=None, container_startup_health_check_timeout=None, inference_recommendation_id=None, + explainer_config=None, ): """Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``.""" @@ -349,6 +350,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, update_endpoint=update_endpoint, inference_recommendation_id=inference_recommendation_id, + explainer_config=explainer_config, ) def _eia_supported(self): diff --git a/tests/integ/test_explainer.py b/tests/integ/test_explainer.py new file mode 100644 index 0000000000..a1a5c08065 --- /dev/null +++ b/tests/integ/test_explainer.py @@ -0,0 +1,131 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import codecs +import json +import os +import pytest + +import tests.integ +import tests.integ.timeout + +from sagemaker import image_uris +from sagemaker.model import Model +from sagemaker.utils import unique_name_from_base +from sagemaker.explainer.explainer_config import ExplainerConfig +from sagemaker.explainer.clarify_explainer_config import ( + ClarifyExplainerConfig, + ClarifyShapConfig, + ClarifyShapBaselineConfig, +) + +from tests.integ import DATA_DIR + + +ROLE = "SageMakerRole" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.c5.xlarge" +TEST_CSV_DATA = "42,42,42,42,42,42,42" +SHAP_BASELINE = "1,2,3,4,5,6,7" +XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model") + +CLARIFY_SHAP_BASELINE_CONFIG = ClarifyShapBaselineConfig(shap_baseline=SHAP_BASELINE) +CLARIFY_SHAP_CONFIG = ClarifyShapConfig(shap_baseline_config=CLARIFY_SHAP_BASELINE_CONFIG) +CLARIFY_EXPLAINER_CONFIG = ClarifyExplainerConfig( + shap_config=CLARIFY_SHAP_CONFIG, enable_explanations="`true`" +) +EXPLAINER_CONFIG = ExplainerConfig(clarify_explainer_config=CLARIFY_EXPLAINER_CONFIG) + + +@pytest.yield_fixture(scope="module") +def endpoint_name(sagemaker_session): + endpoint_name = unique_name_from_base("clarify-explainer-enabled-endpoint-integ") + xgb_model_data = sagemaker_session.upload_data( + path=os.path.join(XGBOOST_DATA_PATH, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + + xgb_image = image_uris.retrieve( + "xgboost", + sagemaker_session.boto_region_name, + version="1", + image_scope="inference", + ) + + with tests.integ.timeout.timeout_and_delete_endpoint_by_name( + endpoint_name=endpoint_name, sagemaker_session=sagemaker_session, hours=2 + ): + xgb_model = Model( + model_data=xgb_model_data, + image_uri=xgb_image, + name=endpoint_name, + role=ROLE, + sagemaker_session=sagemaker_session, + ) + xgb_model.deploy( + INSTANCE_COUNT, + INSTANCE_TYPE, + endpoint_name=endpoint_name, + explainer_config=EXPLAINER_CONFIG, + ) + yield endpoint_name + + +def test_describe_explainer_config(sagemaker_session, endpoint_name): + endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(EndpointName=endpoint_name) + + endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config( + EndpointConfigName=endpoint_desc["EndpointConfigName"] + ) + assert endpoint_config_desc["ExplainerConfig"] == EXPLAINER_CONFIG._to_request_dict() + + +def test_invoke_explainer_enabled_endpoint(sagemaker_session, endpoint_name): + response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint( + EndpointName=endpoint_name, + Body=TEST_CSV_DATA, + ContentType="text/csv", + Accept="text/csv", + ) + assert response + # Explainer enabled endpoint content-type should always be "application/json" + assert response.get("ContentType") == "application/json" + response_body_stream = response["Body"] + try: + response_body_json = json.load(codecs.getreader("utf-8")(response_body_stream)) + assert response_body_json + assert response_body_json.get("explanations") + assert response_body_json.get("predictions") + finally: + response_body_stream.close() + + +def test_invoke_endpoint_with_on_demand_explanations(sagemaker_session, endpoint_name): + response = sagemaker_session.sagemaker_runtime_client.invoke_endpoint( + EndpointName=endpoint_name, + EnableExplanations="`false`", + Body=TEST_CSV_DATA, + ContentType="text/csv", + Accept="text/csv", + ) + assert response + response_body_stream = response["Body"] + try: + response_body_json = json.load(codecs.getreader("utf-8")(response_body_stream)) + assert response_body_json + # no records are explained when EnableExplanations="`false`" + assert response_body_json.get("explanations") == {} + assert response_body_json.get("predictions") + finally: + response_body_stream.close() diff --git a/tests/unit/sagemaker/explainer/__init__.py b/tests/unit/sagemaker/explainer/__init__.py new file mode 100644 index 0000000000..a6987bc6a6 --- /dev/null +++ b/tests/unit/sagemaker/explainer/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import diff --git a/tests/unit/sagemaker/explainer/test_explainer_config.py b/tests/unit/sagemaker/explainer/test_explainer_config.py new file mode 100644 index 0000000000..a25fca1f02 --- /dev/null +++ b/tests/unit/sagemaker/explainer/test_explainer_config.py @@ -0,0 +1,142 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.explainer import ( + ExplainerConfig, + ClarifyExplainerConfig, + ClarifyShapConfig, + ClarifyInferenceConfig, + ClarifyShapBaselineConfig, + ClarifyTextConfig, +) + + +OPTIONAL_MIME_TYPE = "application/jsonlines" +DEFAULT_MIME_TYPE = "text/csv" +SHAP_BASELINE = '1,2,3,"good product"' +SHAP_BASELINE_PATH = "s3://testbucket/baseline.csv" +OPTIONAL_LABEL_HEADERS = ["Label1", "Label2", "Label3"] +OPTIONAL_FEATURE_HEADERS = ["Feature1", "Feature2", "Feature3", "Feature4"] +OPTIONAL_FEATURE_TYPES = ["numerical", "numerical", "categorical", "text"] +OPTIONAL_CONTENT_TEMPLATE = '{"features":$features}' +OPTIONAL_FEATURE_ATTRIBUTION = "features" +OPTIONAL_ENABLE_EXPLAINABITIONS = "`true`" +OPTIONAL_MAX_RECORD_COUNT = 2 +OPTIONAL_MAX_PAYLOAD_IN_MB = 5 +DEFAULT_MAX_PAYLOAD_IN_MB = 6 +OPTIONAL_PROBABILITY_INDEX = 0 +OPTIONAL_LABEL_INDEX = 1 +OPTIONAL_PROBABILITY_ATTRIBUTE = "probabilities" +OPTIONAL_LABEL_ATTRIBUTE = "labels" +OPTIONAL_NUM_OF_SAMPLES = 100 +OPTIONAL_USE_LOGIT = True +DEFAULT_USE_LOGIT = False +OPTIONAL_SEED = 987 +GRANULARITY = "token" +LANGUAGE = "en" + +BASIC_CLARIFY_EXPLAINER_CONFIG_DICT = { + "ShapConfig": { + "ShapBaselineConfig": { + "MimeType": DEFAULT_MIME_TYPE, + "ShapBaseline": SHAP_BASELINE, + }, + "UseLogit": DEFAULT_USE_LOGIT, + } +} + +CLARIFY_EXPLAINER_CONFIG_DICT_WITH_ALL_OPTIONAL = { + "EnableExplanations": OPTIONAL_ENABLE_EXPLAINABITIONS, + "InferenceConfig": { + "FeaturesAttribute": OPTIONAL_FEATURE_ATTRIBUTION, + "ContentTemplate": OPTIONAL_CONTENT_TEMPLATE, + "MaxRecordCount": OPTIONAL_MAX_RECORD_COUNT, + "MaxPayloadInMB": OPTIONAL_MAX_PAYLOAD_IN_MB, + "ProbabilityIndex": OPTIONAL_PROBABILITY_INDEX, + "LabelIndex": OPTIONAL_LABEL_INDEX, + "ProbabilityAttribute": OPTIONAL_PROBABILITY_ATTRIBUTE, + "LabelAttribute": OPTIONAL_LABEL_ATTRIBUTE, + "LabelHeaders": OPTIONAL_LABEL_HEADERS, + "FeatureHeaders": OPTIONAL_FEATURE_HEADERS, + "FeatureTypes": OPTIONAL_FEATURE_TYPES, + }, + "ShapConfig": { + "ShapBaselineConfig": { + "MimeType": OPTIONAL_MIME_TYPE, + "ShapBaseline": SHAP_BASELINE, + "ShapBaselineUri": SHAP_BASELINE_PATH, + }, + "NumberOfSamples": OPTIONAL_NUM_OF_SAMPLES, + "UseLogit": OPTIONAL_USE_LOGIT, + "Seed": OPTIONAL_SEED, + "TextConfig": { + "Granularity": GRANULARITY, + "Language": LANGUAGE, + }, + }, +} + + +def test_init_with_basic_input(): + shap_baseline_config = ClarifyShapBaselineConfig(shap_baseline=SHAP_BASELINE) + shap_config = ClarifyShapConfig(shap_baseline_config=shap_baseline_config) + clarify_explainer_config = ClarifyExplainerConfig( + shap_config=shap_config, + ) + explainer_config = ExplainerConfig(clarify_explainer_config=clarify_explainer_config) + assert ( + explainer_config.clarify_explainer_config._to_request_dict() + == BASIC_CLARIFY_EXPLAINER_CONFIG_DICT + ) + + +def test_init_with_all_optionals(): + shap_baseline_config = ClarifyShapBaselineConfig( + mime_type=OPTIONAL_MIME_TYPE, + # the config won't take shap_baseline and shap_baseline_uri both but we have both + # here for testing purpose + shap_baseline=SHAP_BASELINE, + shap_baseline_uri=SHAP_BASELINE_PATH, + ) + test_config = ClarifyTextConfig(granularity=GRANULARITY, language=LANGUAGE) + shap_config = ClarifyShapConfig( + shap_baseline_config=shap_baseline_config, + number_of_samples=OPTIONAL_NUM_OF_SAMPLES, + seed=OPTIONAL_SEED, + use_logit=OPTIONAL_USE_LOGIT, + text_config=test_config, + ) + inference_config = ClarifyInferenceConfig( + content_template=OPTIONAL_CONTENT_TEMPLATE, + feature_headers=OPTIONAL_FEATURE_HEADERS, + features_attribute=OPTIONAL_FEATURE_ATTRIBUTION, + feature_types=OPTIONAL_FEATURE_TYPES, + label_attribute=OPTIONAL_LABEL_ATTRIBUTE, + label_headers=OPTIONAL_LABEL_HEADERS, + label_index=OPTIONAL_LABEL_INDEX, + max_payload_in_mb=OPTIONAL_MAX_PAYLOAD_IN_MB, + max_record_count=OPTIONAL_MAX_RECORD_COUNT, + probability_attribute=OPTIONAL_PROBABILITY_ATTRIBUTE, + probability_index=OPTIONAL_PROBABILITY_INDEX, + ) + clarify_explainer_config = ClarifyExplainerConfig( + shap_config=shap_config, + inference_config=inference_config, + enable_explanations=OPTIONAL_ENABLE_EXPLAINABITIONS, + ) + explainer_config = ExplainerConfig(clarify_explainer_config=clarify_explainer_config) + assert ( + explainer_config.clarify_explainer_config._to_request_dict() + == CLARIFY_EXPLAINER_CONFIG_DICT_WITH_ALL_OPTIONAL + ) diff --git a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py index c936655ae8..8afd9cd2e0 100644 --- a/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py +++ b/tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py @@ -10,6 +10,7 @@ ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.explainer import ExplainerConfig import pytest @@ -566,6 +567,7 @@ def test_deploy_right_size_with_model_package_succeeds( sagemaker_session.endpoint_from_production_variants.assert_called_with( async_inference_config_dict=None, data_capture_config_dict=None, + explainer_config_dict=None, kms_key=None, name="ir-endpoint-test", production_variants=IR_PRODUCTION_VARIANTS, @@ -587,6 +589,7 @@ def test_deploy_right_size_with_both_overrides_succeeds( sagemaker_session.endpoint_from_production_variants.assert_called_with( async_inference_config_dict=None, data_capture_config_dict=None, + explainer_config_dict=None, kms_key=None, name="ir-endpoint-test", production_variants=IR_OVERRIDDEN_PRODUCTION_VARIANTS, @@ -639,6 +642,7 @@ def test_deploy_right_size_serverless_override(sagemaker_session, default_right_ wait=True, data_capture_config_dict=None, async_inference_config_dict=None, + explainer_config_dict=None, ) @@ -660,6 +664,36 @@ def test_deploy_right_size_async_override(sagemaker_session, default_right_sized wait=True, data_capture_config_dict=None, async_inference_config_dict={"OutputConfig": {"S3OutputPath": "s3://some-path"}}, + explainer_config_dict=None, + ) + + +@patch("sagemaker.utils.name_from_base", MagicMock(return_value=MODEL_NAME)) +def test_deploy_right_size_explainer_config_override(sagemaker_session, default_right_sized_model): + default_right_sized_model.name = MODEL_NAME + mock_clarify_explainer_config = MagicMock() + mock_clarify_explainer_config_dict = { + "EnableExplanations": "`false`", + } + mock_clarify_explainer_config._to_request_dict.return_value = mock_clarify_explainer_config_dict + explainer_config = ExplainerConfig(clarify_explainer_config=mock_clarify_explainer_config) + explainer_config_dict = {"ClarifyExplainerConfig": mock_clarify_explainer_config_dict} + + default_right_sized_model.deploy( + instance_type="ml.c5.2xlarge", + initial_instance_count=1, + explainer_config=explainer_config, + ) + + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=MODEL_NAME, + production_variants=[ANY], + tags=None, + kms_key=None, + wait=True, + data_capture_config_dict=None, + async_inference_config_dict=None, + explainer_config_dict=explainer_config_dict, ) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index cecf1ecc62..ba28e80251 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -21,6 +21,7 @@ from sagemaker.model import Model from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.explainer import ExplainerConfig from tests.unit.sagemaker.inference_recommender.constants import ( DESCRIBE_COMPILATION_JOB_RESPONSE, DESCRIBE_MODEL_PACKAGE_RESPONSE, @@ -64,6 +65,9 @@ "InitialVariantWeight": 1, } +SHAP_BASELINE = '1,2,3,"good product"' +CSV_MIME_TYPE = "text/csv" + @pytest.fixture def sagemaker_session(): @@ -117,6 +121,7 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem tags=None, kms_key=None, wait=True, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=None, ) @@ -160,6 +165,7 @@ def test_deploy_accelerator_type( tags=None, kms_key=None, wait=True, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=None, ) @@ -185,6 +191,7 @@ def test_deploy_endpoint_name(sagemaker_session): tags=None, kms_key=None, wait=True, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=None, ) @@ -259,6 +266,7 @@ def test_deploy_tags(create_sagemaker_model, production_variant, name_from_base, tags=tags, kms_key=None, wait=True, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=None, ) @@ -281,6 +289,7 @@ def test_deploy_kms_key(production_variant, name_from_base, sagemaker_session): tags=None, kms_key=key, wait=True, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=None, ) @@ -302,6 +311,7 @@ def test_deploy_async(production_variant, name_from_base, sagemaker_session): tags=None, kms_key=None, wait=False, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=None, ) @@ -331,11 +341,57 @@ def test_deploy_data_capture_config(production_variant, name_from_base, sagemake tags=None, kms_key=None, wait=True, + explainer_config_dict=None, data_capture_config_dict=data_capture_config_dict, async_inference_config_dict=None, ) +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_explainer_config(production_variant, name_from_base, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + mock_clarify_explainer_config = Mock() + mock_clarify_explainer_config_dict = { + "EnableExplanations": "`true`", + } + mock_clarify_explainer_config._to_request_dict.return_value = mock_clarify_explainer_config_dict + explainer_config = ExplainerConfig(clarify_explainer_config=mock_clarify_explainer_config) + explainer_config_dict = {"ClarifyExplainerConfig": mock_clarify_explainer_config_dict} + + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + explainer_config=explainer_config, + ) + + sagemaker_session.endpoint_from_production_variants.assert_called_with( + name=ENDPOINT_NAME, + production_variants=[BASE_PRODUCTION_VARIANT], + tags=None, + kms_key=None, + wait=True, + explainer_config_dict=explainer_config_dict, + data_capture_config_dict=None, + async_inference_config_dict=None, + ) + + +def test_deploy_wrong_explainer_config(sagemaker_session): + model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session, role=ROLE) + + with pytest.raises(ValueError, match="explainer_config needs to be a ExplainerConfig object"): + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + explainer_config={}, + ) + + @patch("sagemaker.model.Model._create_sagemaker_model", Mock()) @patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) @patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) @@ -363,6 +419,7 @@ def test_deploy_async_inference(production_variant, name_from_base, sagemaker_se tags=None, kms_key=None, wait=True, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=async_inference_config_dict, ) @@ -408,6 +465,7 @@ def test_deploy_serverless_inference(production_variant, create_sagemaker_model, tags=None, kms_key=None, wait=True, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=None, ) @@ -744,6 +802,7 @@ def test_deploy_customized_volume_size_and_timeout( tags=None, kms_key=None, wait=True, + explainer_config_dict=None, data_capture_config_dict=None, async_inference_config_dict=None, ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 4e5c26d64e..3e7cbbd7b0 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2913,6 +2913,7 @@ def test_fit_deploy_tags_in_estimator(name_from_base, sagemaker_session): wait=True, data_capture_config_dict=None, async_inference_config_dict=None, + explainer_config_dict=None, ) sagemaker_session.create_model.assert_called_with( @@ -2963,6 +2964,7 @@ def test_fit_deploy_tags(name_from_base, sagemaker_session): wait=True, data_capture_config_dict=None, async_inference_config_dict=None, + explainer_config_dict=None, ) sagemaker_session.create_model.assert_called_with( @@ -3311,6 +3313,12 @@ def test_generic_to_deploy_bad_arguments_combination(sagemaker_session): ): e.deploy(serverless_inference_config={}) + with pytest.raises( + ValueError, + match="explainer_config needs to be a ExplainerConfig object", + ): + e.deploy(explainer_config={}) + def test_generic_to_deploy_network_isolation(sagemaker_session): e = Estimator( @@ -3367,6 +3375,7 @@ def test_generic_to_deploy_kms(create_model, sagemaker_session): model_data_download_timeout=None, container_startup_health_check_timeout=None, inference_recommendation_id=None, + explainer_config=None, ) @@ -3553,6 +3562,7 @@ def test_deploy_with_customized_volume_size_timeout(create_model, sagemaker_sess model_data_download_timeout=model_data_download_timeout_sec, container_startup_health_check_timeout=startup_health_check_timeout_sec, inference_recommendation_id=None, + explainer_config=None, ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0242ea6feb..9c3f38572f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -27,6 +27,7 @@ import sagemaker from sagemaker import TrainingInput, Session, get_execution_role, exceptions from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.explainer import ExplainerConfig from sagemaker.session import ( _tuning_job_status, _transform_job_status, @@ -2671,6 +2672,21 @@ def test_create_endpoint_config_with_tags(sagemaker_session): ) +def test_create_endpoint_config_with_explainer_config(sagemaker_session): + explainer_config = ExplainerConfig + + sagemaker_session.create_endpoint_config( + "endpoint-test", "simple-model", 1, "local", explainer_config_dict=explainer_config + ) + + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName="endpoint-test", + ProductionVariants=ANY, + Tags=ANY, + ExplainerConfig=explainer_config, + ) + + def test_endpoint_from_production_variants_with_tags(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"}) @@ -2792,6 +2808,32 @@ def test_endpoint_from_production_variants_with_async_config(sagemaker_session): ) +def test_endpoint_from_production_variants_with_clarify_explainer_config(sagemaker_session): + ims = sagemaker_session + ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"}) + pvs = [ + sagemaker.production_variant("A", "ml.p2.xlarge"), + sagemaker.production_variant("B", "p299.4096xlarge"), + ] + ex = ClientError( + {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + ) + ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) + sagemaker_session.endpoint_from_production_variants( + "some-endpoint", + pvs, + explainer_config_dict=ExplainerConfig, + ) + sagemaker_session.sagemaker_client.create_endpoint.assert_called_with( + EndpointConfigName="some-endpoint", EndpointName="some-endpoint", Tags=[] + ) + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName="some-endpoint", + ProductionVariants=pvs, + ExplainerConfig=ExplainerConfig, + ) + + def test_update_endpoint_succeed(sagemaker_session): sagemaker_session.sagemaker_client.describe_endpoint = Mock( return_value={"EndpointStatus": "InService"} @@ -2858,6 +2900,39 @@ def test_create_endpoint_config_from_existing(sagemaker_session): ) +def test_create_endpoint_config_from_existing_with_explainer_config(sagemaker_session): + pvs = [sagemaker.production_variant("A", "ml.m4.xlarge")] + tags = [{"Key": "aws:cloudformation:stackname", "Value": "this-tag-should-be-ignored"}] + existing_endpoint_arn = "arn:aws:sagemaker:us-west-2:123412341234:endpoint-config/foo" + kms_key = "kms" + existing_explainer_config = ExplainerConfig + sagemaker_session.sagemaker_client.describe_endpoint_config.return_value = { + "Tags": tags, + "ProductionVariants": pvs, + "EndpointConfigArn": existing_endpoint_arn, + "KmsKeyId": kms_key, + "ExplainerConfig": existing_explainer_config, + } + sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": tags} + + existing_endpoint_name = "foo" + new_endpoint_name = "new-foo" + sagemaker_session.create_endpoint_config_from_existing( + existing_endpoint_name, new_endpoint_name + ) + + sagemaker_session.sagemaker_client.describe_endpoint_config.assert_called_with( + EndpointConfigName=existing_endpoint_name + ) + + sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( + EndpointConfigName=new_endpoint_name, + ProductionVariants=pvs, + KmsKeyId=kms_key, + ExplainerConfig=existing_explainer_config, + ) + + @patch("time.sleep") def test_wait_for_tuning_job(sleep, sagemaker_session): hyperparameter_tuning_job_desc = {"HyperParameterTuningJobStatus": "Completed"}