Skip to content

Commit aaece32

Browse files
feature: Add support for Partial Dependence Plots(PDP) in SageMaker Clarify
1 parent db56e57 commit aaece32

File tree

2 files changed

+201
-24
lines changed

2 files changed

+201
-24
lines changed

src/sagemaker/clarify.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,37 @@ def get_explainability_config(self):
300300
return None
301301

302302

303+
class PDPConfig(ExplainabilityConfig):
304+
"""Config class for Partial Dependence Plots (PDP).
305+
306+
If PDP is requested, the Partial Dependence Plots will be included in the report, and the
307+
corresponding values will be included in the analysis output.
308+
"""
309+
310+
def __init__(self, features=None, grid_resolution=15, top_k_features=10):
311+
"""Initializes config for PDP.
312+
313+
Args:
314+
features (None or list): List of features names or indices for which partial dependence
315+
plots must be computed and plotted. When ShapConfig is provided, this parameter is
316+
optional as Clarify will try to compute the partial dependence plots for top
317+
feature based on SHAP attributions. When ShapConfig is not provided, 'features'
318+
must be provided.
319+
grid_resolution (int): In case of numerical features, this number represents that
320+
number of buckets that range of values must be divided into. This decides the
321+
granularity of the grid in which the PDP are plotted.
322+
top_k_features (int): Set the number of top SHAP attributes to be selected to compute
323+
partial dependence plots.
324+
"""
325+
self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features}
326+
if features is not None:
327+
self.pdp_config["features"] = features
328+
329+
def get_explainability_config(self):
330+
"""Returns config."""
331+
return copy.deepcopy({"pdp": self.pdp_config})
332+
333+
303334
class SHAPConfig(ExplainabilityConfig):
304335
"""Config class of SHAP."""
305336

@@ -792,8 +823,9 @@ def run_explainability(
792823
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
793824
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
794825
endpoint to be created.
795-
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
796-
specific explainability method. Currently, only SHAP is supported.
826+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
827+
Config of the specific explainability method or a list of ExplainabilityConfig
828+
objects. Currently, SHAP and PDP are the two methods supported.
797829
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
798830
model output for the predicted scores to be explained. This is not required if the
799831
model output is a single score. Alternatively, an instance of
@@ -827,7 +859,30 @@ def run_explainability(
827859
predictor_config.update(predicted_label_config)
828860
else:
829861
_set(model_scores, "label", predictor_config)
830-
analysis_config["methods"] = explainability_config.get_explainability_config()
862+
863+
explainability_methods = {}
864+
if isinstance(explainability_config, list):
865+
if len(explainability_config) == 0:
866+
raise ValueError("Please provide at least one explainability config.")
867+
for config in explainability_config:
868+
explain_config = config.get_explainability_config()
869+
explainability_methods.update(explain_config)
870+
if not len(explainability_methods.keys()) == len(explainability_config):
871+
raise ValueError("Duplicate explainability configs are provided")
872+
if (
873+
"shap" not in explainability_methods
874+
and explainability_methods["pdp"].get("features", None) is None
875+
):
876+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
877+
else:
878+
if (
879+
isinstance(explainability_config, PDPConfig)
880+
and explainability_config.get_explainability_config()["pdp"].get("features", None)
881+
is None
882+
):
883+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
884+
explainability_methods = explainability_config.get_explainability_config()
885+
analysis_config["methods"] = explainability_methods
831886
analysis_config["predictor"] = predictor_config
832887
if job_name is None:
833888
if self.job_name_prefix:

tests/unit/test_clarify.py

Lines changed: 143 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from __future__ import print_function, absolute_import
1515

16+
import copy
17+
1618
from mock import patch, Mock, MagicMock
1719
import pytest
1820

@@ -23,6 +25,7 @@
2325
ModelConfig,
2426
ModelPredictedLabelConfig,
2527
SHAPConfig,
28+
PDPConfig,
2629
)
2730
from sagemaker import image_uris, Processor
2831

@@ -304,6 +307,14 @@ def test_shap_config_no_parameters():
304307
assert expected_config == shap_config.get_explainability_config()
305308

306309

310+
def test_pdp_config():
311+
pdp_config = PDPConfig(features=["f1", "f2"], grid_resolution=20)
312+
expected_config = {
313+
"pdp": {"features": ["f1", "f2"], "grid_resolution": 20, "top_k_features": 10}
314+
}
315+
assert expected_config == pdp_config.get_explainability_config()
316+
317+
307318
def test_invalid_shap_config():
308319
with pytest.raises(ValueError) as error:
309320
SHAPConfig(
@@ -409,13 +420,18 @@ def shap_config():
409420
0.26124998927116394,
410421
0.2824999988079071,
411422
0.06875000149011612,
412-
]
423+
],
413424
],
414425
num_samples=100,
415426
agg_method="mean_sq",
416427
)
417428

418429

430+
@pytest.fixture(scope="module")
431+
def pdp_config():
432+
return PDPConfig(features=["F1", "F2"], grid_resolution=20)
433+
434+
419435
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
420436
def test_pre_training_bias(
421437
name_from_base,
@@ -594,21 +610,30 @@ def test_run_on_s3_analysis_config_file(
594610
)
595611

596612

597-
def _run_test_shap(
613+
def _run_test_explain(
598614
name_from_base,
599615
clarify_processor,
600616
clarify_processor_with_job_name_prefix,
601617
data_config,
602618
model_config,
603619
shap_config,
620+
pdp_config,
604621
model_scores,
605622
expected_predictor_config,
606623
):
607624
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
625+
explanation_configs = None
626+
if shap_config and pdp_config:
627+
explanation_configs = [shap_config, pdp_config]
628+
elif shap_config:
629+
explanation_configs = shap_config
630+
elif pdp_config:
631+
explanation_configs = pdp_config
632+
608633
clarify_processor.run_explainability(
609634
data_config,
610635
model_config,
611-
shap_config,
636+
explanation_configs,
612637
model_scores=model_scores,
613638
wait=True,
614639
job_name="test",
@@ -623,23 +648,30 @@ def _run_test_shap(
623648
"F3",
624649
],
625650
"label": "Label",
626-
"methods": {
627-
"shap": {
628-
"baseline": [
629-
[
630-
0.26124998927116394,
631-
0.2824999988079071,
632-
0.06875000149011612,
633-
]
634-
],
635-
"num_samples": 100,
636-
"agg_method": "mean_sq",
637-
"use_logit": False,
638-
"save_local_shap_values": True,
639-
}
640-
},
641651
"predictor": expected_predictor_config,
642652
}
653+
expected_explanation_configs = {}
654+
if shap_config:
655+
expected_explanation_configs["shap"] = {
656+
"baseline": [
657+
[
658+
0.26124998927116394,
659+
0.2824999988079071,
660+
0.06875000149011612,
661+
]
662+
],
663+
"num_samples": 100,
664+
"agg_method": "mean_sq",
665+
"use_logit": False,
666+
"save_local_shap_values": True,
667+
}
668+
if pdp_config:
669+
expected_explanation_configs["pdp"] = {
670+
"features": ["F1", "F2"],
671+
"grid_resolution": 20,
672+
"top_k_features": 10,
673+
}
674+
expected_analysis_config["methods"] = expected_explanation_configs
643675
mock_method.assert_called_with(
644676
data_config,
645677
expected_analysis_config,
@@ -652,7 +684,7 @@ def _run_test_shap(
652684
clarify_processor_with_job_name_prefix.run_explainability(
653685
data_config,
654686
model_config,
655-
shap_config,
687+
explanation_configs,
656688
model_scores=model_scores,
657689
wait=True,
658690
experiment_config={"ExperimentName": "AnExperiment"},
@@ -669,6 +701,34 @@ def _run_test_shap(
669701
)
670702

671703

704+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
705+
def test_pdp(
706+
name_from_base,
707+
clarify_processor,
708+
clarify_processor_with_job_name_prefix,
709+
data_config,
710+
model_config,
711+
shap_config,
712+
pdp_config,
713+
):
714+
expected_predictor_config = {
715+
"model_name": "xgboost-model",
716+
"instance_type": "ml.c5.xlarge",
717+
"initial_instance_count": 1,
718+
}
719+
_run_test_explain(
720+
name_from_base,
721+
clarify_processor,
722+
clarify_processor_with_job_name_prefix,
723+
data_config,
724+
model_config,
725+
None,
726+
pdp_config,
727+
None,
728+
expected_predictor_config,
729+
)
730+
731+
672732
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
673733
def test_shap(
674734
name_from_base,
@@ -683,18 +743,78 @@ def test_shap(
683743
"instance_type": "ml.c5.xlarge",
684744
"initial_instance_count": 1,
685745
}
686-
_run_test_shap(
746+
_run_test_explain(
687747
name_from_base,
688748
clarify_processor,
689749
clarify_processor_with_job_name_prefix,
690750
data_config,
691751
model_config,
692752
shap_config,
693753
None,
754+
None,
694755
expected_predictor_config,
695756
)
696757

697758

759+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
760+
def test_explainability_with_invalid_config(
761+
name_from_base,
762+
clarify_processor,
763+
clarify_processor_with_job_name_prefix,
764+
data_config,
765+
model_config,
766+
):
767+
expected_predictor_config = {
768+
"model_name": "xgboost-model",
769+
"instance_type": "ml.c5.xlarge",
770+
"initial_instance_count": 1,
771+
}
772+
with pytest.raises(
773+
AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'"
774+
):
775+
_run_test_explain(
776+
name_from_base,
777+
clarify_processor,
778+
clarify_processor_with_job_name_prefix,
779+
data_config,
780+
model_config,
781+
None,
782+
None,
783+
None,
784+
expected_predictor_config,
785+
)
786+
787+
788+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
789+
def test_explainability_with_multiple_shap_config(
790+
name_from_base,
791+
clarify_processor,
792+
clarify_processor_with_job_name_prefix,
793+
data_config,
794+
model_config,
795+
shap_config,
796+
):
797+
expected_predictor_config = {
798+
"model_name": "xgboost-model",
799+
"instance_type": "ml.c5.xlarge",
800+
"initial_instance_count": 1,
801+
}
802+
with pytest.raises(ValueError, match="Duplicate explainability configs are provided"):
803+
second_shap_config = copy.deepcopy(shap_config)
804+
second_shap_config.shap_config["num_samples"] = 200
805+
_run_test_explain(
806+
name_from_base,
807+
clarify_processor,
808+
clarify_processor_with_job_name_prefix,
809+
data_config,
810+
model_config,
811+
[shap_config, second_shap_config],
812+
None,
813+
None,
814+
expected_predictor_config,
815+
)
816+
817+
698818
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
699819
def test_shap_with_predicted_label(
700820
name_from_base,
@@ -703,6 +823,7 @@ def test_shap_with_predicted_label(
703823
data_config,
704824
model_config,
705825
shap_config,
826+
pdp_config,
706827
):
707828
probability = "pr"
708829
label_headers = ["success"]
@@ -717,13 +838,14 @@ def test_shap_with_predicted_label(
717838
"probability": probability,
718839
"label_headers": label_headers,
719840
}
720-
_run_test_shap(
841+
_run_test_explain(
721842
name_from_base,
722843
clarify_processor,
723844
clarify_processor_with_job_name_prefix,
724845
data_config,
725846
model_config,
726847
shap_config,
848+
pdp_config,
727849
model_scores,
728850
expected_predictor_config,
729851
)

0 commit comments

Comments
 (0)