Skip to content

Commit 515ef02

Browse files
author
Michael Trinh
committed
fix: Update Clarify SHAPConfig baseline to allow JSON structures
1 parent 162e922 commit 515ef02

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

src/sagemaker/clarify.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@
9494
{object: object},
9595
)
9696
],
97+
# Arbitrary JSON object as baseline
98+
{object: object},
9799
),
98100
SchemaOptional("num_clusters"): int,
99101
SchemaOptional("use_logit"): bool,
@@ -1201,7 +1203,7 @@ class SHAPConfig(ExplainabilityConfig):
12011203

12021204
def __init__(
12031205
self,
1204-
baseline: Optional[Union[str, List]] = None,
1206+
baseline: Optional[Union[str, List, Dict]] = None,
12051207
num_samples: Optional[int] = None,
12061208
agg_method: Optional[str] = None,
12071209
use_logit: bool = False,
@@ -1214,7 +1216,7 @@ def __init__(
12141216
"""Initializes config for SHAP analysis.
12151217
12161218
Args:
1217-
baseline (None or str or list): `Baseline dataset <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-shap-baselines.html>`_
1219+
baseline (None or str or list or dict): `Baseline dataset <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-shap-baselines.html>`_
12181220
for the Kernel SHAP algorithm, accepted in the form of:
12191221
S3 object URI, a list of rows (with at least one element),
12201222
or None (for no input baseline). The baseline dataset must have the same format

tests/unit/test_clarify.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -561,14 +561,20 @@ def test_invalid_model_predicted_label_config():
561561
)
562562

563563

564-
def test_shap_config():
565-
baseline = [
566-
[
567-
0.26124998927116394,
568-
0.2824999988079071,
569-
0.06875000149011612,
570-
]
571-
]
564+
@pytest.mark.parametrize(
565+
"baseline",
566+
[
567+
([[0.26124998927116394, 0.2824999988079071, 0.06875000149011612]]),
568+
(
569+
{
570+
"instances": [
571+
{"features": [0.26124998927116394, 0.2824999988079071, 0.06875000149011612]}
572+
]
573+
}
574+
),
575+
],
576+
)
577+
def test_valid_shap_config(baseline):
572578
num_samples = 100
573579
agg_method = "mean_sq"
574580
use_logit = True

0 commit comments

Comments
 (0)