Skip to content

Commit 13c3b7d

Browse files
feature: Add support for Partial Dependence Plots(PDP) in SageMaker Clarify
1 parent 2594ffb commit 13c3b7d

File tree

4 files changed

+165
-26
lines changed

4 files changed

+165
-26
lines changed

src/sagemaker/clarify.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import os
2121
import tempfile
2222
import re
23+
from typing import List
24+
2325
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
2426
from sagemaker import image_uris, s3, utils
2527

@@ -292,7 +294,18 @@ class ExplainabilityConfig(ABC):
292294
@abstractmethod
293295
def get_explainability_config(self):
294296
"""Returns config."""
295-
return None
297+
298+
299+
class PDPConfig(ExplainabilityConfig):
300+
def __init__(self, features=None, grid_resolution=None):
301+
self.pdp_config = {}
302+
if features is not None:
303+
self.pdp_config["features"] = features
304+
if grid_resolution is not None:
305+
self.pdp_config["grid_resolution"] = grid_resolution
306+
307+
def get_explainability_config(self):
308+
return {"pdp": copy.deepcopy(self.pdp_config)}
296309

297310

298311
class SHAPConfig(ExplainabilityConfig):
@@ -771,8 +784,9 @@ def run_explainability(
771784
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
772785
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
773786
endpoint to be created.
774-
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
775-
specific explainability method. Currently, only SHAP is supported.
787+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`| list of
788+
:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the specific explainability method or a
789+
list of ExplainabilityConfig objects. Currently, SHAP and PDP are the two methods supported.
776790
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
777791
model output for the predicted scores to be explained. This is not required if the
778792
model output is a single score. Alternatively, an instance of
@@ -806,7 +820,21 @@ def run_explainability(
806820
predictor_config.update(predicted_label_config)
807821
else:
808822
_set(model_scores, "label", predictor_config)
809-
analysis_config["methods"] = explainability_config.get_explainability_config()
823+
824+
explainability_methods = {}
825+
if isinstance(explainability_config, List):
826+
for config in explainability_config:
827+
if not isinstance(config, ExplainabilityConfig):
828+
raise ValueError(
829+
f"Invalid input: Excepted ExplainabilityConfig, got {type(config)} instead"
830+
)
831+
explain_config = config.get_explainability_config()
832+
explainability_methods[list(explain_config.keys())[0]] = explain_config[
833+
list(explain_config.keys())[0]
834+
]
835+
elif isinstance(explainability_config, ExplainabilityConfig):
836+
explainability_methods = explainability_config.get_explainability_config()
837+
analysis_config["methods"] = explainability_methods
810838
analysis_config["predictor"] = predictor_config
811839
if job_name is None:
812840
if self.job_name_prefix:

tests/integ/test_clarify.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from __future__ import print_function, absolute_import
1515

16-
1716
import json
1817
import math
1918
import numpy as np
@@ -31,14 +30,14 @@
3130
ModelConfig,
3231
ModelPredictedLabelConfig,
3332
SHAPConfig,
33+
PDPConfig,
3434
)
3535

3636
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerPredictor
3737
from sagemaker import utils
3838
from tests import integ
3939
from tests.integ import timeout
4040

41-
4241
CLARIFY_DEFAULT_TIMEOUT_MINUTES = 15
4342

4443

@@ -177,6 +176,11 @@ def shap_config():
177176
)
178177

179178

179+
@pytest.fixture(scope="module")
180+
def pdp_config():
181+
return PDPConfig(features=["F1"], grid_resolution=10)
182+
183+
180184
def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sagemaker_session):
181185
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
182186
clarify_processor.run_pre_training_bias(
@@ -258,6 +262,57 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak
258262
check_analysis_config(data_config, sagemaker_session, "shap")
259263

260264

265+
def test_pdp(clarify_processor, data_config, model_config, pdp_config, sagemaker_session):
266+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
267+
clarify_processor.run_explainability(
268+
data_config,
269+
model_config,
270+
pdp_config,
271+
model_scores="score",
272+
job_name=utils.unique_name_from_base("clarify-explainability-pdp"),
273+
wait=True,
274+
)
275+
analysis_result_json = s3.S3Downloader.read_file(
276+
data_config.s3_output_path + "/analysis.json",
277+
sagemaker_session,
278+
)
279+
analysis_result = json.loads(analysis_result_json)
280+
print(analysis_result)
281+
assert analysis_result["explanations"]["pdp"][0]["feature_name"] == "F1"
282+
283+
check_analysis_config(data_config, sagemaker_session, "pdp")
284+
285+
286+
def test_shap_and_pdp(
287+
clarify_processor, data_config, model_config, shap_config, pdp_config, sagemaker_session
288+
):
289+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
290+
clarify_processor.run_explainability(
291+
data_config,
292+
model_config,
293+
[shap_config, pdp_config],
294+
model_scores="score",
295+
job_name=utils.unique_name_from_base("clarify-explainability"),
296+
wait=True,
297+
)
298+
analysis_result_json = s3.S3Downloader.read_file(
299+
data_config.s3_output_path + "/analysis.json",
300+
sagemaker_session,
301+
)
302+
analysis_result = json.loads(analysis_result_json)
303+
print(analysis_result)
304+
assert (
305+
math.fabs(
306+
analysis_result["explanations"]["kernel_shap"]["label0"]["global_shap_values"]["F2"]
307+
)
308+
<= 1
309+
)
310+
assert analysis_result["explanations"]["pdp"][0]["feature_name"] == "F1"
311+
312+
check_analysis_config(data_config, sagemaker_session, "pdp")
313+
check_analysis_config(data_config, sagemaker_session, "shap")
314+
315+
261316
def check_analysis_config(data_config, sagemaker_session, method):
262317
analysis_config_json = s3.S3Downloader.read_file(
263318
data_config.s3_output_path + "/analysis_config.json",

tests/integ/test_clarify_model_monitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
ModelConfig,
4343
ModelPredictedLabelConfig,
4444
SHAPConfig,
45+
PDPConfig,
4546
)
4647
from sagemaker.model import Model
4748

tests/unit/test_clarify.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ModelConfig,
2424
ModelPredictedLabelConfig,
2525
SHAPConfig,
26+
PDPConfig,
2627
)
2728
from sagemaker import image_uris
2829

@@ -268,6 +269,12 @@ def test_shap_config():
268269
assert expected_config == shap_config.get_explainability_config()
269270

270271

272+
def test_pdp_config():
273+
pdp_config = PDPConfig(features=["f1", "f2"], grid_resolution=20)
274+
expected_config = {"pdp": {"features": ["f1", "f2"], "grid_resolution": 20}}
275+
assert expected_config == pdp_config.get_explainability_config()
276+
277+
271278
def test_invalid_shap_config():
272279
with pytest.raises(ValueError) as error:
273280
SHAPConfig(
@@ -374,6 +381,11 @@ def shap_config():
374381
)
375382

376383

384+
@pytest.fixture(scope="module")
385+
def pdp_config():
386+
return PDPConfig(features=["F1", "F2"], grid_resolution=20)
387+
388+
377389
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
378390
def test_pre_training_bias(
379391
name_from_base,
@@ -499,21 +511,30 @@ def test_post_training_bias(
499511
)
500512

501513

502-
def _run_test_shap(
514+
def _run_test_explain(
503515
name_from_base,
504516
clarify_processor,
505517
clarify_processor_with_job_name_prefix,
506518
data_config,
507519
model_config,
508520
shap_config,
521+
pdp_config,
509522
model_scores,
510523
expected_predictor_config,
511524
):
512525
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
526+
explanation_configs = None
527+
if shap_config and pdp_config:
528+
explanation_configs = [shap_config, pdp_config]
529+
elif shap_config:
530+
explanation_configs = shap_config
531+
elif pdp_config:
532+
explanation_configs = pdp_config
533+
513534
clarify_processor.run_explainability(
514535
data_config,
515536
model_config,
516-
shap_config,
537+
explanation_configs,
517538
model_scores=model_scores,
518539
wait=True,
519540
job_name="test",
@@ -528,23 +549,26 @@ def _run_test_shap(
528549
"F3",
529550
],
530551
"label": "Label",
531-
"methods": {
532-
"shap": {
533-
"baseline": [
534-
[
535-
0.26124998927116394,
536-
0.2824999988079071,
537-
0.06875000149011612,
538-
]
539-
],
540-
"num_samples": 100,
541-
"agg_method": "mean_sq",
542-
"use_logit": False,
543-
"save_local_shap_values": True,
544-
}
545-
},
546552
"predictor": expected_predictor_config,
547553
}
554+
expected_explanation_configs = {}
555+
if shap_config:
556+
expected_explanation_configs["shap"] = {
557+
"baseline": [
558+
[
559+
0.26124998927116394,
560+
0.2824999988079071,
561+
0.06875000149011612,
562+
]
563+
],
564+
"num_samples": 100,
565+
"agg_method": "mean_sq",
566+
"use_logit": False,
567+
"save_local_shap_values": True,
568+
}
569+
if pdp_config:
570+
expected_explanation_configs["pdp"] = {"features": ["F1", "F2"], "grid_resolution": 20}
571+
expected_analysis_config["methods"] = expected_explanation_configs
548572
mock_method.assert_called_with(
549573
data_config,
550574
expected_analysis_config,
@@ -557,7 +581,7 @@ def _run_test_shap(
557581
clarify_processor_with_job_name_prefix.run_explainability(
558582
data_config,
559583
model_config,
560-
shap_config,
584+
explanation_configs,
561585
model_scores=model_scores,
562586
wait=True,
563587
experiment_config={"ExperimentName": "AnExperiment"},
@@ -574,6 +598,34 @@ def _run_test_shap(
574598
)
575599

576600

601+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
602+
def test_pdp(
603+
name_from_base,
604+
clarify_processor,
605+
clarify_processor_with_job_name_prefix,
606+
data_config,
607+
model_config,
608+
shap_config,
609+
pdp_config,
610+
):
611+
expected_predictor_config = {
612+
"model_name": "xgboost-model",
613+
"instance_type": "ml.c5.xlarge",
614+
"initial_instance_count": 1,
615+
}
616+
_run_test_explain(
617+
name_from_base,
618+
clarify_processor,
619+
clarify_processor_with_job_name_prefix,
620+
data_config,
621+
model_config,
622+
None,
623+
pdp_config,
624+
None,
625+
expected_predictor_config,
626+
)
627+
628+
577629
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
578630
def test_shap(
579631
name_from_base,
@@ -588,14 +640,15 @@ def test_shap(
588640
"instance_type": "ml.c5.xlarge",
589641
"initial_instance_count": 1,
590642
}
591-
_run_test_shap(
643+
_run_test_explain(
592644
name_from_base,
593645
clarify_processor,
594646
clarify_processor_with_job_name_prefix,
595647
data_config,
596648
model_config,
597649
shap_config,
598650
None,
651+
None,
599652
expected_predictor_config,
600653
)
601654

@@ -608,6 +661,7 @@ def test_shap_with_predicted_label(
608661
data_config,
609662
model_config,
610663
shap_config,
664+
pdp_config,
611665
):
612666
probability = "pr"
613667
label_headers = ["success"]
@@ -622,13 +676,14 @@ def test_shap_with_predicted_label(
622676
"probability": probability,
623677
"label_headers": label_headers,
624678
}
625-
_run_test_shap(
679+
_run_test_explain(
626680
name_from_base,
627681
clarify_processor,
628682
clarify_processor_with_job_name_prefix,
629683
data_config,
630684
model_config,
631685
shap_config,
686+
pdp_config,
632687
model_scores,
633688
expected_predictor_config,
634689
)

0 commit comments

Comments
 (0)