Skip to content

Commit 9f04b85

Browse files
committed
change: Add configuration option with headers for Clarify.
1 parent 12a6918 commit 9f04b85

File tree

2 files changed

+61
-10
lines changed

2 files changed

+61
-10
lines changed

src/sagemaker/clarify.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,10 @@ def run_explainability(
674674
endpoint to be created.
675675
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
676676
specific explainability method. Currently, only SHAP is supported.
677-
model_scores: Index or JSONPath location in the model output for the predicted scores
678-
to be explained. This is not required if the model output is a single score.
677+
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
678+
model output for the predicted scores to be explained. This is not required if the
679+
model output is a single score. Alternatively, an instance of
680+
ModelPredictedLabelConfig can be provided.
679681
wait (bool): Whether the call should wait until the job completes (default: True).
680682
logs (bool): Whether to show the logs produced by the job.
681683
Only meaningful when ``wait`` is True (default: True).
@@ -689,7 +691,12 @@ def run_explainability(
689691
"""
690692
analysis_config = data_config.get_config()
691693
predictor_config = model_config.get_predictor_config()
692-
_set(model_scores, "label", predictor_config)
694+
if isinstance(model_scores, ModelPredictedLabelConfig):
695+
probability_threshold, predicted_label_config = model_scores.get_predictor_config()
696+
_set(probability_threshold, "probability_threshold", analysis_config)
697+
predictor_config.update(predicted_label_config)
698+
else:
699+
_set(model_scores, "label", predictor_config)
693700
analysis_config["methods"] = explainability_config.get_explainability_config()
694701
analysis_config["predictor"] = predictor_config
695702
if job_name is None:

tests/unit/test_clarify.py

+51-7
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,20 @@ def test_post_training_bias(
379379
)
380380

381381

382-
def test_shap(clarify_processor, data_config, model_config, shap_config):
382+
def _run_test_shap(
383+
clarify_processor,
384+
data_config,
385+
model_config,
386+
shap_config,
387+
model_scores,
388+
expected_predictor_config,
389+
):
383390
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
384391
clarify_processor.run_explainability(
385392
data_config,
386393
model_config,
387394
shap_config,
388-
model_scores=None,
395+
model_scores=model_scores,
389396
wait=True,
390397
job_name="test",
391398
experiment_config={"ExperimentName": "AnExperiment"},
@@ -414,11 +421,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
414421
"save_local_shap_values": True,
415422
}
416423
},
417-
"predictor": {
418-
"model_name": "xgboost-model",
419-
"instance_type": "ml.c5.xlarge",
420-
"initial_instance_count": 1,
421-
},
424+
"predictor": expected_predictor_config,
422425
}
423426
mock_method.assert_called_once_with(
424427
data_config,
@@ -429,3 +432,44 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
429432
None,
430433
{"ExperimentName": "AnExperiment"},
431434
)
435+
436+
437+
def test_shap(clarify_processor, data_config, model_config, shap_config):
438+
model_scores = None
439+
expected_predictor_config = {
440+
"model_name": "xgboost-model",
441+
"instance_type": "ml.c5.xlarge",
442+
"initial_instance_count": 1,
443+
}
444+
_run_test_shap(
445+
clarify_processor,
446+
data_config,
447+
model_config,
448+
shap_config,
449+
model_scores,
450+
expected_predictor_config,
451+
)
452+
453+
454+
def test_shap_with_predicted_label(clarify_processor, data_config, model_config, shap_config):
455+
probability = "pr"
456+
label_headers = ["success"]
457+
model_scores = ModelPredictedLabelConfig(
458+
probability=probability,
459+
label_headers=label_headers,
460+
)
461+
expected_predictor_config = {
462+
"model_name": "xgboost-model",
463+
"instance_type": "ml.c5.xlarge",
464+
"initial_instance_count": 1,
465+
"probability": probability,
466+
"label_headers": label_headers,
467+
}
468+
_run_test_shap(
469+
clarify_processor,
470+
data_config,
471+
model_config,
472+
shap_config,
473+
model_scores,
474+
expected_predictor_config,
475+
)

0 commit comments

Comments
 (0)