diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 7f00a78268..e433d6e104 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -25,7 +25,7 @@ import tempfile from abc import ABC, abstractmethod -from typing import List, Union +from typing import List, Union, Dict from sagemaker import image_uris, s3, utils from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor @@ -172,7 +172,11 @@ def __init__( _set(joinsource, "joinsource_name_or_index", self.analysis_config) _set(facet_dataset_uri, "facet_dataset_uri", self.analysis_config) _set(facet_headers, "facet_headers", self.analysis_config) - _set(predicted_label_dataset_uri, "predicted_label_dataset_uri", self.analysis_config) + _set( + predicted_label_dataset_uri, + "predicted_label_dataset_uri", + self.analysis_config, + ) _set(predicted_label_headers, "predicted_label_headers", self.analysis_config) _set(predicted_label, "predicted_label", self.analysis_config) _set(excluded_columns, "excluded_columns", self.analysis_config) @@ -491,7 +495,10 @@ def __init__(self, features=None, grid_resolution=15, top_k_features=10): top_k_features (int): Sets the number of top SHAP attributes used to compute partial dependence plots. """ # noqa E501 - self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features} + self.pdp_config = { + "grid_resolution": grid_resolution, + "top_k_features": top_k_features, + } if features is not None: self.pdp_config["features"] = features @@ -824,7 +831,11 @@ def __init__( image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image features. Default is None. """ # noqa E501 # pylint: disable=c0301 - if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]: + if agg_method is not None and agg_method not in [ + "mean_abs", + "median", + "mean_sq", + ]: raise ValueError( f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq." ) @@ -1167,7 +1178,11 @@ def run_post_training_bias( * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.bias_post_training( - data_config, data_bias_config, model_predicted_label_config, methods, model_config + data_config, + data_bias_config, + model_predicted_label_config, + methods, + model_config, ) # when name is either not provided (is None) or an empty string ("") job_name = job_name or utils.name_from_base( @@ -1368,68 +1383,198 @@ def run_explainability( experiment_config, ) + def run_bias_and_explainability( + self, + data_config: DataConfig, + model_config: ModelConfig, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], + bias_config: BiasConfig, + pre_training_methods: Union[str, List[str]] = "all", + post_training_methods: Union[str, List[str]] = "all", + model_predicted_label_config: ModelPredictedLabelConfig = None, + wait=True, + logs=True, + job_name=None, + kms_key=None, + experiment_config=None, + ): + """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions. + + For bias: + Computes metrics for both the pre-training and the post-training methods. + To calculate post-training methods, it spins up a model endpoint and runs inference over the + input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`) + to obtain predicted labels. + + For Explainability: + Spins up a model endpoint. + + Currently, only SHAP and Partial Dependence Plots (PDP) are supported + as explainability methods. + You can request both methods or one at a time with the ``explainability_config`` parameter. + + When SHAP is requested in the ``explainability_config``, + the SHAP algorithm calculates the feature importance for each input example + in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`, + by creating ``num_samples`` copies of the example with a subset of features + replaced with values from the ``baseline``. + It then runs model inference to see how the model's prediction changes with the replaced + features. If the model output returns multiple scores importance is computed for each score. + Across examples, feature importance is aggregated using ``agg_method``. + + When PDP is requested in the ``explainability_config``, + the PDP algorithm calculates the dependence of the target response + on the input features and marginalizes over the values of all other input features. + The Partial Dependence Plots are included in the output + `report `__ + and the corresponding values are included in the analysis output. + + Args: + data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. + model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its + endpoint to be created. + explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list): + Config of the specific explainability method or a list of + :class:`~sagemaker.clarify.ExplainabilityConfig` objects. + Currently, SHAP and PDP are the two methods supported. + You can request multiple methods at once by passing in a list of + `~sagemaker.clarify.ExplainabilityConfig`. + bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. + pre_training_methods (str or list[str]): Selector of a subset of potential metrics: + ["`CI `_", + "`DPL `_", + "`KL `_", + "`JS `_", + "`LP `_", + "`TVD `_", + "`KS `_", + "`CDDL `_"]. + Defaults to str "all" to run all metrics if left unspecified. + post_training_methods (str or list[str]): Selector of a subset of potential metrics: + ["`DPPL `_" + , "`DI `_", + "`DCA `_", + "`DCR `_", + "`RD `_", + "`DAR `_", + "`DRR `_", + "`AD `_", + "`CDDPL `_ + ", "`TE `_", + "`FT `_"]. + Defaults to str "all" to run all metrics if left unspecified. + model_predicted_label_config ( + int or + str or + :class:`~sagemaker.clarify.ModelPredictedLabelConfig` + ): + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + to provide more parameters like ``label_headers``. + wait (bool): Whether the call should wait until the job completes (default: True). + logs (bool): Whether to show the logs produced by the job. + Only meaningful when ``wait`` is True (default: True). + job_name (str): Processing job name. When ``job_name`` is not specified, + if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + is specified, the job name will be composed of ``job_name_prefix`` and current + timestamp; otherwise use ``"Clarify-Explainability"`` as prefix. + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). + experiment_config (dict[str, str]): Experiment management configuration. + Optionally, the dict can contain three keys: + ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``. + + The behavior of setting these keys is as follows: + + * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If ``'TrialName'`` is supplied and the Trial already exists, + the job's Trial Component will be associated with the Trial. + * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied, + the Trial Component will be unassociated. + * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. + """ # noqa E501 # pylint: disable=c0301 + analysis_config = _AnalysisConfigGenerator.bias_and_explainability( + data_config, + model_config, + model_predicted_label_config, + explainability_config, + bias_config, + pre_training_methods, + post_training_methods, + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Bias-And-Explainability" + ) + return self._run( + data_config, + analysis_config, + wait, + logs, + job_name, + kms_key, + experiment_config, + ) + class _AnalysisConfigGenerator: """Creates analysis_config objects for different type of runs.""" + @classmethod + def bias_and_explainability( + cls, + data_config: DataConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], + bias_config: BiasConfig, + pre_training_methods: Union[str, List[str]] = "all", + post_training_methods: Union[str, List[str]] = "all", + ): + """Generates a config for Bias and Explainability""" + analysis_config = {**data_config.get_config(), **bias_config.get_config()} + analysis_config = cls._add_methods( + analysis_config, + pre_training_methods=pre_training_methods, + post_training_methods=post_training_methods, + explainability_config=explainability_config, + ) + analysis_config = cls._add_predictor( + analysis_config, model_config, model_predicted_label_config + ) + return analysis_config + @classmethod def explainability( cls, data_config: DataConfig, model_config: ModelConfig, - model_scores: ModelPredictedLabelConfig, - explainability_config: ExplainabilityConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], ): """Generates a config for Explainability""" - analysis_config = data_config.get_config() - predictor_config = model_config.get_predictor_config() - if isinstance(model_scores, ModelPredictedLabelConfig): - ( - probability_threshold, - predicted_label_config, - ) = model_scores.get_predictor_config() - _set(probability_threshold, "probability_threshold", analysis_config) - predictor_config.update(predicted_label_config) - else: - _set(model_scores, "label", predictor_config) - - explainability_methods = {} - if isinstance(explainability_config, list): - if len(explainability_config) == 0: - raise ValueError("Please provide at least one explainability config.") - for config in explainability_config: - explain_config = config.get_explainability_config() - explainability_methods.update(explain_config) - if not len(explainability_methods.keys()) == len(explainability_config): - raise ValueError("Duplicate explainability configs are provided") - if ( - "shap" not in explainability_methods - and explainability_methods["pdp"].get("features", None) is None - ): - raise ValueError("PDP features must be provided when ShapConfig is not provided") - else: - if ( - isinstance(explainability_config, PDPConfig) - and explainability_config.get_explainability_config()["pdp"].get("features", None) - is None - ): - raise ValueError("PDP features must be provided when ShapConfig is not provided") - explainability_methods = explainability_config.get_explainability_config() - analysis_config["methods"] = explainability_methods - analysis_config["predictor"] = predictor_config - return cls._common(analysis_config) + analysis_config = data_config.analysis_config + analysis_config = cls._add_predictor( + analysis_config, model_config, model_predicted_label_config + ) + analysis_config = cls._add_methods( + analysis_config, explainability_config=explainability_config + ) + return analysis_config @classmethod def bias_pre_training( - cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]] + cls, + data_config: DataConfig, + bias_config: BiasConfig, + methods: Union[str, List[str]], ): """Generates a config for Bias Pre Training""" - analysis_config = { - **data_config.get_config(), - **bias_config.get_config(), - "methods": {"pre_training_bias": {"methods": methods}}, - } - return cls._common(analysis_config) + analysis_config = {**data_config.get_config(), **bias_config.get_config()} + analysis_config = cls._add_methods(analysis_config, pre_training_methods=methods) + return analysis_config @classmethod def bias_post_training( @@ -1441,21 +1586,12 @@ def bias_post_training( model_config: ModelConfig, ): """Generates a config for Bias Post Training""" - analysis_config = { - **data_config.get_config(), - **bias_config.get_config(), - "predictor": {**model_config.get_predictor_config()}, - "methods": {"post_training_bias": {"methods": methods}}, - } - if model_predicted_label_config: - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - if predictor_config: - analysis_config["predictor"].update(predictor_config) - _set(probability_threshold, "probability_threshold", analysis_config) - return cls._common(analysis_config) + analysis_config = {**data_config.get_config(), **bias_config.get_config()} + analysis_config = cls._add_methods(analysis_config, post_training_methods=methods) + analysis_config = cls._add_predictor( + analysis_config, model_config, model_predicted_label_config + ) + return analysis_config @classmethod def bias( @@ -1468,16 +1604,28 @@ def bias( post_training_methods: Union[str, List[str]] = "all", ): """Generates a config for Bias""" - analysis_config = { - **data_config.get_config(), - **bias_config.get_config(), - "predictor": model_config.get_predictor_config(), - "methods": { - "pre_training_bias": {"methods": pre_training_methods}, - "post_training_bias": {"methods": post_training_methods}, - }, - } - if model_predicted_label_config: + analysis_config = {**data_config.get_config(), **bias_config.get_config()} + analysis_config = cls._add_methods( + analysis_config, + pre_training_methods=pre_training_methods, + post_training_methods=post_training_methods, + ) + analysis_config = cls._add_predictor( + analysis_config, model_config, model_predicted_label_config + ) + return analysis_config + + @classmethod + def _add_predictor( + cls, + analysis_config: Dict, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + ): + """Extends analysis config with predictor.""" + analysis_config = {**analysis_config} + analysis_config["predictor"] = model_config.get_predictor_config() + if isinstance(model_predicted_label_config, ModelPredictedLabelConfig): ( probability_threshold, predictor_config, @@ -1485,17 +1633,82 @@ def bias( if predictor_config: analysis_config["predictor"].update(predictor_config) _set(probability_threshold, "probability_threshold", analysis_config) - return cls._common(analysis_config) - - @staticmethod - def _common(analysis_config): - """Extends analysis config with common values""" - analysis_config["methods"]["report"] = { - "name": "report", - "title": "Analysis Report", - } + else: + _set(model_predicted_label_config, "label", analysis_config["predictor"]) + return analysis_config + + @classmethod + def _add_methods( + cls, + analysis_config: Dict, + pre_training_methods: Union[str, List[str]] = None, + post_training_methods: Union[str, List[str]] = None, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]] = None, + report=True, + ): + """Extends analysis config with methods.""" + # validate + params = [pre_training_methods, post_training_methods, explainability_config] + if not any(params): + raise AttributeError( + "analysis_config must have at least one working method: " + "One of the " + "`pre_training_methods`, `post_training_methods`, `explainability_config`." + ) + + # main logic + analysis_config = {**analysis_config} + if "methods" not in analysis_config: + analysis_config["methods"] = {} + + if report: + analysis_config["methods"]["report"] = { + "name": "report", + "title": "Analysis Report", + } + + if pre_training_methods: + analysis_config["methods"]["pre_training_bias"] = {"methods": pre_training_methods} + + if post_training_methods: + analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods} + + if explainability_config is not None: + explainability_methods = cls._merge_explainability_configs(explainability_config) + analysis_config["methods"] = { + **analysis_config["methods"], + **explainability_methods, + } return analysis_config + @classmethod + def _merge_explainability_configs( + cls, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], + ): + """Merges explainability configs, when more than one.""" + if isinstance(explainability_config, list): + explainability_methods = {} + if len(explainability_config) == 0: + raise ValueError("Please provide at least one explainability config.") + for config in explainability_config: + explain_config = config.get_explainability_config() + explainability_methods.update(explain_config) + if not len(explainability_methods) == len(explainability_config): + raise ValueError("Duplicate explainability configs are provided") + if ( + "shap" not in explainability_methods + and "features" not in explainability_methods["pdp"] + ): + raise ValueError("PDP features must be provided when ShapConfig is not provided") + return explainability_methods + if ( + isinstance(explainability_config, PDPConfig) + and "features" not in explainability_config.get_explainability_config()["pdp"] + ): + raise ValueError("PDP features must be provided when ShapConfig is not provided") + return explainability_config.get_explainability_config() + def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): """Uploads the local ``analysis_config_file`` to the ``s3_output_path``. diff --git a/tests/integ/test_clarify.py b/tests/integ/test_clarify.py index a107c00859..eaa75bce64 100644 --- a/tests/integ/test_clarify.py +++ b/tests/integ/test_clarify.py @@ -138,7 +138,9 @@ def data_path_no_label_index(training_set_no_label): def data_path_label_index(training_set_label_index): features, label, index = training_set_label_index data = pd.concat( - [pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(index)], axis=1, sort=False + [pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(index)], + axis=1, + sort=False, ) with tempfile.TemporaryDirectory() as tmpdirname: filename = os.path.join(tmpdirname, "train_label_index.csv") @@ -151,7 +153,12 @@ def data_path_label_index(training_set_label_index): def data_path_label_index_6col(training_set_label_index): features, label, index = training_set_label_index data = pd.concat( - [pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(features), pd.DataFrame(index)], + [ + pd.DataFrame(label), + pd.DataFrame(features), + pd.DataFrame(features), + pd.DataFrame(index), + ], axis=1, sort=False, ) @@ -551,7 +558,10 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sag def test_pre_training_bias_facets_not_included( - clarify_processor, data_config_facets_not_included, data_bias_config, sagemaker_session + clarify_processor, + data_config_facets_not_included, + data_bias_config, + sagemaker_session, ): with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): clarify_processor.run_pre_training_bias( @@ -643,7 +653,9 @@ def test_post_training_bias_facets_not_included_excluded_columns( <= 1.0 ) check_analysis_config( - data_config_facets_not_included_multiple_files, sagemaker_session, "post_training_bias" + data_config_facets_not_included_multiple_files, + sagemaker_session, + "post_training_bias", ) @@ -704,6 +716,50 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak check_analysis_config(data_config, sagemaker_session, "shap") +def test_bias_and_explainability( + clarify_processor, + data_config, + model_config, + shap_config, + data_bias_config, + sagemaker_session, +): + with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): + clarify_processor.run_bias_and_explainability( + data_config, + model_config, + shap_config, + data_bias_config, + pre_training_methods="all", + post_training_methods="all", + model_predicted_label_config="score", + job_name=utils.unique_name_from_base("clarify-bias-and-explainability"), + wait=True, + ) + analysis_result_json = s3.S3Downloader.read_file( + data_config.s3_output_path + "/analysis.json", + sagemaker_session, + ) + analysis_result = json.loads(analysis_result_json) + assert ( + math.fabs( + analysis_result["explanations"]["kernel_shap"]["label0"]["global_shap_values"]["F2"] + ) + <= 1 + ) + check_analysis_config(data_config, sagemaker_session, "shap") + + assert ( + math.fabs( + analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][ + "value" + ] + ) + <= 1.0 + ) + check_analysis_config(data_config, sagemaker_session, "post_training_bias") + + def check_analysis_config(data_config, sagemaker_session, method): analysis_config_json = s3.S3Downloader.read_file( data_config.s3_output_path + "/analysis_config.json", diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 7375657944..00bf036b5a 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -232,7 +232,8 @@ def test_invalid_bias_config(): # Two facets but only one value with pytest.raises( - ValueError, match="The number of facet names doesn't match the number of facet values" + ValueError, + match="The number of facet names doesn't match the number of facet values", ): BiasConfig( label_values_or_threshold=[1], @@ -295,7 +296,10 @@ def test_invalid_bias_config(): { "facet": [ {"name_or_index": "Feature1", "value_or_threshold": [1]}, - {"name_or_index": 1, "value_or_threshold": ["category1, category2"]}, + { + "name_or_index": 1, + "value_or_threshold": ["category1, category2"], + }, {"name_or_index": "Feature3", "value_or_threshold": [0.5]}, ], }, @@ -1094,7 +1098,9 @@ def test_explainability_with_invalid_config( "initial_instance_count": 1, } with pytest.raises( - AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'" + AttributeError, + match="analysis_config must have at least one working method: " + "One of the `pre_training_methods`, `post_training_methods`, `explainability_config`.", ): _run_test_explain( name_from_base, @@ -1320,6 +1326,86 @@ def test_analysis_config_generator_for_explainability(data_config, model_config) assert actual == expected +def test_analysis_config_generator_for_explainability_failing(data_config, model_config): + model_scores = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + with pytest.raises( + ValueError, + match="PDP features must be provided when ShapConfig is not provided", + ): + _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + PDPConfig(), + ) + + with pytest.raises(ValueError, match="Duplicate explainability configs are provided"): + _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + [SHAPConfig(), SHAPConfig()], + ) + + with pytest.raises( + AttributeError, + match="analysis_config must have at least one working method: " + "One of the " + "`pre_training_methods`, `post_training_methods`, `explainability_config`.", + ): + _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + [], + ) + + +def test_analysis_config_generator_for_bias_explainability( + data_config, data_bias_config, model_config +): + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias_and_explainability( + data_config, + model_config, + model_predicted_label_config, + [SHAPConfig(), PDPConfig()], + data_bias_config, + pre_training_methods="all", + post_training_methods="all", + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "pdp": {"grid_resolution": 15, "top_k_features": 10}, + "post_training_bias": {"methods": "all"}, + "pre_training_bias": {"methods": "all"}, + "report": {"name": "report", "title": "Analysis Report"}, + "shap": {"save_local_shap_values": True, "use_logit": False}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } + assert actual == expected + + def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config): actual = _AnalysisConfigGenerator.bias_pre_training( data_config, data_bias_config, methods="all"