Skip to content

Commit 4bf250b

Browse files
feature: Add support for Partial Dependence Plots(PDP) in SageMaker Clarify
1 parent 554ba08 commit 4bf250b

File tree

3 files changed

+196
-34
lines changed

3 files changed

+196
-34
lines changed

src/sagemaker/clarify.py

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
import tempfile
2222
import re
23+
2324
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
2425
from sagemaker import image_uris, s3, utils
2526

@@ -292,7 +293,30 @@ class ExplainabilityConfig(ABC):
292293
@abstractmethod
293294
def get_explainability_config(self):
294295
"""Returns config."""
295-
return None
296+
297+
298+
class PDPConfig(ExplainabilityConfig):
299+
"""Config class for Partial Dependence Plots (PDP)"""
300+
301+
def __init__(self, features=None, grid_resolution=None):
302+
"""Initializes config for PDP.
303+
304+
Args:
305+
features (None or list): List of features names or indices for which partial dependence
306+
plots must be computed and plotted.
307+
grid_resolution (int): In case of numerical features, this number represents that
308+
number of buckets that range of values must be divided into. This decides the
309+
granularity of the grid in which the PDP are plotted.
310+
"""
311+
self.pdp_config = {}
312+
if features is not None:
313+
self.pdp_config["features"] = features
314+
if grid_resolution is not None:
315+
self.pdp_config["grid_resolution"] = grid_resolution
316+
317+
def get_explainability_config(self):
318+
"""Returns config."""
319+
return copy.deepcopy({"pdp": self.pdp_config})
296320

297321

298322
class SHAPConfig(ExplainabilityConfig):
@@ -466,7 +490,10 @@ def _run(
466490
will be unassociated.
467491
* `TrialComponentDisplayName` is used for display in Studio.
468492
"""
469-
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
493+
analysis_config["methods"]["report"] = {
494+
"name": "report",
495+
"title": "Analysis Report",
496+
}
470497
with tempfile.TemporaryDirectory() as tmpdirname:
471498
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
472499
with open(analysis_config_file, "w") as f:
@@ -568,7 +595,15 @@ def run_pre_training_bias(
568595
job_name = utils.name_from_base(self.job_name_prefix)
569596
else:
570597
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
571-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
598+
self._run(
599+
data_config,
600+
analysis_config,
601+
wait,
602+
logs,
603+
job_name,
604+
kms_key,
605+
experiment_config,
606+
)
572607

573608
def run_post_training_bias(
574609
self,
@@ -646,7 +681,15 @@ def run_post_training_bias(
646681
job_name = utils.name_from_base(self.job_name_prefix)
647682
else:
648683
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
649-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
684+
self._run(
685+
data_config,
686+
analysis_config,
687+
wait,
688+
logs,
689+
job_name,
690+
kms_key,
691+
experiment_config,
692+
)
650693

651694
def run_bias(
652695
self,
@@ -741,7 +784,15 @@ def run_bias(
741784
job_name = utils.name_from_base(self.job_name_prefix)
742785
else:
743786
job_name = utils.name_from_base("Clarify-Bias")
744-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
787+
self._run(
788+
data_config,
789+
analysis_config,
790+
wait,
791+
logs,
792+
job_name,
793+
kms_key,
794+
experiment_config,
795+
)
745796

746797
def run_explainability(
747798
self,
@@ -771,8 +822,9 @@ def run_explainability(
771822
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
772823
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
773824
endpoint to be created.
774-
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
775-
specific explainability method. Currently, only SHAP is supported.
825+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
826+
Config of the specific explainability method or a list of ExplainabilityConfig
827+
objects. Currently, SHAP and PDP are the two methods supported.
776828
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
777829
model output for the predicted scores to be explained. This is not required if the
778830
model output is a single score. Alternatively, an instance of
@@ -781,7 +833,7 @@ def run_explainability(
781833
logs (bool): Whether to show the logs produced by the job.
782834
Only meaningful when ``wait`` is True (default: True).
783835
job_name (str): Processing job name. When ``job_name`` is not specified, if
784-
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
836+
`job_name_prefix` in :class:`SageMakerClarifyProcessor` specified, the job name
785837
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
786838
"Clarify-Explainability" as prefix.
787839
kms_key (str): The ARN of the KMS key that is used to encrypt the
@@ -801,19 +853,41 @@ def run_explainability(
801853
analysis_config = data_config.get_config()
802854
predictor_config = model_config.get_predictor_config()
803855
if isinstance(model_scores, ModelPredictedLabelConfig):
804-
probability_threshold, predicted_label_config = model_scores.get_predictor_config()
856+
(
857+
probability_threshold,
858+
predicted_label_config,
859+
) = model_scores.get_predictor_config()
805860
_set(probability_threshold, "probability_threshold", analysis_config)
806861
predictor_config.update(predicted_label_config)
807862
else:
808863
_set(model_scores, "label", predictor_config)
809-
analysis_config["methods"] = explainability_config.get_explainability_config()
864+
865+
explainability_methods = {}
866+
if isinstance(explainability_config, list):
867+
assert (
868+
len(explainability_config) > 0
869+
), "Invalid Input: Expected list of ExplainabilityConfig but found empty list instead"
870+
for config in explainability_config:
871+
explain_config = config.get_explainability_config()
872+
explainability_methods.update(explain_config)
873+
else:
874+
explainability_methods = explainability_config.get_explainability_config()
875+
analysis_config["methods"] = explainability_methods
810876
analysis_config["predictor"] = predictor_config
811877
if job_name is None:
812878
if self.job_name_prefix:
813879
job_name = utils.name_from_base(self.job_name_prefix)
814880
else:
815881
job_name = utils.name_from_base("Clarify-Explainability")
816-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
882+
self._run(
883+
data_config,
884+
analysis_config,
885+
wait,
886+
logs,
887+
job_name,
888+
kms_key,
889+
experiment_config,
890+
)
817891

818892

819893
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):

tests/integ/test_clarify.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from __future__ import print_function, absolute_import
1515

16-
1716
import json
1817
import math
1918
import numpy as np
@@ -31,14 +30,14 @@
3130
ModelConfig,
3231
ModelPredictedLabelConfig,
3332
SHAPConfig,
33+
PDPConfig,
3434
)
3535

3636
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerPredictor
3737
from sagemaker import utils
3838
from tests import integ
3939
from tests.integ import timeout
4040

41-
4241
CLARIFY_DEFAULT_TIMEOUT_MINUTES = 15
4342

4443

@@ -177,6 +176,11 @@ def shap_config():
177176
)
178177

179178

179+
@pytest.fixture(scope="module")
180+
def pdp_config():
181+
return PDPConfig(features=["F1"], grid_resolution=10)
182+
183+
180184
def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sagemaker_session):
181185
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
182186
clarify_processor.run_pre_training_bias(

0 commit comments

Comments
 (0)