Skip to content

Commit 65bc046

Browse files
feature: Add support for Partial Dependence Plots(PDP) in SageMaker Clarify
1 parent 25da5cc commit 65bc046

File tree

2 files changed

+201
-24
lines changed

2 files changed

+201
-24
lines changed

src/sagemaker/clarify.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,37 @@ def get_explainability_config(self):
300300
return None
301301

302302

303+
class PDPConfig(ExplainabilityConfig):
304+
"""Config class for Partial Dependence Plots (PDP).
305+
306+
If PDP is requested, the Partial Dependence Plots will be included in the report, and the
307+
corresponding values will be included in the analysis output.
308+
"""
309+
310+
def __init__(self, features=None, grid_resolution=15, top_k_features=10):
311+
"""Initializes config for PDP.
312+
313+
Args:
314+
features (None or list): List of features names or indices for which partial dependence
315+
plots must be computed and plotted. When ShapConfig is provided, this parameter is
316+
optional as Clarify will try to compute the partial dependence plots for top
317+
feature based on SHAP attributions. When ShapConfig is not provided, 'features'
318+
must be provided.
319+
grid_resolution (int): In case of numerical features, this number represents that
320+
number of buckets that range of values must be divided into. This decides the
321+
granularity of the grid in which the PDP are plotted.
322+
top_k_features (int): Set the number of top SHAP attributes to be selected to compute
323+
partial dependence plots.
324+
"""
325+
self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features}
326+
if features is not None:
327+
self.pdp_config["features"] = features
328+
329+
def get_explainability_config(self):
330+
"""Returns config."""
331+
return copy.deepcopy({"pdp": self.pdp_config})
332+
333+
303334
class SHAPConfig(ExplainabilityConfig):
304335
"""Config class of SHAP."""
305336

@@ -776,8 +807,9 @@ def run_explainability(
776807
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
777808
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
778809
endpoint to be created.
779-
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
780-
specific explainability method. Currently, only SHAP is supported.
810+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
811+
Config of the specific explainability method or a list of ExplainabilityConfig
812+
objects. Currently, SHAP and PDP are the two methods supported.
781813
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
782814
model output for the predicted scores to be explained. This is not required if the
783815
model output is a single score. Alternatively, an instance of
@@ -811,7 +843,30 @@ def run_explainability(
811843
predictor_config.update(predicted_label_config)
812844
else:
813845
_set(model_scores, "label", predictor_config)
814-
analysis_config["methods"] = explainability_config.get_explainability_config()
846+
847+
explainability_methods = {}
848+
if isinstance(explainability_config, list):
849+
if len(explainability_config) == 0:
850+
raise ValueError("Please provide at least one explainability config.")
851+
for config in explainability_config:
852+
explain_config = config.get_explainability_config()
853+
explainability_methods.update(explain_config)
854+
if not len(explainability_methods.keys()) == len(explainability_config):
855+
raise ValueError("Duplicate explainability configs are provided")
856+
if (
857+
"shap" not in explainability_methods
858+
and explainability_methods["pdp"].get("features", None) is None
859+
):
860+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
861+
else:
862+
if (
863+
isinstance(explainability_config, PDPConfig)
864+
and explainability_config.get_explainability_config()["pdp"].get("features", None)
865+
is None
866+
):
867+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
868+
explainability_methods = explainability_config.get_explainability_config()
869+
analysis_config["methods"] = explainability_methods
815870
analysis_config["predictor"] = predictor_config
816871
if job_name is None:
817872
if self.job_name_prefix:

tests/unit/test_clarify.py

Lines changed: 143 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from __future__ import print_function, absolute_import
1515

16+
import copy
17+
1618
from mock import patch, Mock, MagicMock
1719
import pytest
1820

@@ -23,6 +25,7 @@
2325
ModelConfig,
2426
ModelPredictedLabelConfig,
2527
SHAPConfig,
28+
PDPConfig,
2629
)
2730
from sagemaker import image_uris, Processor
2831

@@ -268,6 +271,14 @@ def test_shap_config():
268271
assert expected_config == shap_config.get_explainability_config()
269272

270273

274+
def test_pdp_config():
275+
pdp_config = PDPConfig(features=["f1", "f2"], grid_resolution=20)
276+
expected_config = {
277+
"pdp": {"features": ["f1", "f2"], "grid_resolution": 20, "top_k_features": 10}
278+
}
279+
assert expected_config == pdp_config.get_explainability_config()
280+
281+
271282
def test_invalid_shap_config():
272283
with pytest.raises(ValueError) as error:
273284
SHAPConfig(
@@ -367,13 +378,18 @@ def shap_config():
367378
0.26124998927116394,
368379
0.2824999988079071,
369380
0.06875000149011612,
370-
]
381+
],
371382
],
372383
num_samples=100,
373384
agg_method="mean_sq",
374385
)
375386

376387

388+
@pytest.fixture(scope="module")
389+
def pdp_config():
390+
return PDPConfig(features=["F1", "F2"], grid_resolution=20)
391+
392+
377393
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
378394
def test_pre_training_bias(
379395
name_from_base,
@@ -552,21 +568,30 @@ def test_run_on_s3_analysis_config_file(
552568
)
553569

554570

555-
def _run_test_shap(
571+
def _run_test_explain(
556572
name_from_base,
557573
clarify_processor,
558574
clarify_processor_with_job_name_prefix,
559575
data_config,
560576
model_config,
561577
shap_config,
578+
pdp_config,
562579
model_scores,
563580
expected_predictor_config,
564581
):
565582
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
583+
explanation_configs = None
584+
if shap_config and pdp_config:
585+
explanation_configs = [shap_config, pdp_config]
586+
elif shap_config:
587+
explanation_configs = shap_config
588+
elif pdp_config:
589+
explanation_configs = pdp_config
590+
566591
clarify_processor.run_explainability(
567592
data_config,
568593
model_config,
569-
shap_config,
594+
explanation_configs,
570595
model_scores=model_scores,
571596
wait=True,
572597
job_name="test",
@@ -581,23 +606,30 @@ def _run_test_shap(
581606
"F3",
582607
],
583608
"label": "Label",
584-
"methods": {
585-
"shap": {
586-
"baseline": [
587-
[
588-
0.26124998927116394,
589-
0.2824999988079071,
590-
0.06875000149011612,
591-
]
592-
],
593-
"num_samples": 100,
594-
"agg_method": "mean_sq",
595-
"use_logit": False,
596-
"save_local_shap_values": True,
597-
}
598-
},
599609
"predictor": expected_predictor_config,
600610
}
611+
expected_explanation_configs = {}
612+
if shap_config:
613+
expected_explanation_configs["shap"] = {
614+
"baseline": [
615+
[
616+
0.26124998927116394,
617+
0.2824999988079071,
618+
0.06875000149011612,
619+
]
620+
],
621+
"num_samples": 100,
622+
"agg_method": "mean_sq",
623+
"use_logit": False,
624+
"save_local_shap_values": True,
625+
}
626+
if pdp_config:
627+
expected_explanation_configs["pdp"] = {
628+
"features": ["F1", "F2"],
629+
"grid_resolution": 20,
630+
"top_k_features": 10,
631+
}
632+
expected_analysis_config["methods"] = expected_explanation_configs
601633
mock_method.assert_called_with(
602634
data_config,
603635
expected_analysis_config,
@@ -610,7 +642,7 @@ def _run_test_shap(
610642
clarify_processor_with_job_name_prefix.run_explainability(
611643
data_config,
612644
model_config,
613-
shap_config,
645+
explanation_configs,
614646
model_scores=model_scores,
615647
wait=True,
616648
experiment_config={"ExperimentName": "AnExperiment"},
@@ -627,6 +659,34 @@ def _run_test_shap(
627659
)
628660

629661

662+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
663+
def test_pdp(
664+
name_from_base,
665+
clarify_processor,
666+
clarify_processor_with_job_name_prefix,
667+
data_config,
668+
model_config,
669+
shap_config,
670+
pdp_config,
671+
):
672+
expected_predictor_config = {
673+
"model_name": "xgboost-model",
674+
"instance_type": "ml.c5.xlarge",
675+
"initial_instance_count": 1,
676+
}
677+
_run_test_explain(
678+
name_from_base,
679+
clarify_processor,
680+
clarify_processor_with_job_name_prefix,
681+
data_config,
682+
model_config,
683+
None,
684+
pdp_config,
685+
None,
686+
expected_predictor_config,
687+
)
688+
689+
630690
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
631691
def test_shap(
632692
name_from_base,
@@ -641,18 +701,78 @@ def test_shap(
641701
"instance_type": "ml.c5.xlarge",
642702
"initial_instance_count": 1,
643703
}
644-
_run_test_shap(
704+
_run_test_explain(
645705
name_from_base,
646706
clarify_processor,
647707
clarify_processor_with_job_name_prefix,
648708
data_config,
649709
model_config,
650710
shap_config,
651711
None,
712+
None,
652713
expected_predictor_config,
653714
)
654715

655716

717+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
718+
def test_explainability_with_invalid_config(
719+
name_from_base,
720+
clarify_processor,
721+
clarify_processor_with_job_name_prefix,
722+
data_config,
723+
model_config,
724+
):
725+
expected_predictor_config = {
726+
"model_name": "xgboost-model",
727+
"instance_type": "ml.c5.xlarge",
728+
"initial_instance_count": 1,
729+
}
730+
with pytest.raises(
731+
AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'"
732+
):
733+
_run_test_explain(
734+
name_from_base,
735+
clarify_processor,
736+
clarify_processor_with_job_name_prefix,
737+
data_config,
738+
model_config,
739+
None,
740+
None,
741+
None,
742+
expected_predictor_config,
743+
)
744+
745+
746+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
747+
def test_explainability_with_multiple_shap_config(
748+
name_from_base,
749+
clarify_processor,
750+
clarify_processor_with_job_name_prefix,
751+
data_config,
752+
model_config,
753+
shap_config,
754+
):
755+
expected_predictor_config = {
756+
"model_name": "xgboost-model",
757+
"instance_type": "ml.c5.xlarge",
758+
"initial_instance_count": 1,
759+
}
760+
with pytest.raises(ValueError, match="Duplicate explainability configs are provided"):
761+
second_shap_config = copy.deepcopy(shap_config)
762+
second_shap_config.shap_config["num_samples"] = 200
763+
_run_test_explain(
764+
name_from_base,
765+
clarify_processor,
766+
clarify_processor_with_job_name_prefix,
767+
data_config,
768+
model_config,
769+
[shap_config, second_shap_config],
770+
None,
771+
None,
772+
expected_predictor_config,
773+
)
774+
775+
656776
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
657777
def test_shap_with_predicted_label(
658778
name_from_base,
@@ -661,6 +781,7 @@ def test_shap_with_predicted_label(
661781
data_config,
662782
model_config,
663783
shap_config,
784+
pdp_config,
664785
):
665786
probability = "pr"
666787
label_headers = ["success"]
@@ -675,13 +796,14 @@ def test_shap_with_predicted_label(
675796
"probability": probability,
676797
"label_headers": label_headers,
677798
}
678-
_run_test_shap(
799+
_run_test_explain(
679800
name_from_base,
680801
clarify_processor,
681802
clarify_processor_with_job_name_prefix,
682803
data_config,
683804
model_config,
684805
shap_config,
806+
pdp_config,
685807
model_scores,
686808
expected_predictor_config,
687809
)

0 commit comments

Comments
 (0)