279
279
# for bias
280
280
ANALYSIS_CONFIG_LABEL = "Label"
281
281
ANALYSIS_CONFIG_HEADERS_OF_FEATURES = ["F1" , "F2" , "F3" ]
282
+ ANALYSIS_CONFIG_LABEL_HEADERS = ["Decision" ]
282
283
ANALYSIS_CONFIG_ALL_HEADERS = [* ANALYSIS_CONFIG_HEADERS_OF_FEATURES , ANALYSIS_CONFIG_LABEL ]
283
284
ANALYSIS_CONFIG_LABEL_VALUES = [1 ]
284
285
ANALYSIS_CONFIG_FACET_NAME = "F1"
330
331
"content_type" : CONTENT_TYPE ,
331
332
},
332
333
}
334
+ EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS = copy .deepcopy (EXPLAINABILITY_ANALYSIS_CONFIG )
335
+ # noinspection PyTypeChecker
336
+ EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS ["predictor" ][
337
+ "label_headers"
338
+ ] = ANALYSIS_CONFIG_LABEL_HEADERS
333
339
334
340
335
341
@pytest .fixture ()
@@ -1048,25 +1054,44 @@ def test_explainability_analysis_config(shap_config, model_config):
1048
1054
explainability_config = shap_config ,
1049
1055
model_config = model_config ,
1050
1056
headers = ANALYSIS_CONFIG_HEADERS_OF_FEATURES ,
1057
+ label_headers = ANALYSIS_CONFIG_LABEL_HEADERS ,
1051
1058
)
1052
- assert EXPLAINABILITY_ANALYSIS_CONFIG == config ._to_dict ()
1059
+ assert EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS == config ._to_dict ()
1053
1060
1054
1061
1062
+ @pytest .mark .parametrize (
1063
+ "model_scores,explainability_analysis_config" ,
1064
+ [
1065
+ (INFERENCE_ATTRIBUTE , EXPLAINABILITY_ANALYSIS_CONFIG ),
1066
+ (
1067
+ ModelPredictedLabelConfig (
1068
+ label = INFERENCE_ATTRIBUTE , label_headers = ANALYSIS_CONFIG_LABEL_HEADERS
1069
+ ),
1070
+ EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS ,
1071
+ ),
1072
+ ],
1073
+ )
1055
1074
def test_model_explainability_monitor_suggest_baseline (
1056
- model_explainability_monitor , sagemaker_session , data_config , shap_config , model_config
1075
+ model_explainability_monitor ,
1076
+ sagemaker_session ,
1077
+ data_config ,
1078
+ shap_config ,
1079
+ model_config ,
1080
+ model_scores ,
1081
+ explainability_analysis_config ,
1057
1082
):
1058
1083
clarify_model_monitor = model_explainability_monitor
1059
1084
# suggest baseline
1060
1085
clarify_model_monitor .suggest_baseline (
1061
1086
data_config = data_config ,
1062
1087
explainability_config = shap_config ,
1063
1088
model_config = model_config ,
1064
- model_scores = INFERENCE_ATTRIBUTE ,
1089
+ model_scores = model_scores ,
1065
1090
job_name = BASELINING_JOB_NAME ,
1066
1091
)
1067
1092
assert isinstance (clarify_model_monitor .latest_baselining_job , ClarifyBaseliningJob )
1068
1093
assert (
1069
- EXPLAINABILITY_ANALYSIS_CONFIG
1094
+ explainability_analysis_config
1070
1095
== clarify_model_monitor .latest_baselining_job_config .analysis_config ._to_dict ()
1071
1096
)
1072
1097
clarify_baselining_job = clarify_model_monitor .latest_baselining_job
@@ -1081,6 +1106,7 @@ def test_model_explainability_monitor_suggest_baseline(
1081
1106
analysis_config = None , # will pick up config from baselining job
1082
1107
baseline_job_name = BASELINING_JOB_NAME ,
1083
1108
endpoint_input = ENDPOINT_NAME ,
1109
+ explainability_analysis_config = explainability_analysis_config ,
1084
1110
# will pick up attributes from baselining job
1085
1111
)
1086
1112
@@ -1133,6 +1159,7 @@ def test_model_explainability_monitor_created_with_config(
1133
1159
sagemaker_session = sagemaker_session ,
1134
1160
analysis_config = analysis_config ,
1135
1161
constraints = CONSTRAINTS ,
1162
+ explainability_analysis_config = EXPLAINABILITY_ANALYSIS_CONFIG ,
1136
1163
)
1137
1164
1138
1165
# update schedule
@@ -1263,6 +1290,7 @@ def _test_model_explainability_monitor_create_schedule(
1263
1290
features_attribute = FEATURES_ATTRIBUTE ,
1264
1291
inference_attribute = str (INFERENCE_ATTRIBUTE ),
1265
1292
),
1293
+ explainability_analysis_config = None ,
1266
1294
):
1267
1295
# create schedule
1268
1296
with patch (
@@ -1278,7 +1306,7 @@ def _test_model_explainability_monitor_create_schedule(
1278
1306
)
1279
1307
if not isinstance (analysis_config , str ):
1280
1308
upload .assert_called_once ()
1281
- assert json .loads (upload .call_args [0 ][0 ]) == EXPLAINABILITY_ANALYSIS_CONFIG
1309
+ assert json .loads (upload .call_args [0 ][0 ]) == explainability_analysis_config
1282
1310
1283
1311
# validation
1284
1312
expected_arguments = {
0 commit comments