Skip to content

feature: Add support for Partial Dependence Plots(PDP) in SageMaker Clarify #2716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,37 @@ def get_explainability_config(self):
return None


class PDPConfig(ExplainabilityConfig):
"""Config class for Partial Dependence Plots (PDP).

If PDP is requested, the Partial Dependence Plots will be included in the report, and the
corresponding values will be included in the analysis output.
"""

def __init__(self, features=None, grid_resolution=15, top_k_features=10):
"""Initializes config for PDP.

Args:
features (None or list): List of features names or indices for which partial dependence
plots must be computed and plotted. When ShapConfig is provided, this parameter is
optional as Clarify will try to compute the partial dependence plots for top
feature based on SHAP attributions. When ShapConfig is not provided, 'features'
must be provided.
grid_resolution (int): In case of numerical features, this number represents that
number of buckets that range of values must be divided into. This decides the
granularity of the grid in which the PDP are plotted.
top_k_features (int): Set the number of top SHAP attributes to be selected to compute
partial dependence plots.
"""
self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features}
if features is not None:
self.pdp_config["features"] = features

def get_explainability_config(self):
"""Returns config."""
return copy.deepcopy({"pdp": self.pdp_config})


class SHAPConfig(ExplainabilityConfig):
"""Config class of SHAP."""

Expand Down Expand Up @@ -792,8 +823,9 @@ def run_explainability(
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
endpoint to be created.
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
specific explainability method. Currently, only SHAP is supported.
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
Config of the specific explainability method or a list of ExplainabilityConfig
objects. Currently, SHAP and PDP are the two methods supported.
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
model output for the predicted scores to be explained. This is not required if the
model output is a single score. Alternatively, an instance of
Expand Down Expand Up @@ -827,7 +859,30 @@ def run_explainability(
predictor_config.update(predicted_label_config)
else:
_set(model_scores, "label", predictor_config)
analysis_config["methods"] = explainability_config.get_explainability_config()

explainability_methods = {}
if isinstance(explainability_config, list):
if len(explainability_config) == 0:
raise ValueError("Please provide at least one explainability config.")
for config in explainability_config:
explain_config = config.get_explainability_config()
explainability_methods.update(explain_config)
if not len(explainability_methods.keys()) == len(explainability_config):
raise ValueError("Duplicate explainability configs are provided")
if (
"shap" not in explainability_methods
and explainability_methods["pdp"].get("features", None) is None
):
raise ValueError("PDP features must be provided when ShapConfig is not provided")
else:
if (
isinstance(explainability_config, PDPConfig)
and explainability_config.get_explainability_config()["pdp"].get("features", None)
is None
):
raise ValueError("PDP features must be provided when ShapConfig is not provided")
explainability_methods = explainability_config.get_explainability_config()
analysis_config["methods"] = explainability_methods
analysis_config["predictor"] = predictor_config
if job_name is None:
if self.job_name_prefix:
Expand Down
164 changes: 143 additions & 21 deletions tests/unit/test_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from __future__ import print_function, absolute_import

import copy

from mock import patch, Mock, MagicMock
import pytest

Expand All @@ -23,6 +25,7 @@
ModelConfig,
ModelPredictedLabelConfig,
SHAPConfig,
PDPConfig,
)
from sagemaker import image_uris, Processor

Expand Down Expand Up @@ -304,6 +307,14 @@ def test_shap_config_no_parameters():
assert expected_config == shap_config.get_explainability_config()


def test_pdp_config():
pdp_config = PDPConfig(features=["f1", "f2"], grid_resolution=20)
expected_config = {
"pdp": {"features": ["f1", "f2"], "grid_resolution": 20, "top_k_features": 10}
}
assert expected_config == pdp_config.get_explainability_config()


def test_invalid_shap_config():
with pytest.raises(ValueError) as error:
SHAPConfig(
Expand Down Expand Up @@ -409,13 +420,18 @@ def shap_config():
0.26124998927116394,
0.2824999988079071,
0.06875000149011612,
]
],
],
num_samples=100,
agg_method="mean_sq",
)


@pytest.fixture(scope="module")
def pdp_config():
return PDPConfig(features=["F1", "F2"], grid_resolution=20)


@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_pre_training_bias(
name_from_base,
Expand Down Expand Up @@ -594,21 +610,30 @@ def test_run_on_s3_analysis_config_file(
)


def _run_test_shap(
def _run_test_explain(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
shap_config,
pdp_config,
model_scores,
expected_predictor_config,
):
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
explanation_configs = None
if shap_config and pdp_config:
explanation_configs = [shap_config, pdp_config]
elif shap_config:
explanation_configs = shap_config
elif pdp_config:
explanation_configs = pdp_config

clarify_processor.run_explainability(
data_config,
model_config,
shap_config,
explanation_configs,
model_scores=model_scores,
wait=True,
job_name="test",
Expand All @@ -623,23 +648,30 @@ def _run_test_shap(
"F3",
],
"label": "Label",
"methods": {
"shap": {
"baseline": [
[
0.26124998927116394,
0.2824999988079071,
0.06875000149011612,
]
],
"num_samples": 100,
"agg_method": "mean_sq",
"use_logit": False,
"save_local_shap_values": True,
}
},
"predictor": expected_predictor_config,
}
expected_explanation_configs = {}
if shap_config:
expected_explanation_configs["shap"] = {
"baseline": [
[
0.26124998927116394,
0.2824999988079071,
0.06875000149011612,
]
],
"num_samples": 100,
"agg_method": "mean_sq",
"use_logit": False,
"save_local_shap_values": True,
}
if pdp_config:
expected_explanation_configs["pdp"] = {
"features": ["F1", "F2"],
"grid_resolution": 20,
"top_k_features": 10,
}
expected_analysis_config["methods"] = expected_explanation_configs
mock_method.assert_called_with(
data_config,
expected_analysis_config,
Expand All @@ -652,7 +684,7 @@ def _run_test_shap(
clarify_processor_with_job_name_prefix.run_explainability(
data_config,
model_config,
shap_config,
explanation_configs,
model_scores=model_scores,
wait=True,
experiment_config={"ExperimentName": "AnExperiment"},
Expand All @@ -669,6 +701,34 @@ def _run_test_shap(
)


@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_pdp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally there should also be test cases for invalid input, for example:

  • explainability_config is None
  • explainability_config is an emtpy list
  • Multiple shap_config (or pdp_config) in explainability_config

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Additional tests are added.

name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
shap_config,
pdp_config,
):
expected_predictor_config = {
"model_name": "xgboost-model",
"instance_type": "ml.c5.xlarge",
"initial_instance_count": 1,
}
_run_test_explain(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
None,
pdp_config,
None,
expected_predictor_config,
)


@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_shap(
name_from_base,
Expand All @@ -683,18 +743,78 @@ def test_shap(
"instance_type": "ml.c5.xlarge",
"initial_instance_count": 1,
}
_run_test_shap(
_run_test_explain(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
shap_config,
None,
None,
expected_predictor_config,
)


@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_explainability_with_invalid_config(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
):
expected_predictor_config = {
"model_name": "xgboost-model",
"instance_type": "ml.c5.xlarge",
"initial_instance_count": 1,
}
with pytest.raises(
AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'"
):
_run_test_explain(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
None,
None,
None,
expected_predictor_config,
)


@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_explainability_with_multiple_shap_config(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
shap_config,
):
expected_predictor_config = {
"model_name": "xgboost-model",
"instance_type": "ml.c5.xlarge",
"initial_instance_count": 1,
}
with pytest.raises(ValueError, match="Duplicate explainability configs are provided"):
second_shap_config = copy.deepcopy(shap_config)
second_shap_config.shap_config["num_samples"] = 200
_run_test_explain(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
[shap_config, second_shap_config],
None,
None,
expected_predictor_config,
)


@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_shap_with_predicted_label(
name_from_base,
Expand All @@ -703,6 +823,7 @@ def test_shap_with_predicted_label(
data_config,
model_config,
shap_config,
pdp_config,
):
probability = "pr"
label_headers = ["success"]
Expand All @@ -717,13 +838,14 @@ def test_shap_with_predicted_label(
"probability": probability,
"label_headers": label_headers,
}
_run_test_shap(
_run_test_explain(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
shap_config,
pdp_config,
model_scores,
expected_predictor_config,
)