Skip to content

feature: support create Clarify explainer enabled endpoint for Clarify Online Explainability #3727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions doc/api/inference/explainer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Online Explainability
---------------------

This module contains classes related to Amazon Sagemaker Clarify Online Explainability

.. automodule:: sagemaker.explainer.explainer_config
:members:
:undoc-members:
:show-inheritance:

.. automodule:: sagemaker.explainer.clarify_explainer_config
:members:
:undoc-members:
:show-inheritance:


4 changes: 4 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,7 @@ def deploy(
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
inference_recommendation_id=None,
explainer_config=None,
**kwargs,
):
"""Deploy the trained model to an Amazon SageMaker endpoint.
Expand Down Expand Up @@ -1458,6 +1459,8 @@ def deploy(
inference_recommendation_id (str): The recommendation id which specifies the
recommendation you picked from inference recommendation job results and
would like to deploy the model and endpoint with recommended parameters.
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
configuration for use with Amazon SageMaker Clarify. (default: None)
**kwargs: Passed to invocation of ``create_model()``.
Implementations may customize ``create_model()`` to accept
``**kwargs`` to customize model creation during deploy.
Expand Down Expand Up @@ -1516,6 +1519,7 @@ def deploy(
data_capture_config=data_capture_config,
serverless_inference_config=serverless_inference_config,
async_inference_config=async_inference_config,
explainer_config=explainer_config,
volume_size=volume_size,
model_data_download_timeout=model_data_download_timeout,
container_startup_health_check_timeout=container_startup_health_check_timeout,
Expand Down
24 changes: 24 additions & 0 deletions src/sagemaker/explainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Imports the classes in this module to simplify customer imports"""

from __future__ import absolute_import

from sagemaker.explainer.explainer_config import ExplainerConfig # noqa: F401
from sagemaker.explainer.clarify_explainer_config import ( # noqa: F401
ClarifyExplainerConfig,
ClarifyInferenceConfig,
ClarifyShapConfig,
ClarifyShapBaselineConfig,
ClarifyTextConfig,
)
298 changes: 298 additions & 0 deletions src/sagemaker/explainer/clarify_explainer_config.py

Large diffs are not rendered by default.

44 changes: 44 additions & 0 deletions src/sagemaker/explainer/explainer_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""A member of ``CreateEndpointConfig`` that enables explainers."""

from __future__ import print_function, absolute_import
from typing import Optional
from sagemaker.explainer.clarify_explainer_config import ClarifyExplainerConfig


class ExplainerConfig(object):
"""A parameter to activate explainers."""

def __init__(
self,
clarify_explainer_config: Optional[ClarifyExplainerConfig] = None,
):
"""Initializes a config object to activate explainer.

Args:
clarify_explainer_config (:class:`~sagemaker.explainer.explainer_config.ClarifyExplainerConfig`):
Optional. A config contains parameters for the SageMaker Clarify explainer. (Default: None)
""" # noqa E501 # pylint: disable=line-too-long
self.clarify_explainer_config = clarify_explainer_config

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
request_dict = {}

if self.clarify_explainer_config:
request_dict[
"ClarifyExplainerConfig"
] = self.clarify_explainer_config._to_request_dict()

return request_dict
4 changes: 4 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def deploy(
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
inference_recommendation_id=None,
explainer_config=None,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
Expand Down Expand Up @@ -286,6 +287,8 @@ def deploy(
inference_recommendation_id (str): The recommendation id which specifies the
recommendation you picked from inference recommendation job results and
would like to deploy the model and endpoint with recommended parameters.
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
configuration for use with Amazon SageMaker Clarify. (default: None)
Raises:
ValueError: If arguments combination check failed in these circumstances:
- If no role is specified or
Expand Down Expand Up @@ -322,6 +325,7 @@ def deploy(
model_data_download_timeout=model_data_download_timeout,
container_startup_health_check_timeout=container_startup_health_check_timeout,
inference_recommendation_id=inference_recommendation_id,
explainer_config=explainer_config,
)

def register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _update_params(
accelerator_type = kwargs["accelerator_type"]
async_inference_config = kwargs["async_inference_config"]
serverless_inference_config = kwargs["serverless_inference_config"]
explainer_config = kwargs["explainer_config"]
inference_recommendation_id = kwargs["inference_recommendation_id"]
inference_recommender_job_results = kwargs["inference_recommender_job_results"]
if inference_recommendation_id is not None:
Expand All @@ -225,6 +226,7 @@ def _update_params(
async_inference_config=async_inference_config,
serverless_inference_config=serverless_inference_config,
inference_recommendation_id=inference_recommendation_id,
explainer_config=explainer_config,
)
elif inference_recommender_job_results is not None:
inference_recommendation = self._update_params_for_right_size(
Expand All @@ -233,6 +235,7 @@ def _update_params(
accelerator_type,
serverless_inference_config,
async_inference_config,
explainer_config,
)
return inference_recommendation or (instance_type, initial_instance_count)

Expand All @@ -243,6 +246,7 @@ def _update_params_for_right_size(
accelerator_type=None,
serverless_inference_config=None,
async_inference_config=None,
explainer_config=None,
):
"""Validates that Inference Recommendation parameters can be used in `model.deploy()`

Expand All @@ -262,6 +266,8 @@ def _update_params_for_right_size(
whether serverless_inference_config has been passed into `model.deploy()`.
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig):
whether async_inference_config has been passed into `model.deploy()`.
explainer_config (sagemaker.explainer.ExplainerConfig): whether explainer_config
has been passed into `model.deploy()`.

Returns:
(string, int) or None: Top instance_type and associated initial_instance_count
Expand All @@ -285,6 +291,11 @@ def _update_params_for_right_size(
"serverless_inference_config is specified. Overriding right_size() recommendations."
)
return None
if explainer_config:
LOGGER.warning(
"explainer_config is specified. Overriding right_size() recommendations."
)
return None

instance_type = self.inference_recommendations[0]["EndpointConfiguration"]["InstanceType"]
initial_instance_count = self.inference_recommendations[0]["EndpointConfiguration"][
Expand All @@ -300,6 +311,7 @@ def _update_params_for_recommendation_id(
async_inference_config,
serverless_inference_config,
inference_recommendation_id,
explainer_config,
):
"""Update parameters with inference recommendation results.

Expand Down Expand Up @@ -332,6 +344,8 @@ def _update_params_for_recommendation_id(
the recommendation you picked from inference recommendation job
results and would like to deploy the model and endpoint with
recommended parameters.
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
configuration for use with Amazon SageMaker Clarify. Default: None.
Raises:
ValueError: If arguments combination check failed in these circumstances:
- If only one of instance type or instance count specified or
Expand Down Expand Up @@ -367,6 +381,8 @@ def _update_params_for_recommendation_id(
raise ValueError(
"serverless_inference_config is not compatible with inference_recommendation_id."
)
if explainer_config is not None:
raise ValueError("explainer_config is not compatible with inference_recommendation_id.")

# Validate recommendation id
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sagemaker.model_metrics import ModelMetrics
from sagemaker.deprecations import removed_kwargs
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.explainer import ExplainerConfig
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.predictor import PredictorBase
from sagemaker.serverless import ServerlessInferenceConfig
Expand Down Expand Up @@ -1080,6 +1081,7 @@ def deploy(
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
inference_recommendation_id=None,
explainer_config=None,
**kwargs,
):
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
Expand Down Expand Up @@ -1158,6 +1160,8 @@ def deploy(
inference_recommendation_id (str): The recommendation id which specifies the
recommendation you picked from inference recommendation job results and
would like to deploy the model and endpoint with recommended parameters.
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
configuration for use with Amazon SageMaker Clarify. Default: None.
Raises:
ValueError: If arguments combination check failed in these circumstances:
- If no role is specified or
Expand Down Expand Up @@ -1204,6 +1208,7 @@ def deploy(
accelerator_type=accelerator_type,
async_inference_config=async_inference_config,
serverless_inference_config=serverless_inference_config,
explainer_config=explainer_config,
inference_recommendation_id=inference_recommendation_id,
inference_recommender_job_results=self.inference_recommender_job_results,
)
Expand All @@ -1212,6 +1217,10 @@ def deploy(
if is_async and not isinstance(async_inference_config, AsyncInferenceConfig):
raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object")

is_explainer_enabled = explainer_config is not None
if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig):
raise ValueError("explainer_config needs to be a ExplainerConfig object")

is_serverless = serverless_inference_config is not None
if not is_serverless and not (instance_type and initial_instance_count):
raise ValueError(
Expand Down Expand Up @@ -1279,13 +1288,18 @@ def deploy(
)
async_inference_config_dict = async_inference_config._to_request_dict()

explainer_config_dict = None
if is_explainer_enabled:
explainer_config_dict = explainer_config._to_request_dict()

self.sagemaker_session.endpoint_from_production_variants(
name=self.endpoint_name,
production_variants=[production_variant],
tags=tags,
kms_key=kms_key,
wait=wait,
data_capture_config_dict=data_capture_config_dict,
explainer_config_dict=explainer_config_dict,
async_inference_config_dict=async_inference_config_dict,
)

Expand Down
23 changes: 23 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3663,6 +3663,7 @@ def create_endpoint_config(
volume_size=None,
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
explainer_config_dict=None,
):
"""Create an Amazon SageMaker endpoint configuration.

Expand Down Expand Up @@ -3696,6 +3697,8 @@ def create_endpoint_config(
inference container to pass health check by SageMaker Hosting. For more information
about health check see:
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
explainer_config_dict (dict): Specifies configuration to enable explainers.
Default: None.

Example:
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
Expand Down Expand Up @@ -3751,6 +3754,9 @@ def create_endpoint_config(
)
request["DataCaptureConfig"] = inferred_data_capture_config_dict

if explainer_config_dict is not None:
request["ExplainerConfig"] = explainer_config_dict

self.sagemaker_client.create_endpoint_config(**request)
return name

Expand All @@ -3762,6 +3768,7 @@ def create_endpoint_config_from_existing(
new_kms_key=None,
new_data_capture_config_dict=None,
new_production_variants=None,
new_explainer_config_dict=None,
):
"""Create an Amazon SageMaker endpoint configuration from an existing one.

Expand Down Expand Up @@ -3789,6 +3796,9 @@ def create_endpoint_config_from_existing(
new_production_variants (list[dict]): The configuration for which model(s) to host and
the resources to deploy for hosting the model(s). If not specified,
the ``ProductionVariants`` of the existing endpoint configuration is used.
new_explainer_config_dict (dict): Specifies configuration to enable explainers.
(default: None). If not specified, the explainer configuration of the existing
endpoint configuration is used.

Returns:
str: Name of the endpoint point configuration created.
Expand Down Expand Up @@ -3856,6 +3866,13 @@ def create_endpoint_config_from_existing(
)
request["AsyncInferenceConfig"] = inferred_async_inference_config_dict

request_explainer_config_dict = (
new_explainer_config_dict or existing_endpoint_config_desc.get("ExplainerConfig", None)
)

if request_explainer_config_dict is not None:
request["ExplainerConfig"] = request_explainer_config_dict

self.sagemaker_client.create_endpoint_config(**request)

def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
Expand Down Expand Up @@ -4372,6 +4389,7 @@ def endpoint_from_production_variants(
wait=True,
data_capture_config_dict=None,
async_inference_config_dict=None,
explainer_config_dict=None,
):
"""Create an SageMaker ``Endpoint`` from a list of production variants.

Expand All @@ -4389,6 +4407,9 @@ def endpoint_from_production_variants(
async_inference_config_dict (dict) : specifies configuration related to async endpoint.
Use this configuration when trying to create async endpoint and make async inference
(default: None)
explainer_config_dict (dict) : Specifies configuration related to explainer.
Use this configuration when trying to use online explainability.
(default: None)
Returns:
str: The name of the created ``Endpoint``.
"""
Expand Down Expand Up @@ -4422,6 +4443,8 @@ def endpoint_from_production_variants(
sagemaker_session=self,
)
config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict
if explainer_config_dict is not None:
config_options["ExplainerConfig"] = explainer_config_dict

LOGGER.info("Creating endpoint-config with name %s", name)
self.sagemaker_client.create_endpoint_config(**config_options)
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def deploy(
model_data_download_timeout=None,
container_startup_health_check_timeout=None,
inference_recommendation_id=None,
explainer_config=None,
):
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""

Expand All @@ -349,6 +350,7 @@ def deploy(
container_startup_health_check_timeout=container_startup_health_check_timeout,
update_endpoint=update_endpoint,
inference_recommendation_id=inference_recommendation_id,
explainer_config=explainer_config,
)

def _eia_supported(self):
Expand Down
Loading