Skip to content

Commit 1db1ff8

Browse files
feature: Add support for Partial Dependence Plots(PDP) in SageMaker Clarify
1 parent 25da5cc commit 1db1ff8

File tree

3 files changed

+231
-34
lines changed

3 files changed

+231
-34
lines changed

src/sagemaker/clarify.py

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import tempfile
2222
import re
23+
2324
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
2425
from sagemaker import image_uris, s3, utils
2526

@@ -297,7 +298,30 @@ class ExplainabilityConfig(ABC):
297298
@abstractmethod
298299
def get_explainability_config(self):
299300
"""Returns config."""
300-
return None
301+
302+
303+
class PDPConfig(ExplainabilityConfig):
304+
"""Config class for Partial Dependence Plots (PDP)"""
305+
306+
def __init__(self, features=None, grid_resolution=None):
307+
"""Initializes config for PDP.
308+
309+
Args:
310+
features (None or list): List of features names or indices for which partial dependence
311+
plots must be computed and plotted.
312+
grid_resolution (int): In case of numerical features, this number represents that
313+
number of buckets that range of values must be divided into. This decides the
314+
granularity of the grid in which the PDP are plotted.
315+
"""
316+
self.pdp_config = {}
317+
if features is not None:
318+
self.pdp_config["features"] = features
319+
if grid_resolution is not None:
320+
self.pdp_config["grid_resolution"] = grid_resolution
321+
322+
def get_explainability_config(self):
323+
"""Returns config."""
324+
return copy.deepcopy({"pdp": self.pdp_config})
301325

302326

303327
class SHAPConfig(ExplainabilityConfig):
@@ -471,7 +495,10 @@ def _run(
471495
will be unassociated.
472496
* `TrialComponentDisplayName` is used for display in Studio.
473497
"""
474-
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
498+
analysis_config["methods"]["report"] = {
499+
"name": "report",
500+
"title": "Analysis Report",
501+
}
475502
with tempfile.TemporaryDirectory() as tmpdirname:
476503
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
477504
with open(analysis_config_file, "w") as f:
@@ -573,7 +600,15 @@ def run_pre_training_bias(
573600
job_name = utils.name_from_base(self.job_name_prefix)
574601
else:
575602
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
576-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
603+
self._run(
604+
data_config,
605+
analysis_config,
606+
wait,
607+
logs,
608+
job_name,
609+
kms_key,
610+
experiment_config,
611+
)
577612

578613
def run_post_training_bias(
579614
self,
@@ -651,7 +686,15 @@ def run_post_training_bias(
651686
job_name = utils.name_from_base(self.job_name_prefix)
652687
else:
653688
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
654-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
689+
self._run(
690+
data_config,
691+
analysis_config,
692+
wait,
693+
logs,
694+
job_name,
695+
kms_key,
696+
experiment_config,
697+
)
655698

656699
def run_bias(
657700
self,
@@ -746,7 +789,15 @@ def run_bias(
746789
job_name = utils.name_from_base(self.job_name_prefix)
747790
else:
748791
job_name = utils.name_from_base("Clarify-Bias")
749-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
792+
self._run(
793+
data_config,
794+
analysis_config,
795+
wait,
796+
logs,
797+
job_name,
798+
kms_key,
799+
experiment_config,
800+
)
750801

751802
def run_explainability(
752803
self,
@@ -776,8 +827,9 @@ def run_explainability(
776827
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
777828
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
778829
endpoint to be created.
779-
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
780-
specific explainability method. Currently, only SHAP is supported.
830+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
831+
Config of the specific explainability method or a list of ExplainabilityConfig
832+
objects. Currently, SHAP and PDP are the two methods supported.
781833
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
782834
model output for the predicted scores to be explained. This is not required if the
783835
model output is a single score. Alternatively, an instance of
@@ -786,7 +838,7 @@ def run_explainability(
786838
logs (bool): Whether to show the logs produced by the job.
787839
Only meaningful when ``wait`` is True (default: True).
788840
job_name (str): Processing job name. When ``job_name`` is not specified, if
789-
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
841+
`job_name_prefix` in :class:`SageMakerClarifyProcessor` specified, the job name
790842
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
791843
"Clarify-Explainability" as prefix.
792844
kms_key (str): The ARN of the KMS key that is used to encrypt the
@@ -806,19 +858,44 @@ def run_explainability(
806858
analysis_config = data_config.get_config()
807859
predictor_config = model_config.get_predictor_config()
808860
if isinstance(model_scores, ModelPredictedLabelConfig):
809-
probability_threshold, predicted_label_config = model_scores.get_predictor_config()
861+
(
862+
probability_threshold,
863+
predicted_label_config,
864+
) = model_scores.get_predictor_config()
810865
_set(probability_threshold, "probability_threshold", analysis_config)
811866
predictor_config.update(predicted_label_config)
812867
else:
813868
_set(model_scores, "label", predictor_config)
814-
analysis_config["methods"] = explainability_config.get_explainability_config()
869+
870+
explainability_methods = {}
871+
if isinstance(explainability_config, list):
872+
assert (
873+
len(explainability_config) > 0
874+
), "Please provide at least one explaianbility config."
875+
for config in explainability_config:
876+
explain_config = config.get_explainability_config()
877+
explainability_methods.update(explain_config)
878+
assert len(explainability_methods.keys()) == len(
879+
explainability_config
880+
), "There are duplicate explainability configs"
881+
else:
882+
explainability_methods = explainability_config.get_explainability_config()
883+
analysis_config["methods"] = explainability_methods
815884
analysis_config["predictor"] = predictor_config
816885
if job_name is None:
817886
if self.job_name_prefix:
818887
job_name = utils.name_from_base(self.job_name_prefix)
819888
else:
820889
job_name = utils.name_from_base("Clarify-Explainability")
821-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
890+
self._run(
891+
data_config,
892+
analysis_config,
893+
wait,
894+
logs,
895+
job_name,
896+
kms_key,
897+
experiment_config,
898+
)
822899

823900

824901
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):

tests/integ/test_clarify.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from __future__ import print_function, absolute_import
1515

16-
1716
import json
1817
import math
1918
import numpy as np
@@ -31,14 +30,14 @@
3130
ModelConfig,
3231
ModelPredictedLabelConfig,
3332
SHAPConfig,
33+
PDPConfig,
3434
)
3535

3636
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerPredictor
3737
from sagemaker import utils
3838
from tests import integ
3939
from tests.integ import timeout
4040

41-
4241
CLARIFY_DEFAULT_TIMEOUT_MINUTES = 15
4342

4443

@@ -177,6 +176,11 @@ def shap_config():
177176
)
178177

179178

179+
@pytest.fixture(scope="module")
180+
def pdp_config():
181+
return PDPConfig(features=["F1"], grid_resolution=10)
182+
183+
180184
def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sagemaker_session):
181185
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
182186
clarify_processor.run_pre_training_bias(

0 commit comments

Comments
 (0)