Skip to content

change: Add configuration option with headers for Clarify Explainability #2446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 25, 2021
Merged
13 changes: 10 additions & 3 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,10 @@ def run_explainability(
endpoint to be created.
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
specific explainability method. Currently, only SHAP is supported.
model_scores: Index or JSONPath location in the model output for the predicted scores
to be explained. This is not required if the model output is a single score.
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
model output for the predicted scores to be explained. This is not required if the
model output is a single score. Alternatively, an instance of
ModelPredictedLabelConfig can be provided.
wait (bool): Whether the call should wait until the job completes (default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when ``wait`` is True (default: True).
Expand All @@ -689,7 +691,12 @@ def run_explainability(
"""
analysis_config = data_config.get_config()
predictor_config = model_config.get_predictor_config()
_set(model_scores, "label", predictor_config)
if isinstance(model_scores, ModelPredictedLabelConfig):
probability_threshold, predicted_label_config = model_scores.get_predictor_config()
_set(probability_threshold, "probability_threshold", analysis_config)
predictor_config.update(predicted_label_config)
else:
_set(model_scores, "label", predictor_config)
analysis_config["methods"] = explainability_config.get_explainability_config()
analysis_config["predictor"] = predictor_config
if job_name is None:
Expand Down
58 changes: 51 additions & 7 deletions tests/unit/test_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,20 @@ def test_post_training_bias(
)


def test_shap(clarify_processor, data_config, model_config, shap_config):
def _run_test_shap(
clarify_processor,
data_config,
model_config,
shap_config,
model_scores,
expected_predictor_config,
):
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
clarify_processor.run_explainability(
data_config,
model_config,
shap_config,
model_scores=None,
model_scores=model_scores,
wait=True,
job_name="test",
experiment_config={"ExperimentName": "AnExperiment"},
Expand Down Expand Up @@ -414,11 +421,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
"save_local_shap_values": True,
}
},
"predictor": {
"model_name": "xgboost-model",
"instance_type": "ml.c5.xlarge",
"initial_instance_count": 1,
},
"predictor": expected_predictor_config,
}
mock_method.assert_called_once_with(
data_config,
Expand All @@ -429,3 +432,44 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
None,
{"ExperimentName": "AnExperiment"},
)


def test_shap(clarify_processor, data_config, model_config, shap_config):
model_scores = None
expected_predictor_config = {
"model_name": "xgboost-model",
"instance_type": "ml.c5.xlarge",
"initial_instance_count": 1,
}
_run_test_shap(
clarify_processor,
data_config,
model_config,
shap_config,
model_scores,
expected_predictor_config,
)


def test_shap_with_predicted_label(clarify_processor, data_config, model_config, shap_config):
probability = "pr"
label_headers = ["success"]
model_scores = ModelPredictedLabelConfig(
probability=probability,
label_headers=label_headers,
)
expected_predictor_config = {
"model_name": "xgboost-model",
"instance_type": "ml.c5.xlarge",
"initial_instance_count": 1,
"probability": probability,
"label_headers": label_headers,
}
_run_test_shap(
clarify_processor,
data_config,
model_config,
shap_config,
model_scores,
expected_predictor_config,
)