diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index e22a82b431..5d36eb229d 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -723,8 +723,10 @@ def run_explainability( endpoint to be created. explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the specific explainability method. Currently, only SHAP is supported. - model_scores: Index or JSONPath location in the model output for the predicted scores - to be explained. This is not required if the model output is a single score. + model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the + model output for the predicted scores to be explained. This is not required if the + model output is a single score. Alternatively, an instance of + ModelPredictedLabelConfig can be provided. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). @@ -740,7 +742,12 @@ def run_explainability( """ analysis_config = data_config.get_config() predictor_config = model_config.get_predictor_config() - _set(model_scores, "label", predictor_config) + if isinstance(model_scores, ModelPredictedLabelConfig): + probability_threshold, predicted_label_config = model_scores.get_predictor_config() + _set(probability_threshold, "probability_threshold", analysis_config) + predictor_config.update(predicted_label_config) + else: + _set(model_scores, "label", predictor_config) analysis_config["methods"] = explainability_config.get_explainability_config() analysis_config["predictor"] = predictor_config if job_name is None: diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 9cf17bcf09..7c6ae7e8c9 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -442,21 +442,22 @@ def test_post_training_bias( ) -@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) -def test_shap( +def _run_test_shap( name_from_base, clarify_processor, clarify_processor_with_job_name_prefix, data_config, model_config, shap_config, + model_scores, + expected_predictor_config, ): with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method: clarify_processor.run_explainability( data_config, model_config, shap_config, - model_scores=None, + model_scores=model_scores, wait=True, job_name="test", experiment_config={"ExperimentName": "AnExperiment"}, @@ -485,11 +486,7 @@ def test_shap( "save_local_shap_values": True, } }, - "predictor": { - "model_name": "xgboost-model", - "instance_type": "ml.c5.xlarge", - "initial_instance_count": 1, - }, + "predictor": expected_predictor_config, } mock_method.assert_called_with( data_config, @@ -504,7 +501,7 @@ def test_shap( data_config, model_config, shap_config, - model_scores=None, + model_scores=model_scores, wait=True, experiment_config={"ExperimentName": "AnExperiment"}, ) @@ -518,3 +515,63 @@ def test_shap( None, {"ExperimentName": "AnExperiment"}, ) + + +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_shap( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, +): + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + } + _run_test_shap( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, + None, + expected_predictor_config, + ) + + +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_shap_with_predicted_label( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, +): + probability = "pr" + label_headers = ["success"] + model_scores = ModelPredictedLabelConfig( + probability=probability, + label_headers=label_headers, + ) + expected_predictor_config = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "initial_instance_count": 1, + "probability": probability, + "label_headers": label_headers, + } + _run_test_shap( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, + model_scores, + expected_predictor_config, + )