Skip to content

Commit 4c1c118

Browse files
spoornMichael Trinh
and
Michael Trinh
authored
fix: Update Clarify SHAPConfig baseline to allow JSON structures (#3804)
Co-authored-by: Michael Trinh <[email protected]>
1 parent 68885dc commit 4c1c118

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

src/sagemaker/clarify.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@
9494
{object: object},
9595
)
9696
],
97+
# Arbitrary JSON object as baseline
98+
{object: object},
9799
),
98100
SchemaOptional("num_clusters"): int,
99101
SchemaOptional("use_logit"): bool,
@@ -1211,7 +1213,7 @@ class SHAPConfig(ExplainabilityConfig):
12111213

12121214
def __init__(
12131215
self,
1214-
baseline: Optional[Union[str, List]] = None,
1216+
baseline: Optional[Union[str, List, Dict]] = None,
12151217
num_samples: Optional[int] = None,
12161218
agg_method: Optional[str] = None,
12171219
use_logit: bool = False,
@@ -1224,7 +1226,7 @@ def __init__(
12241226
"""Initializes config for SHAP analysis.
12251227
12261228
Args:
1227-
baseline (None or str or list): `Baseline dataset <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-shap-baselines.html>`_
1229+
baseline (None or str or list or dict): `Baseline dataset <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-shap-baselines.html>`_
12281230
for the Kernel SHAP algorithm, accepted in the form of:
12291231
S3 object URI, a list of rows (with at least one element),
12301232
or None (for no input baseline). The baseline dataset must have the same format

tests/unit/test_clarify.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -563,14 +563,20 @@ def test_invalid_model_predicted_label_config():
563563
)
564564

565565

566-
def test_shap_config():
567-
baseline = [
568-
[
569-
0.26124998927116394,
570-
0.2824999988079071,
571-
0.06875000149011612,
572-
]
573-
]
566+
@pytest.mark.parametrize(
567+
"baseline",
568+
[
569+
([[0.26124998927116394, 0.2824999988079071, 0.06875000149011612]]),
570+
(
571+
{
572+
"instances": [
573+
{"features": [0.26124998927116394, 0.2824999988079071, 0.06875000149011612]}
574+
]
575+
}
576+
),
577+
],
578+
)
579+
def test_valid_shap_config(baseline):
574580
num_samples = 100
575581
agg_method = "mean_sq"
576582
use_logit = True

0 commit comments

Comments
 (0)