From 9f04b85b8ce3b24672707b31ebc6aefb49408504 Mon Sep 17 00:00:00 2001 From: Mila Hardt Date: Mon, 14 Jun 2021 12:53:35 -0700 Subject: [PATCH 1/2] change: Add configuration option with headers for Clarify. --- src/sagemaker/clarify.py | 13 +++++++-- tests/unit/test_clarify.py | 58 +++++++++++++++++++++++++++++++++----- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 680a426912..9ea0d59256 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -674,8 +674,10 @@ def run_explainability( endpoint to be created. explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the specific explainability method. Currently, only SHAP is supported. - model_scores: Index or JSONPath location in the model output for the predicted scores - to be explained. This is not required if the model output is a single score. + model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the + model output for the predicted scores to be explained. This is not required if the + model output is a single score. Alternatively, an instance of + ModelPredictedLabelConfig can be provided. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). @@ -689,7 +691,12 @@ def run_explainability( """ analysis_config = data_config.get_config() predictor_config = model_config.get_predictor_config() - _set(model_scores, "label", predictor_config) + if isinstance(model_scores, ModelPredictedLabelConfig): + probability_threshold, predicted_label_config = model_scores.get_predictor_config() + _set(probability_threshold, "probability_threshold", analysis_config) + predictor_config.update(predicted_label_config) + else: + _set(model_scores, "label", predictor_config) analysis_config["methods"] = explainability_config.get_explainability_config() analysis_config["predictor"] = predictor_config if job_name is None: diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 03004f2faa..d4844742f8 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -379,13 +379,20 @@ def test_post_training_bias( ) -def test_shap(clarify_processor, data_config, model_config, shap_config): +def _run_test_shap( + clarify_processor, + data_config, + model_config, + shap_config, + model_scores, + expected_predictor_config, +): with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method: clarify_processor.run_explainability( data_config, model_config, shap_config, - model_scores=None, + model_scores=model_scores, wait=True, job_name="test", experiment_config={"ExperimentName": "AnExperiment"}, @@ -414,11 +421,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config): "save_local_shap_values": True, } }, - "predictor": { - "model_name": "xgboost-model", - "instance_type": "ml.c5.xlarge", - "initial_instance_count": 1, - }, + "predictor": expected_predictor_config, } mock_method.assert_called_once_with( data_config, @@ -429,3 +432,44 @@ def test_shap(clarify_processor, data_config, model_config, shap_config): None, {"ExperimentName": "AnExperiment"}, ) + + +def test_shap(clarify_processor, data_config, model_config, shap_config): + model_scores = None + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + } + _run_test_shap( + clarify_processor, + data_config, + model_config, + shap_config, + model_scores, + expected_predictor_config, + ) + + +def test_shap_with_predicted_label(clarify_processor, data_config, model_config, shap_config): + probability = "pr" + label_headers = ["success"] + model_scores = ModelPredictedLabelConfig( + probability=probability, + label_headers=label_headers, + ) + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + "probability": probability, + "label_headers": label_headers, + } + _run_test_shap( + clarify_processor, + data_config, + model_config, + shap_config, + model_scores, + expected_predictor_config, + ) From 259d31d0e848f73a886c922961dde501a10518f2 Mon Sep 17 00:00:00 2001 From: Mila Hardt Date: Wed, 23 Jun 2021 10:51:52 -0700 Subject: [PATCH 2/2] Redo test --- tests/unit/test_clarify.py | 75 +++++++++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 9cf17bcf09..7c6ae7e8c9 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -442,21 +442,22 @@ def test_post_training_bias( ) -@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) -def test_shap( +def _run_test_shap( name_from_base, clarify_processor, clarify_processor_with_job_name_prefix, data_config, model_config, shap_config, + model_scores, + expected_predictor_config, ): with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method: clarify_processor.run_explainability( data_config, model_config, shap_config, - model_scores=None, + model_scores=model_scores, wait=True, job_name="test", experiment_config={"ExperimentName": "AnExperiment"}, @@ -485,11 +486,7 @@ def test_shap( "save_local_shap_values": True, } }, - "predictor": { - "model_name": "xgboost-model", - "instance_type": "ml.c5.xlarge", - "initial_instance_count": 1, - }, + "predictor": expected_predictor_config, } mock_method.assert_called_with( data_config, @@ -504,7 +501,7 @@ def test_shap( data_config, model_config, shap_config, - model_scores=None, + model_scores=model_scores, wait=True, experiment_config={"ExperimentName": "AnExperiment"}, ) @@ -518,3 +515,63 @@ def test_shap( None, {"ExperimentName": "AnExperiment"}, ) + + +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_shap( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, +): + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + } + _run_test_shap( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, + None, + expected_predictor_config, + ) + + +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_shap_with_predicted_label( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, +): + probability = "pr" + label_headers = ["success"] + model_scores = ModelPredictedLabelConfig( + probability=probability, + label_headers=label_headers, + ) + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + "probability": probability, + "label_headers": label_headers, + } + _run_test_shap( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, + model_scores, + expected_predictor_config, + )