diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 2f7f3dc53d..14bb675681 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -94,6 +94,8 @@ {object: object}, ) ], + # Arbitrary JSON object as baseline + {object: object}, ), SchemaOptional("num_clusters"): int, SchemaOptional("use_logit"): bool, @@ -1211,7 +1213,7 @@ class SHAPConfig(ExplainabilityConfig): def __init__( self, - baseline: Optional[Union[str, List]] = None, + baseline: Optional[Union[str, List, Dict]] = None, num_samples: Optional[int] = None, agg_method: Optional[str] = None, use_logit: bool = False, @@ -1224,7 +1226,7 @@ def __init__( """Initializes config for SHAP analysis. Args: - baseline (None or str or list): `Baseline dataset `_ + baseline (None or str or list or dict): `Baseline dataset `_ for the Kernel SHAP algorithm, accepted in the form of: S3 object URI, a list of rows (with at least one element), or None (for no input baseline). The baseline dataset must have the same format diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 714f9d316c..4fe1ecb5f2 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -563,14 +563,20 @@ def test_invalid_model_predicted_label_config(): ) -def test_shap_config(): - baseline = [ - [ - 0.26124998927116394, - 0.2824999988079071, - 0.06875000149011612, - ] - ] +@pytest.mark.parametrize( + "baseline", + [ + ([[0.26124998927116394, 0.2824999988079071, 0.06875000149011612]]), + ( + { + "instances": [ + {"features": [0.26124998927116394, 0.2824999988079071, 0.06875000149011612]} + ] + } + ), + ], +) +def test_valid_shap_config(baseline): num_samples = 100 agg_method = "mean_sq" use_logit = True