Skip to content

Commit 1167339

Browse files
xiaoyi-chengclaytonparnell
authored andcommitted
feature: support create Clarify explainer enabled endpoint for Clarify Online Explainability (#3727)
Co-authored-by: Clayton Parnell <[email protected]>
1 parent 53fea69 commit 1167339

File tree

17 files changed

+909
-0
lines changed

17 files changed

+909
-0
lines changed

doc/api/inference/explainer.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Online Explainability
2+
---------------------
3+
4+
This module contains classes related to Amazon Sagemaker Clarify Online Explainability
5+
6+
.. automodule:: sagemaker.explainer.explainer_config
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
11+
.. automodule:: sagemaker.explainer.clarify_explainer_config
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
15+
16+

src/sagemaker/estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,7 @@ def deploy(
13781378
model_data_download_timeout=None,
13791379
container_startup_health_check_timeout=None,
13801380
inference_recommendation_id=None,
1381+
explainer_config=None,
13811382
**kwargs,
13821383
):
13831384
"""Deploy the trained model to an Amazon SageMaker endpoint.
@@ -1458,6 +1459,8 @@ def deploy(
14581459
inference_recommendation_id (str): The recommendation id which specifies the
14591460
recommendation you picked from inference recommendation job results and
14601461
would like to deploy the model and endpoint with recommended parameters.
1462+
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
1463+
configuration for use with Amazon SageMaker Clarify. (default: None)
14611464
**kwargs: Passed to invocation of ``create_model()``.
14621465
Implementations may customize ``create_model()`` to accept
14631466
``**kwargs`` to customize model creation during deploy.
@@ -1516,6 +1519,7 @@ def deploy(
15161519
data_capture_config=data_capture_config,
15171520
serverless_inference_config=serverless_inference_config,
15181521
async_inference_config=async_inference_config,
1522+
explainer_config=explainer_config,
15191523
volume_size=volume_size,
15201524
model_data_download_timeout=model_data_download_timeout,
15211525
container_startup_health_check_timeout=container_startup_health_check_timeout,

src/sagemaker/explainer/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Imports the classes in this module to simplify customer imports"""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.explainer.explainer_config import ExplainerConfig # noqa: F401
18+
from sagemaker.explainer.clarify_explainer_config import ( # noqa: F401
19+
ClarifyExplainerConfig,
20+
ClarifyInferenceConfig,
21+
ClarifyShapConfig,
22+
ClarifyShapBaselineConfig,
23+
ClarifyTextConfig,
24+
)

src/sagemaker/explainer/clarify_explainer_config.py

Lines changed: 298 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A member of ``CreateEndpointConfig`` that enables explainers."""
14+
15+
from __future__ import print_function, absolute_import
16+
from typing import Optional
17+
from sagemaker.explainer.clarify_explainer_config import ClarifyExplainerConfig
18+
19+
20+
class ExplainerConfig(object):
21+
"""A parameter to activate explainers."""
22+
23+
def __init__(
24+
self,
25+
clarify_explainer_config: Optional[ClarifyExplainerConfig] = None,
26+
):
27+
"""Initializes a config object to activate explainer.
28+
29+
Args:
30+
clarify_explainer_config (:class:`~sagemaker.explainer.explainer_config.ClarifyExplainerConfig`):
31+
Optional. A config contains parameters for the SageMaker Clarify explainer. (Default: None)
32+
""" # noqa E501 # pylint: disable=line-too-long
33+
self.clarify_explainer_config = clarify_explainer_config
34+
35+
def _to_request_dict(self):
36+
"""Generates a request dictionary using the parameters provided to the class."""
37+
request_dict = {}
38+
39+
if self.clarify_explainer_config:
40+
request_dict[
41+
"ClarifyExplainerConfig"
42+
] = self.clarify_explainer_config._to_request_dict()
43+
44+
return request_dict

src/sagemaker/huggingface/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def deploy(
210210
model_data_download_timeout=None,
211211
container_startup_health_check_timeout=None,
212212
inference_recommendation_id=None,
213+
explainer_config=None,
213214
**kwargs,
214215
):
215216
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -286,6 +287,8 @@ def deploy(
286287
inference_recommendation_id (str): The recommendation id which specifies the
287288
recommendation you picked from inference recommendation job results and
288289
would like to deploy the model and endpoint with recommended parameters.
290+
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
291+
configuration for use with Amazon SageMaker Clarify. (default: None)
289292
Raises:
290293
ValueError: If arguments combination check failed in these circumstances:
291294
- If no role is specified or
@@ -322,6 +325,7 @@ def deploy(
322325
model_data_download_timeout=model_data_download_timeout,
323326
container_startup_health_check_timeout=container_startup_health_check_timeout,
324327
inference_recommendation_id=inference_recommendation_id,
328+
explainer_config=explainer_config,
325329
)
326330

327331
def register(

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _update_params(
215215
accelerator_type = kwargs["accelerator_type"]
216216
async_inference_config = kwargs["async_inference_config"]
217217
serverless_inference_config = kwargs["serverless_inference_config"]
218+
explainer_config = kwargs["explainer_config"]
218219
inference_recommendation_id = kwargs["inference_recommendation_id"]
219220
inference_recommender_job_results = kwargs["inference_recommender_job_results"]
220221
if inference_recommendation_id is not None:
@@ -225,6 +226,7 @@ def _update_params(
225226
async_inference_config=async_inference_config,
226227
serverless_inference_config=serverless_inference_config,
227228
inference_recommendation_id=inference_recommendation_id,
229+
explainer_config=explainer_config,
228230
)
229231
elif inference_recommender_job_results is not None:
230232
inference_recommendation = self._update_params_for_right_size(
@@ -233,6 +235,7 @@ def _update_params(
233235
accelerator_type,
234236
serverless_inference_config,
235237
async_inference_config,
238+
explainer_config,
236239
)
237240
return inference_recommendation or (instance_type, initial_instance_count)
238241

@@ -243,6 +246,7 @@ def _update_params_for_right_size(
243246
accelerator_type=None,
244247
serverless_inference_config=None,
245248
async_inference_config=None,
249+
explainer_config=None,
246250
):
247251
"""Validates that Inference Recommendation parameters can be used in `model.deploy()`
248252
@@ -262,6 +266,8 @@ def _update_params_for_right_size(
262266
whether serverless_inference_config has been passed into `model.deploy()`.
263267
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig):
264268
whether async_inference_config has been passed into `model.deploy()`.
269+
explainer_config (sagemaker.explainer.ExplainerConfig): whether explainer_config
270+
has been passed into `model.deploy()`.
265271
266272
Returns:
267273
(string, int) or None: Top instance_type and associated initial_instance_count
@@ -285,6 +291,11 @@ def _update_params_for_right_size(
285291
"serverless_inference_config is specified. Overriding right_size() recommendations."
286292
)
287293
return None
294+
if explainer_config:
295+
LOGGER.warning(
296+
"explainer_config is specified. Overriding right_size() recommendations."
297+
)
298+
return None
288299

289300
instance_type = self.inference_recommendations[0]["EndpointConfiguration"]["InstanceType"]
290301
initial_instance_count = self.inference_recommendations[0]["EndpointConfiguration"][
@@ -300,6 +311,7 @@ def _update_params_for_recommendation_id(
300311
async_inference_config,
301312
serverless_inference_config,
302313
inference_recommendation_id,
314+
explainer_config,
303315
):
304316
"""Update parameters with inference recommendation results.
305317
@@ -332,6 +344,8 @@ def _update_params_for_recommendation_id(
332344
the recommendation you picked from inference recommendation job
333345
results and would like to deploy the model and endpoint with
334346
recommended parameters.
347+
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
348+
configuration for use with Amazon SageMaker Clarify. Default: None.
335349
Raises:
336350
ValueError: If arguments combination check failed in these circumstances:
337351
- If only one of instance type or instance count specified or
@@ -367,6 +381,8 @@ def _update_params_for_recommendation_id(
367381
raise ValueError(
368382
"serverless_inference_config is not compatible with inference_recommendation_id."
369383
)
384+
if explainer_config is not None:
385+
raise ValueError("explainer_config is not compatible with inference_recommendation_id.")
370386

371387
# Validate recommendation id
372388
if not re.match(r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$", inference_recommendation_id):

src/sagemaker/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from sagemaker.model_metrics import ModelMetrics
4343
from sagemaker.deprecations import removed_kwargs
4444
from sagemaker.drift_check_baselines import DriftCheckBaselines
45+
from sagemaker.explainer import ExplainerConfig
4546
from sagemaker.metadata_properties import MetadataProperties
4647
from sagemaker.predictor import PredictorBase
4748
from sagemaker.serverless import ServerlessInferenceConfig
@@ -1080,6 +1081,7 @@ def deploy(
10801081
model_data_download_timeout=None,
10811082
container_startup_health_check_timeout=None,
10821083
inference_recommendation_id=None,
1084+
explainer_config=None,
10831085
**kwargs,
10841086
):
10851087
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
@@ -1158,6 +1160,8 @@ def deploy(
11581160
inference_recommendation_id (str): The recommendation id which specifies the
11591161
recommendation you picked from inference recommendation job results and
11601162
would like to deploy the model and endpoint with recommended parameters.
1163+
explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability
1164+
configuration for use with Amazon SageMaker Clarify. Default: None.
11611165
Raises:
11621166
ValueError: If arguments combination check failed in these circumstances:
11631167
- If no role is specified or
@@ -1204,6 +1208,7 @@ def deploy(
12041208
accelerator_type=accelerator_type,
12051209
async_inference_config=async_inference_config,
12061210
serverless_inference_config=serverless_inference_config,
1211+
explainer_config=explainer_config,
12071212
inference_recommendation_id=inference_recommendation_id,
12081213
inference_recommender_job_results=self.inference_recommender_job_results,
12091214
)
@@ -1212,6 +1217,10 @@ def deploy(
12121217
if is_async and not isinstance(async_inference_config, AsyncInferenceConfig):
12131218
raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object")
12141219

1220+
is_explainer_enabled = explainer_config is not None
1221+
if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig):
1222+
raise ValueError("explainer_config needs to be a ExplainerConfig object")
1223+
12151224
is_serverless = serverless_inference_config is not None
12161225
if not is_serverless and not (instance_type and initial_instance_count):
12171226
raise ValueError(
@@ -1282,13 +1291,18 @@ def deploy(
12821291
)
12831292
async_inference_config_dict = async_inference_config._to_request_dict()
12841293

1294+
explainer_config_dict = None
1295+
if is_explainer_enabled:
1296+
explainer_config_dict = explainer_config._to_request_dict()
1297+
12851298
self.sagemaker_session.endpoint_from_production_variants(
12861299
name=self.endpoint_name,
12871300
production_variants=[production_variant],
12881301
tags=tags,
12891302
kms_key=kms_key,
12901303
wait=wait,
12911304
data_capture_config_dict=data_capture_config_dict,
1305+
explainer_config_dict=explainer_config_dict,
12921306
async_inference_config_dict=async_inference_config_dict,
12931307
)
12941308

src/sagemaker/session.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3663,6 +3663,7 @@ def create_endpoint_config(
36633663
volume_size=None,
36643664
model_data_download_timeout=None,
36653665
container_startup_health_check_timeout=None,
3666+
explainer_config_dict=None,
36663667
):
36673668
"""Create an Amazon SageMaker endpoint configuration.
36683669
@@ -3696,6 +3697,8 @@ def create_endpoint_config(
36963697
inference container to pass health check by SageMaker Hosting. For more information
36973698
about health check see:
36983699
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
3700+
explainer_config_dict (dict): Specifies configuration to enable explainers.
3701+
Default: None.
36993702
37003703
Example:
37013704
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
@@ -3751,6 +3754,9 @@ def create_endpoint_config(
37513754
)
37523755
request["DataCaptureConfig"] = inferred_data_capture_config_dict
37533756

3757+
if explainer_config_dict is not None:
3758+
request["ExplainerConfig"] = explainer_config_dict
3759+
37543760
self.sagemaker_client.create_endpoint_config(**request)
37553761
return name
37563762

@@ -3762,6 +3768,7 @@ def create_endpoint_config_from_existing(
37623768
new_kms_key=None,
37633769
new_data_capture_config_dict=None,
37643770
new_production_variants=None,
3771+
new_explainer_config_dict=None,
37653772
):
37663773
"""Create an Amazon SageMaker endpoint configuration from an existing one.
37673774
@@ -3789,6 +3796,9 @@ def create_endpoint_config_from_existing(
37893796
new_production_variants (list[dict]): The configuration for which model(s) to host and
37903797
the resources to deploy for hosting the model(s). If not specified,
37913798
the ``ProductionVariants`` of the existing endpoint configuration is used.
3799+
new_explainer_config_dict (dict): Specifies configuration to enable explainers.
3800+
(default: None). If not specified, the explainer configuration of the existing
3801+
endpoint configuration is used.
37923802
37933803
Returns:
37943804
str: Name of the endpoint point configuration created.
@@ -3856,6 +3866,13 @@ def create_endpoint_config_from_existing(
38563866
)
38573867
request["AsyncInferenceConfig"] = inferred_async_inference_config_dict
38583868

3869+
request_explainer_config_dict = (
3870+
new_explainer_config_dict or existing_endpoint_config_desc.get("ExplainerConfig", None)
3871+
)
3872+
3873+
if request_explainer_config_dict is not None:
3874+
request["ExplainerConfig"] = request_explainer_config_dict
3875+
38593876
self.sagemaker_client.create_endpoint_config(**request)
38603877

38613878
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
@@ -4372,6 +4389,7 @@ def endpoint_from_production_variants(
43724389
wait=True,
43734390
data_capture_config_dict=None,
43744391
async_inference_config_dict=None,
4392+
explainer_config_dict=None,
43754393
):
43764394
"""Create an SageMaker ``Endpoint`` from a list of production variants.
43774395
@@ -4389,6 +4407,9 @@ def endpoint_from_production_variants(
43894407
async_inference_config_dict (dict) : specifies configuration related to async endpoint.
43904408
Use this configuration when trying to create async endpoint and make async inference
43914409
(default: None)
4410+
explainer_config_dict (dict) : Specifies configuration related to explainer.
4411+
Use this configuration when trying to use online explainability.
4412+
(default: None)
43924413
Returns:
43934414
str: The name of the created ``Endpoint``.
43944415
"""
@@ -4422,6 +4443,8 @@ def endpoint_from_production_variants(
44224443
sagemaker_session=self,
44234444
)
44244445
config_options["AsyncInferenceConfig"] = inferred_async_inference_config_dict
4446+
if explainer_config_dict is not None:
4447+
config_options["ExplainerConfig"] = explainer_config_dict
44254448

44264449
LOGGER.info("Creating endpoint-config with name %s", name)
44274450
self.sagemaker_client.create_endpoint_config(**config_options)

src/sagemaker/tensorflow/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def deploy(
324324
model_data_download_timeout=None,
325325
container_startup_health_check_timeout=None,
326326
inference_recommendation_id=None,
327+
explainer_config=None,
327328
):
328329
"""Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``."""
329330

@@ -349,6 +350,7 @@ def deploy(
349350
container_startup_health_check_timeout=container_startup_health_check_timeout,
350351
update_endpoint=update_endpoint,
351352
inference_recommendation_id=inference_recommendation_id,
353+
explainer_config=explainer_config,
352354
)
353355

354356
def _eia_supported(self):

0 commit comments

Comments
 (0)