Skip to content

Commit dd954b1

Browse files
committed
change: Add label_headers option for Clarify ModelExplainabilityMonitor
The option has been added to SageMakerClarifyProcessor API by PR 2446, this commit adds the same option to ModelExplainabilityMonitor.
1 parent 9be4c8a commit dd954b1

File tree

4 files changed

+70
-21
lines changed

4 files changed

+70
-21
lines changed

src/sagemaker/clarify.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,15 @@ def __init__(
275275
probability_threshold (float): An optional value for binary prediction tasks in which
276276
the model returns a probability, to indicate the threshold to convert the
277277
prediction to a boolean value. Default is 0.5.
278-
label_headers (list): List of label values - one for each score of the ``probability``.
278+
label_headers (list[str]): List of headers, each for a predicted score in model output.
279+
For bias analysis, it is used to extract the label value with the highest score as
280+
predicted label. For explainability job, It is used to beautify the analysis report
281+
by replacing placeholders like "label0".
279282
"""
280283
self.label = label
281284
self.probability = probability
282285
self.probability_threshold = probability_threshold
286+
self.label_headers = label_headers
283287
if probability_threshold is not None:
284288
try:
285289
float(probability_threshold)
@@ -830,13 +834,12 @@ def run_explainability(
830834
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
831835
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
832836
endpoint to be created.
833-
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
834-
Config of the specific explainability method or a list of ExplainabilityConfig
835-
objects. Currently, SHAP and PDP are the two methods supported.
836-
model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the
837-
model output for the predicted scores to be explained. This is not required if the
838-
model output is a single score. Alternatively, an instance of
839-
ModelPredictedLabelConfig can be provided.
837+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig`): Config of the
838+
specific explainability method. Currently, only SHAP is supported.
839+
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
840+
Index or JSONPath to locate the predicted scores in the model output. This is not
841+
required if the model output isa single score. Alternatively, it can be an instance
842+
of ModelPredictedLabelConfig to provide more parameters like label_headers.
840843
wait (bool): Whether the call should wait until the job completes (default: True).
841844
logs (bool): Whether to show the logs produced by the job.
842845
Only meaningful when ``wait`` is True (default: True).

src/sagemaker/model_monitor/clarify_model_monitoring.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker import image_uris, s3
2727
from sagemaker.session import Session
2828
from sagemaker.utils import name_from_base
29-
from sagemaker.clarify import SageMakerClarifyProcessor
29+
from sagemaker.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig
3030

3131
_LOGGER = logging.getLogger(__name__)
3232

@@ -833,9 +833,10 @@ def suggest_baseline(
833833
specific explainability method. Currently, only SHAP is supported.
834834
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
835835
endpoint to be created.
836-
model_scores (int or str): Index or JSONPath location in the model output for the
837-
predicted scores to be explained. This is not required if the model output is
838-
a single score.
836+
model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
837+
Index or JSONPath to locate the predicted scores in the model output. This is not
838+
required if the model output isa single score. Alternatively, it can be an instance
839+
of ModelPredictedLabelConfig to provide more parameters like label_headers.
839840
wait (bool): Whether the call should wait until the job completes (default: False).
840841
logs (bool): Whether to show the logs produced by the job.
841842
Only meaningful when wait is True (default: False).
@@ -865,14 +866,24 @@ def suggest_baseline(
865866
headers = copy.deepcopy(data_config.headers)
866867
if headers and data_config.label in headers:
867868
headers.remove(data_config.label)
869+
if model_scores is None:
870+
inference_attribute = None
871+
label_headers = None
872+
elif isinstance(model_scores, ModelPredictedLabelConfig):
873+
inference_attribute = str(model_scores.label)
874+
label_headers = model_scores.label_headers
875+
else:
876+
inference_attribute = str(model_scores)
877+
label_headers = None
868878
self.latest_baselining_job_config = ClarifyBaseliningConfig(
869879
analysis_config=ExplainabilityAnalysisConfig(
870880
explainability_config=explainability_config,
871881
model_config=model_config,
872882
headers=headers,
883+
label_headers=label_headers,
873884
),
874885
features_attribute=data_config.features,
875-
inference_attribute=model_scores if model_scores is None else str(model_scores),
886+
inference_attribute=inference_attribute,
876887
)
877888
self.latest_baselining_job_name = baselining_job_name
878889
self.latest_baselining_job = ClarifyBaseliningJob(
@@ -1166,7 +1177,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
11661177
class ExplainabilityAnalysisConfig:
11671178
"""Analysis configuration for ModelExplainabilityMonitor."""
11681179

1169-
def __init__(self, explainability_config, model_config, headers=None):
1180+
def __init__(self, explainability_config, model_config, headers=None, label_headers=None):
11701181
"""Creates an analysis config dictionary.
11711182
11721183
Args:
@@ -1175,13 +1186,19 @@ def __init__(self, explainability_config, model_config, headers=None):
11751186
model_config (sagemaker.clarify.ModelConfig): Config object related to bias
11761187
configurations.
11771188
headers (list[str]): A list of feature names (without label) of model/endpint input.
1189+
label_headers (list[str]): List of headers, each for a predicted score in model output.
1190+
It is used to beautify the analysis report by replacing placeholders like "label0".
1191+
11781192
"""
1193+
predictor_config = model_config.get_predictor_config()
11791194
self.analysis_config = {
11801195
"methods": explainability_config.get_explainability_config(),
1181-
"predictor": model_config.get_predictor_config(),
1196+
"predictor": predictor_config,
11821197
}
11831198
if headers is not None:
11841199
self.analysis_config["headers"] = headers
1200+
if label_headers is not None:
1201+
predictor_config["label_headers"] = label_headers
11851202

11861203
def _to_dict(self):
11871204
"""Generates a request dictionary using the parameters provided to the class."""

tests/integ/test_clarify_model_monitor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
HEADER_OF_LABEL = "Label"
5454
HEADERS_OF_FEATURES = ["F1", "F2", "F3", "F4", "F5", "F6", "F7"]
5555
ALL_HEADERS = [*HEADERS_OF_FEATURES, HEADER_OF_LABEL]
56+
HEADER_OF_PREDICTION = "Decision"
5657
DATASET_TYPE = "text/csv"
5758
CONTENT_TYPE = DATASET_TYPE
5859
ACCEPT_TYPE = DATASET_TYPE
@@ -325,7 +326,7 @@ def scheduled_explainability_monitor(
325326
):
326327
monitor_schedule_name = utils.unique_name_from_base("explainability-monitor")
327328
analysis_config = ExplainabilityAnalysisConfig(
328-
shap_config, model_config, headers=HEADERS_OF_FEATURES
329+
shap_config, model_config, headers=HEADERS_OF_FEATURES, label_headers=[HEADER_OF_PREDICTION]
329330
)
330331
s3_uri_monitoring_output = os.path.join(
331332
"s3://",

tests/unit/sagemaker/monitor/test_clarify_model_monitor.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@
279279
# for bias
280280
ANALYSIS_CONFIG_LABEL = "Label"
281281
ANALYSIS_CONFIG_HEADERS_OF_FEATURES = ["F1", "F2", "F3"]
282+
ANALYSIS_CONFIG_LABEL_HEADERS = ["Decision"]
282283
ANALYSIS_CONFIG_ALL_HEADERS = [*ANALYSIS_CONFIG_HEADERS_OF_FEATURES, ANALYSIS_CONFIG_LABEL]
283284
ANALYSIS_CONFIG_LABEL_VALUES = [1]
284285
ANALYSIS_CONFIG_FACET_NAME = "F1"
@@ -330,6 +331,11 @@
330331
"content_type": CONTENT_TYPE,
331332
},
332333
}
334+
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS = copy.deepcopy(EXPLAINABILITY_ANALYSIS_CONFIG)
335+
# noinspection PyTypeChecker
336+
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS["predictor"][
337+
"label_headers"
338+
] = ANALYSIS_CONFIG_LABEL_HEADERS
333339

334340

335341
@pytest.fixture()
@@ -1048,25 +1054,44 @@ def test_explainability_analysis_config(shap_config, model_config):
10481054
explainability_config=shap_config,
10491055
model_config=model_config,
10501056
headers=ANALYSIS_CONFIG_HEADERS_OF_FEATURES,
1057+
label_headers=ANALYSIS_CONFIG_LABEL_HEADERS,
10511058
)
1052-
assert EXPLAINABILITY_ANALYSIS_CONFIG == config._to_dict()
1059+
assert EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS == config._to_dict()
10531060

10541061

1062+
@pytest.mark.parametrize(
1063+
"model_scores,explainability_analysis_config",
1064+
[
1065+
(INFERENCE_ATTRIBUTE, EXPLAINABILITY_ANALYSIS_CONFIG),
1066+
(
1067+
ModelPredictedLabelConfig(
1068+
label=INFERENCE_ATTRIBUTE, label_headers=ANALYSIS_CONFIG_LABEL_HEADERS
1069+
),
1070+
EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS,
1071+
),
1072+
],
1073+
)
10551074
def test_model_explainability_monitor_suggest_baseline(
1056-
model_explainability_monitor, sagemaker_session, data_config, shap_config, model_config
1075+
model_explainability_monitor,
1076+
sagemaker_session,
1077+
data_config,
1078+
shap_config,
1079+
model_config,
1080+
model_scores,
1081+
explainability_analysis_config,
10571082
):
10581083
clarify_model_monitor = model_explainability_monitor
10591084
# suggest baseline
10601085
clarify_model_monitor.suggest_baseline(
10611086
data_config=data_config,
10621087
explainability_config=shap_config,
10631088
model_config=model_config,
1064-
model_scores=INFERENCE_ATTRIBUTE,
1089+
model_scores=model_scores,
10651090
job_name=BASELINING_JOB_NAME,
10661091
)
10671092
assert isinstance(clarify_model_monitor.latest_baselining_job, ClarifyBaseliningJob)
10681093
assert (
1069-
EXPLAINABILITY_ANALYSIS_CONFIG
1094+
explainability_analysis_config
10701095
== clarify_model_monitor.latest_baselining_job_config.analysis_config._to_dict()
10711096
)
10721097
clarify_baselining_job = clarify_model_monitor.latest_baselining_job
@@ -1081,6 +1106,7 @@ def test_model_explainability_monitor_suggest_baseline(
10811106
analysis_config=None, # will pick up config from baselining job
10821107
baseline_job_name=BASELINING_JOB_NAME,
10831108
endpoint_input=ENDPOINT_NAME,
1109+
explainability_analysis_config=explainability_analysis_config,
10841110
# will pick up attributes from baselining job
10851111
)
10861112

@@ -1133,6 +1159,7 @@ def test_model_explainability_monitor_created_with_config(
11331159
sagemaker_session=sagemaker_session,
11341160
analysis_config=analysis_config,
11351161
constraints=CONSTRAINTS,
1162+
explainability_analysis_config=EXPLAINABILITY_ANALYSIS_CONFIG,
11361163
)
11371164

11381165
# update schedule
@@ -1263,6 +1290,7 @@ def _test_model_explainability_monitor_create_schedule(
12631290
features_attribute=FEATURES_ATTRIBUTE,
12641291
inference_attribute=str(INFERENCE_ATTRIBUTE),
12651292
),
1293+
explainability_analysis_config=None,
12661294
):
12671295
# create schedule
12681296
with patch(
@@ -1278,7 +1306,7 @@ def _test_model_explainability_monitor_create_schedule(
12781306
)
12791307
if not isinstance(analysis_config, str):
12801308
upload.assert_called_once()
1281-
assert json.loads(upload.call_args[0][0]) == EXPLAINABILITY_ANALYSIS_CONFIG
1309+
assert json.loads(upload.call_args[0][0]) == explainability_analysis_config
12821310

12831311
# validation
12841312
expected_arguments = {

0 commit comments

Comments
 (0)