Skip to content

Commit 8c40262

Browse files
feature: Add support for Partial Dependence Plots(PDP) in SageMaker Clarify
1 parent b082eb4 commit 8c40262

File tree

4 files changed

+228
-35
lines changed

4 files changed

+228
-35
lines changed

src/sagemaker/clarify.py

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import os
2121
import tempfile
2222
import re
23+
from typing import List
24+
2325
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
2426
from sagemaker import image_uris, s3, utils
2527

@@ -54,7 +56,11 @@ def __init__(
5456
"ShardedByS3Key".
5557
s3_compression_type (str): Valid options are "None" or "Gzip".
5658
"""
57-
if dataset_type not in ["text/csv", "application/jsonlines", "application/x-parquet"]:
59+
if dataset_type not in [
60+
"text/csv",
61+
"application/jsonlines",
62+
"application/x-parquet",
63+
]:
5864
raise ValueError(
5965
f"Invalid dataset_type '{dataset_type}'."
6066
f" Please check the API documentation for the supported dataset types."
@@ -292,7 +298,30 @@ class ExplainabilityConfig(ABC):
292298
@abstractmethod
293299
def get_explainability_config(self):
294300
"""Returns config."""
295-
return None
301+
302+
303+
class PDPConfig(ExplainabilityConfig):
304+
"""Config class for Partial Dependence Plots (PDP)"""
305+
306+
def __init__(self, features=None, grid_resolution=None):
307+
"""Initializes config for PDP.
308+
309+
Args:
310+
features (None or list): List of features names or indices for which partial dependence
311+
plots must be computed and plotted.
312+
grid_resolution (int): In case of numerical features, this number represents that
313+
number of buckets that range of values must be divided into. This decides the
314+
granularity of the grid in which the PDP are plotted.
315+
"""
316+
self.pdp_config = {}
317+
if features is not None:
318+
self.pdp_config["features"] = features
319+
if grid_resolution is not None:
320+
self.pdp_config["grid_resolution"] = grid_resolution
321+
322+
def get_explainability_config(self):
323+
"""Returns config."""
324+
return {"pdp": copy.deepcopy(self.pdp_config)}
296325

297326

298327
class SHAPConfig(ExplainabilityConfig):
@@ -466,7 +495,10 @@ def _run(
466495
will be unassociated.
467496
* `TrialComponentDisplayName` is used for display in Studio.
468497
"""
469-
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
498+
analysis_config["methods"]["report"] = {
499+
"name": "report",
500+
"title": "Analysis Report",
501+
}
470502
with tempfile.TemporaryDirectory() as tmpdirname:
471503
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
472504
with open(analysis_config_file, "w") as f:
@@ -568,7 +600,15 @@ def run_pre_training_bias(
568600
job_name = utils.name_from_base(self.job_name_prefix)
569601
else:
570602
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
571-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
603+
self._run(
604+
data_config,
605+
analysis_config,
606+
wait,
607+
logs,
608+
job_name,
609+
kms_key,
610+
experiment_config,
611+
)
572612

573613
def run_post_training_bias(
574614
self,
@@ -646,7 +686,15 @@ def run_post_training_bias(
646686
job_name = utils.name_from_base(self.job_name_prefix)
647687
else:
648688
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
649-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
689+
self._run(
690+
data_config,
691+
analysis_config,
692+
wait,
693+
logs,
694+
job_name,
695+
kms_key,
696+
experiment_config,
697+
)
650698

651699
def run_bias(
652700
self,
@@ -741,7 +789,15 @@ def run_bias(
741789
job_name = utils.name_from_base(self.job_name_prefix)
742790
else:
743791
job_name = utils.name_from_base("Clarify-Bias")
744-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
792+
self._run(
793+
data_config,
794+
analysis_config,
795+
wait,
796+
logs,
797+
job_name,
798+
kms_key,
799+
experiment_config,
800+
)
745801

746802
def run_explainability(
747803
self,
@@ -772,7 +828,8 @@ def run_explainability(
772828
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
773829
endpoint to be created.
774830
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
775-
specific explainability method. Currently, only SHAP is supported.
831+
specific explainability method or a list of ExplainabilityConfig objects. Currently,
832+
SHAP and PDP are the two methods supported.
776833
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
777834
model output for the predicted scores to be explained. This is not required if the
778835
model output is a single score. Alternatively, an instance of
@@ -781,7 +838,7 @@ def run_explainability(
781838
logs (bool): Whether to show the logs produced by the job.
782839
Only meaningful when ``wait`` is True (default: True).
783840
job_name (str): Processing job name. When ``job_name`` is not specified, if
784-
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
841+
`job_name_prefix` in :class:`SageMakerClarifyProcessor` specified, the job name
785842
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
786843
"Clarify-Explainability" as prefix.
787844
kms_key (str): The ARN of the KMS key that is used to encrypt the
@@ -801,19 +858,44 @@ def run_explainability(
801858
analysis_config = data_config.get_config()
802859
predictor_config = model_config.get_predictor_config()
803860
if isinstance(model_scores, ModelPredictedLabelConfig):
804-
probability_threshold, predicted_label_config = model_scores.get_predictor_config()
861+
(
862+
probability_threshold,
863+
predicted_label_config,
864+
) = model_scores.get_predictor_config()
805865
_set(probability_threshold, "probability_threshold", analysis_config)
806866
predictor_config.update(predicted_label_config)
807867
else:
808868
_set(model_scores, "label", predictor_config)
809-
analysis_config["methods"] = explainability_config.get_explainability_config()
869+
870+
explainability_methods = {}
871+
if isinstance(explainability_config, List): # pylint: disable=W1116
872+
for config in explainability_config:
873+
if not isinstance(config, ExplainabilityConfig):
874+
raise ValueError(
875+
f"Invalid input: Excepted ExplainabilityConfig, got {type(config)} instead"
876+
)
877+
explain_config = config.get_explainability_config()
878+
explainability_methods[list(explain_config.keys())[0]] = explain_config[
879+
list(explain_config.keys())[0]
880+
]
881+
elif isinstance(explainability_config, ExplainabilityConfig):
882+
explainability_methods = explainability_config.get_explainability_config()
883+
analysis_config["methods"] = explainability_methods
810884
analysis_config["predictor"] = predictor_config
811885
if job_name is None:
812886
if self.job_name_prefix:
813887
job_name = utils.name_from_base(self.job_name_prefix)
814888
else:
815889
job_name = utils.name_from_base("Clarify-Explainability")
816-
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
890+
self._run(
891+
data_config,
892+
analysis_config,
893+
wait,
894+
logs,
895+
job_name,
896+
kms_key,
897+
experiment_config,
898+
)
817899

818900

819901
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):

tests/integ/test_clarify.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from __future__ import print_function, absolute_import
1515

16-
1716
import json
1817
import math
1918
import numpy as np
@@ -31,14 +30,14 @@
3130
ModelConfig,
3231
ModelPredictedLabelConfig,
3332
SHAPConfig,
33+
PDPConfig,
3434
)
3535

3636
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerPredictor
3737
from sagemaker import utils
3838
from tests import integ
3939
from tests.integ import timeout
4040

41-
4241
CLARIFY_DEFAULT_TIMEOUT_MINUTES = 15
4342

4443

@@ -177,6 +176,11 @@ def shap_config():
177176
)
178177

179178

179+
@pytest.fixture(scope="module")
180+
def pdp_config():
181+
return PDPConfig(features=["F1"], grid_resolution=10)
182+
183+
180184
def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sagemaker_session):
181185
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
182186
clarify_processor.run_pre_training_bias(
@@ -258,6 +262,57 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak
258262
check_analysis_config(data_config, sagemaker_session, "shap")
259263

260264

265+
def test_pdp(clarify_processor, data_config, model_config, pdp_config, sagemaker_session):
266+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
267+
clarify_processor.run_explainability(
268+
data_config,
269+
model_config,
270+
pdp_config,
271+
model_scores="score",
272+
job_name=utils.unique_name_from_base("clarify-explainability-pdp"),
273+
wait=True,
274+
)
275+
analysis_result_json = s3.S3Downloader.read_file(
276+
data_config.s3_output_path + "/analysis.json",
277+
sagemaker_session,
278+
)
279+
analysis_result = json.loads(analysis_result_json)
280+
print(analysis_result)
281+
assert analysis_result["explanations"]["pdp"][0]["feature_name"] == "F1"
282+
283+
check_analysis_config(data_config, sagemaker_session, "pdp")
284+
285+
286+
def test_shap_and_pdp(
287+
clarify_processor, data_config, model_config, shap_config, pdp_config, sagemaker_session
288+
):
289+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
290+
clarify_processor.run_explainability(
291+
data_config,
292+
model_config,
293+
[shap_config, pdp_config],
294+
model_scores="score",
295+
job_name=utils.unique_name_from_base("clarify-explainability"),
296+
wait=True,
297+
)
298+
analysis_result_json = s3.S3Downloader.read_file(
299+
data_config.s3_output_path + "/analysis.json",
300+
sagemaker_session,
301+
)
302+
analysis_result = json.loads(analysis_result_json)
303+
print(analysis_result)
304+
assert (
305+
math.fabs(
306+
analysis_result["explanations"]["kernel_shap"]["label0"]["global_shap_values"]["F2"]
307+
)
308+
<= 1
309+
)
310+
assert analysis_result["explanations"]["pdp"][0]["feature_name"] == "F1"
311+
312+
check_analysis_config(data_config, sagemaker_session, "pdp")
313+
check_analysis_config(data_config, sagemaker_session, "shap")
314+
315+
261316
def check_analysis_config(data_config, sagemaker_session, method):
262317
analysis_config_json = s3.S3Downloader.read_file(
263318
data_config.s3_output_path + "/analysis_config.json",

0 commit comments

Comments
 (0)