26
26
from sagemaker import image_uris , s3
27
27
from sagemaker .session import Session
28
28
from sagemaker .utils import name_from_base
29
- from sagemaker .clarify import SageMakerClarifyProcessor
29
+ from sagemaker .clarify import SageMakerClarifyProcessor , ModelPredictedLabelConfig
30
30
31
31
_LOGGER = logging .getLogger (__name__ )
32
32
@@ -833,9 +833,10 @@ def suggest_baseline(
833
833
specific explainability method. Currently, only SHAP is supported.
834
834
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
835
835
endpoint to be created.
836
- model_scores (int or str): Index or JSONPath location in the model output for the
837
- predicted scores to be explained. This is not required if the model output is
838
- a single score.
836
+ model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
837
+ Index or JSONPath to locate the predicted scores in the model output. This is not
838
+ required if the model output is a single score. Alternatively, it can be an instance
839
+ of ModelPredictedLabelConfig to provide more parameters like label_headers.
839
840
wait (bool): Whether the call should wait until the job completes (default: False).
840
841
logs (bool): Whether to show the logs produced by the job.
841
842
Only meaningful when wait is True (default: False).
@@ -865,14 +866,24 @@ def suggest_baseline(
865
866
headers = copy .deepcopy (data_config .headers )
866
867
if headers and data_config .label in headers :
867
868
headers .remove (data_config .label )
869
+ if model_scores is None :
870
+ inference_attribute = None
871
+ label_headers = None
872
+ elif isinstance (model_scores , ModelPredictedLabelConfig ):
873
+ inference_attribute = str (model_scores .label )
874
+ label_headers = model_scores .label_headers
875
+ else :
876
+ inference_attribute = str (model_scores )
877
+ label_headers = None
868
878
self .latest_baselining_job_config = ClarifyBaseliningConfig (
869
879
analysis_config = ExplainabilityAnalysisConfig (
870
880
explainability_config = explainability_config ,
871
881
model_config = model_config ,
872
882
headers = headers ,
883
+ label_headers = label_headers ,
873
884
),
874
885
features_attribute = data_config .features ,
875
- inference_attribute = model_scores if model_scores is None else str ( model_scores ) ,
886
+ inference_attribute = inference_attribute ,
876
887
)
877
888
self .latest_baselining_job_name = baselining_job_name
878
889
self .latest_baselining_job = ClarifyBaseliningJob (
@@ -1166,7 +1177,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
1166
1177
class ExplainabilityAnalysisConfig :
1167
1178
"""Analysis configuration for ModelExplainabilityMonitor."""
1168
1179
1169
- def __init__ (self , explainability_config , model_config , headers = None ):
1180
+ def __init__ (self , explainability_config , model_config , headers = None , label_headers = None ):
1170
1181
"""Creates an analysis config dictionary.
1171
1182
1172
1183
Args:
@@ -1175,13 +1186,19 @@ def __init__(self, explainability_config, model_config, headers=None):
1175
1186
model_config (sagemaker.clarify.ModelConfig): Config object related to bias
1176
1187
configurations.
1177
1188
headers (list[str]): A list of feature names (without label) of model/endpint input.
1189
+ label_headers (list[str]): List of headers, each for a predicted score in model output.
1190
+ It is used to beautify the analysis report by replacing placeholders like "label0".
1191
+
1178
1192
"""
1193
+ predictor_config = model_config .get_predictor_config ()
1179
1194
self .analysis_config = {
1180
1195
"methods" : explainability_config .get_explainability_config (),
1181
- "predictor" : model_config . get_predictor_config () ,
1196
+ "predictor" : predictor_config ,
1182
1197
}
1183
1198
if headers is not None :
1184
1199
self .analysis_config ["headers" ] = headers
1200
+ if label_headers is not None :
1201
+ predictor_config ["label_headers" ] = label_headers
1185
1202
1186
1203
def _to_dict (self ):
1187
1204
"""Generates a request dictionary using the parameters provided to the class."""
0 commit comments