diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 89feb3741b..ab6fed6d80 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -300,6 +300,37 @@ def get_explainability_config(self): return None +class PDPConfig(ExplainabilityConfig): + """Config class for Partial Dependence Plots (PDP). + + If PDP is requested, the Partial Dependence Plots will be included in the report, and the + corresponding values will be included in the analysis output. + """ + + def __init__(self, features=None, grid_resolution=15, top_k_features=10): + """Initializes config for PDP. + + Args: + features (None or list): List of features names or indices for which partial dependence + plots must be computed and plotted. When ShapConfig is provided, this parameter is + optional as Clarify will try to compute the partial dependence plots for top + feature based on SHAP attributions. When ShapConfig is not provided, 'features' + must be provided. + grid_resolution (int): In case of numerical features, this number represents that + number of buckets that range of values must be divided into. This decides the + granularity of the grid in which the PDP are plotted. + top_k_features (int): Set the number of top SHAP attributes to be selected to compute + partial dependence plots. + """ + self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features} + if features is not None: + self.pdp_config["features"] = features + + def get_explainability_config(self): + """Returns config.""" + return copy.deepcopy({"pdp": self.pdp_config}) + + class SHAPConfig(ExplainabilityConfig): """Config class of SHAP.""" @@ -792,8 +823,9 @@ def run_explainability( data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. - explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the - specific explainability method. Currently, only SHAP is supported. + explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list): + Config of the specific explainability method or a list of ExplainabilityConfig + objects. Currently, SHAP and PDP are the two methods supported. model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the model output for the predicted scores to be explained. This is not required if the model output is a single score. Alternatively, an instance of @@ -827,7 +859,30 @@ def run_explainability( predictor_config.update(predicted_label_config) else: _set(model_scores, "label", predictor_config) - analysis_config["methods"] = explainability_config.get_explainability_config() + + explainability_methods = {} + if isinstance(explainability_config, list): + if len(explainability_config) == 0: + raise ValueError("Please provide at least one explainability config.") + for config in explainability_config: + explain_config = config.get_explainability_config() + explainability_methods.update(explain_config) + if not len(explainability_methods.keys()) == len(explainability_config): + raise ValueError("Duplicate explainability configs are provided") + if ( + "shap" not in explainability_methods + and explainability_methods["pdp"].get("features", None) is None + ): + raise ValueError("PDP features must be provided when ShapConfig is not provided") + else: + if ( + isinstance(explainability_config, PDPConfig) + and explainability_config.get_explainability_config()["pdp"].get("features", None) + is None + ): + raise ValueError("PDP features must be provided when ShapConfig is not provided") + explainability_methods = explainability_config.get_explainability_config() + analysis_config["methods"] = explainability_methods analysis_config["predictor"] = predictor_config if job_name is None: if self.job_name_prefix: diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 7a4441bf8d..0b2bf1b2ec 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -13,6 +13,8 @@ from __future__ import print_function, absolute_import +import copy + from mock import patch, Mock, MagicMock import pytest @@ -23,6 +25,7 @@ ModelConfig, ModelPredictedLabelConfig, SHAPConfig, + PDPConfig, ) from sagemaker import image_uris, Processor @@ -304,6 +307,14 @@ def test_shap_config_no_parameters(): assert expected_config == shap_config.get_explainability_config() +def test_pdp_config(): + pdp_config = PDPConfig(features=["f1", "f2"], grid_resolution=20) + expected_config = { + "pdp": {"features": ["f1", "f2"], "grid_resolution": 20, "top_k_features": 10} + } + assert expected_config == pdp_config.get_explainability_config() + + def test_invalid_shap_config(): with pytest.raises(ValueError) as error: SHAPConfig( @@ -409,13 +420,18 @@ def shap_config(): 0.26124998927116394, 0.2824999988079071, 0.06875000149011612, - ] + ], ], num_samples=100, agg_method="mean_sq", ) +@pytest.fixture(scope="module") +def pdp_config(): + return PDPConfig(features=["F1", "F2"], grid_resolution=20) + + @patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) def test_pre_training_bias( name_from_base, @@ -594,21 +610,30 @@ def test_run_on_s3_analysis_config_file( ) -def _run_test_shap( +def _run_test_explain( name_from_base, clarify_processor, clarify_processor_with_job_name_prefix, data_config, model_config, shap_config, + pdp_config, model_scores, expected_predictor_config, ): with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method: + explanation_configs = None + if shap_config and pdp_config: + explanation_configs = [shap_config, pdp_config] + elif shap_config: + explanation_configs = shap_config + elif pdp_config: + explanation_configs = pdp_config + clarify_processor.run_explainability( data_config, model_config, - shap_config, + explanation_configs, model_scores=model_scores, wait=True, job_name="test", @@ -623,23 +648,30 @@ def _run_test_shap( "F3", ], "label": "Label", - "methods": { - "shap": { - "baseline": [ - [ - 0.26124998927116394, - 0.2824999988079071, - 0.06875000149011612, - ] - ], - "num_samples": 100, - "agg_method": "mean_sq", - "use_logit": False, - "save_local_shap_values": True, - } - }, "predictor": expected_predictor_config, } + expected_explanation_configs = {} + if shap_config: + expected_explanation_configs["shap"] = { + "baseline": [ + [ + 0.26124998927116394, + 0.2824999988079071, + 0.06875000149011612, + ] + ], + "num_samples": 100, + "agg_method": "mean_sq", + "use_logit": False, + "save_local_shap_values": True, + } + if pdp_config: + expected_explanation_configs["pdp"] = { + "features": ["F1", "F2"], + "grid_resolution": 20, + "top_k_features": 10, + } + expected_analysis_config["methods"] = expected_explanation_configs mock_method.assert_called_with( data_config, expected_analysis_config, @@ -652,7 +684,7 @@ def _run_test_shap( clarify_processor_with_job_name_prefix.run_explainability( data_config, model_config, - shap_config, + explanation_configs, model_scores=model_scores, wait=True, experiment_config={"ExperimentName": "AnExperiment"}, @@ -669,6 +701,34 @@ def _run_test_shap( ) +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_pdp( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, + pdp_config, +): + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + } + _run_test_explain( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + None, + pdp_config, + None, + expected_predictor_config, + ) + + @patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) def test_shap( name_from_base, @@ -683,7 +743,7 @@ def test_shap( "instance_type": "ml.c5.xlarge", "initial_instance_count": 1, } - _run_test_shap( + _run_test_explain( name_from_base, clarify_processor, clarify_processor_with_job_name_prefix, @@ -691,10 +751,70 @@ def test_shap( model_config, shap_config, None, + None, expected_predictor_config, ) +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_explainability_with_invalid_config( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, +): + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + } + with pytest.raises( + AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'" + ): + _run_test_explain( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + None, + None, + None, + expected_predictor_config, + ) + + +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_explainability_with_multiple_shap_config( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, +): + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + } + with pytest.raises(ValueError, match="Duplicate explainability configs are provided"): + second_shap_config = copy.deepcopy(shap_config) + second_shap_config.shap_config["num_samples"] = 200 + _run_test_explain( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + [shap_config, second_shap_config], + None, + None, + expected_predictor_config, + ) + + @patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) def test_shap_with_predicted_label( name_from_base, @@ -703,6 +823,7 @@ def test_shap_with_predicted_label( data_config, model_config, shap_config, + pdp_config, ): probability = "pr" label_headers = ["success"] @@ -717,13 +838,14 @@ def test_shap_with_predicted_label( "probability": probability, "label_headers": label_headers, } - _run_test_shap( + _run_test_explain( name_from_base, clarify_processor, clarify_processor_with_job_name_prefix, data_config, model_config, shap_config, + pdp_config, model_scores, expected_predictor_config, )