diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 0829e25f4b..006cc4846c 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -290,11 +290,15 @@ def __init__( probability_threshold (float): An optional value for binary prediction tasks in which the model returns a probability, to indicate the threshold to convert the prediction to a boolean value. Default is 0.5. - label_headers (list): List of label values - one for each score of the ``probability``. + label_headers (list[str]): List of headers, each for a predicted score in model output. + For bias analysis, it is used to extract the label value with the highest score as + predicted label. For explainability job, It is used to beautify the analysis report + by replacing placeholders like "label0". """ self.label = label self.probability = probability self.probability_threshold = probability_threshold + self.label_headers = label_headers if probability_threshold is not None: try: float(probability_threshold) @@ -1060,10 +1064,10 @@ def run_explainability( explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list): Config of the specific explainability method or a list of ExplainabilityConfig objects. Currently, SHAP and PDP are the two methods supported. - model_scores(str|int|ModelPredictedLabelConfig): Index or JSONPath location in the - model output for the predicted scores to be explained. This is not required if the - model output is a single score. Alternatively, an instance of - ModelPredictedLabelConfig can be provided. + model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of ModelPredictedLabelConfig to provide more parameters like label_headers. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 10da0bf6c9..09de7b5c05 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -26,7 +26,7 @@ from sagemaker import image_uris, s3 from sagemaker.session import Session from sagemaker.utils import name_from_base -from sagemaker.clarify import SageMakerClarifyProcessor +from sagemaker.clarify import SageMakerClarifyProcessor, ModelPredictedLabelConfig _LOGGER = logging.getLogger(__name__) @@ -833,9 +833,10 @@ def suggest_baseline( specific explainability method. Currently, only SHAP is supported. model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. - model_scores (int or str): Index or JSONPath location in the model output for the - predicted scores to be explained. This is not required if the model output is - a single score. + model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of ModelPredictedLabelConfig to provide more parameters like label_headers. wait (bool): Whether the call should wait until the job completes (default: False). logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True (default: False). @@ -865,14 +866,24 @@ def suggest_baseline( headers = copy.deepcopy(data_config.headers) if headers and data_config.label in headers: headers.remove(data_config.label) + if model_scores is None: + inference_attribute = None + label_headers = None + elif isinstance(model_scores, ModelPredictedLabelConfig): + inference_attribute = str(model_scores.label) + label_headers = model_scores.label_headers + else: + inference_attribute = str(model_scores) + label_headers = None self.latest_baselining_job_config = ClarifyBaseliningConfig( analysis_config=ExplainabilityAnalysisConfig( explainability_config=explainability_config, model_config=model_config, headers=headers, + label_headers=label_headers, ), features_attribute=data_config.features, - inference_attribute=model_scores if model_scores is None else str(model_scores), + inference_attribute=inference_attribute, ) self.latest_baselining_job_name = baselining_job_name self.latest_baselining_job = ClarifyBaseliningJob( @@ -1166,7 +1177,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None): class ExplainabilityAnalysisConfig: """Analysis configuration for ModelExplainabilityMonitor.""" - def __init__(self, explainability_config, model_config, headers=None): + def __init__(self, explainability_config, model_config, headers=None, label_headers=None): """Creates an analysis config dictionary. Args: @@ -1175,13 +1186,19 @@ def __init__(self, explainability_config, model_config, headers=None): model_config (sagemaker.clarify.ModelConfig): Config object related to bias configurations. headers (list[str]): A list of feature names (without label) of model/endpint input. + label_headers (list[str]): List of headers, each for a predicted score in model output. + It is used to beautify the analysis report by replacing placeholders like "label0". + """ + predictor_config = model_config.get_predictor_config() self.analysis_config = { "methods": explainability_config.get_explainability_config(), - "predictor": model_config.get_predictor_config(), + "predictor": predictor_config, } if headers is not None: self.analysis_config["headers"] = headers + if label_headers is not None: + predictor_config["label_headers"] = label_headers def _to_dict(self): """Generates a request dictionary using the parameters provided to the class.""" diff --git a/tests/integ/test_clarify_model_monitor.py b/tests/integ/test_clarify_model_monitor.py index 6891082285..3f48fa1032 100644 --- a/tests/integ/test_clarify_model_monitor.py +++ b/tests/integ/test_clarify_model_monitor.py @@ -53,6 +53,7 @@ HEADER_OF_LABEL = "Label" HEADERS_OF_FEATURES = ["F1", "F2", "F3", "F4", "F5", "F6", "F7"] ALL_HEADERS = [*HEADERS_OF_FEATURES, HEADER_OF_LABEL] +HEADER_OF_PREDICTION = "Decision" DATASET_TYPE = "text/csv" CONTENT_TYPE = DATASET_TYPE ACCEPT_TYPE = DATASET_TYPE @@ -325,7 +326,7 @@ def scheduled_explainability_monitor( ): monitor_schedule_name = utils.unique_name_from_base("explainability-monitor") analysis_config = ExplainabilityAnalysisConfig( - shap_config, model_config, headers=HEADERS_OF_FEATURES + shap_config, model_config, headers=HEADERS_OF_FEATURES, label_headers=[HEADER_OF_PREDICTION] ) s3_uri_monitoring_output = os.path.join( "s3://", diff --git a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py index e13755f208..7c1d497d64 100644 --- a/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py +++ b/tests/unit/sagemaker/monitor/test_clarify_model_monitor.py @@ -279,6 +279,7 @@ # for bias ANALYSIS_CONFIG_LABEL = "Label" ANALYSIS_CONFIG_HEADERS_OF_FEATURES = ["F1", "F2", "F3"] +ANALYSIS_CONFIG_LABEL_HEADERS = ["Decision"] ANALYSIS_CONFIG_ALL_HEADERS = [*ANALYSIS_CONFIG_HEADERS_OF_FEATURES, ANALYSIS_CONFIG_LABEL] ANALYSIS_CONFIG_LABEL_VALUES = [1] ANALYSIS_CONFIG_FACET_NAME = "F1" @@ -330,6 +331,11 @@ "content_type": CONTENT_TYPE, }, } +EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS = copy.deepcopy(EXPLAINABILITY_ANALYSIS_CONFIG) +# noinspection PyTypeChecker +EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS["predictor"][ + "label_headers" +] = ANALYSIS_CONFIG_LABEL_HEADERS @pytest.fixture() @@ -1048,12 +1054,31 @@ def test_explainability_analysis_config(shap_config, model_config): explainability_config=shap_config, model_config=model_config, headers=ANALYSIS_CONFIG_HEADERS_OF_FEATURES, + label_headers=ANALYSIS_CONFIG_LABEL_HEADERS, ) - assert EXPLAINABILITY_ANALYSIS_CONFIG == config._to_dict() + assert EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS == config._to_dict() +@pytest.mark.parametrize( + "model_scores,explainability_analysis_config", + [ + (INFERENCE_ATTRIBUTE, EXPLAINABILITY_ANALYSIS_CONFIG), + ( + ModelPredictedLabelConfig( + label=INFERENCE_ATTRIBUTE, label_headers=ANALYSIS_CONFIG_LABEL_HEADERS + ), + EXPLAINABILITY_ANALYSIS_CONFIG_WITH_LABEL_HEADERS, + ), + ], +) def test_model_explainability_monitor_suggest_baseline( - model_explainability_monitor, sagemaker_session, data_config, shap_config, model_config + model_explainability_monitor, + sagemaker_session, + data_config, + shap_config, + model_config, + model_scores, + explainability_analysis_config, ): clarify_model_monitor = model_explainability_monitor # suggest baseline @@ -1061,12 +1086,12 @@ def test_model_explainability_monitor_suggest_baseline( data_config=data_config, explainability_config=shap_config, model_config=model_config, - model_scores=INFERENCE_ATTRIBUTE, + model_scores=model_scores, job_name=BASELINING_JOB_NAME, ) assert isinstance(clarify_model_monitor.latest_baselining_job, ClarifyBaseliningJob) assert ( - EXPLAINABILITY_ANALYSIS_CONFIG + explainability_analysis_config == clarify_model_monitor.latest_baselining_job_config.analysis_config._to_dict() ) clarify_baselining_job = clarify_model_monitor.latest_baselining_job @@ -1081,6 +1106,7 @@ def test_model_explainability_monitor_suggest_baseline( analysis_config=None, # will pick up config from baselining job baseline_job_name=BASELINING_JOB_NAME, endpoint_input=ENDPOINT_NAME, + explainability_analysis_config=explainability_analysis_config, # will pick up attributes from baselining job ) @@ -1133,6 +1159,7 @@ def test_model_explainability_monitor_created_with_config( sagemaker_session=sagemaker_session, analysis_config=analysis_config, constraints=CONSTRAINTS, + explainability_analysis_config=EXPLAINABILITY_ANALYSIS_CONFIG, ) # update schedule @@ -1263,6 +1290,7 @@ def _test_model_explainability_monitor_create_schedule( features_attribute=FEATURES_ATTRIBUTE, inference_attribute=str(INFERENCE_ATTRIBUTE), ), + explainability_analysis_config=None, ): # create schedule with patch( @@ -1278,7 +1306,7 @@ def _test_model_explainability_monitor_create_schedule( ) if not isinstance(analysis_config, str): upload.assert_called_once() - assert json.loads(upload.call_args[0][0]) == EXPLAINABILITY_ANALYSIS_CONFIG + assert json.loads(upload.call_args[0][0]) == explainability_analysis_config # validation expected_arguments = {