diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6590d30514..3bc2071330 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -25,6 +25,8 @@ import tempfile from abc import ABC, abstractmethod +from typing import List, Union + from sagemaker import image_uris, s3, utils from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor @@ -922,6 +924,7 @@ def __init__( version (str): Clarify version to use. """ # noqa E501 # pylint: disable=c0301 container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version) + self._last_analysis_config = None self.job_name_prefix = job_name_prefix super(SageMakerClarifyProcessor, self).__init__( role, @@ -983,10 +986,10 @@ def _run( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ - analysis_config["methods"]["report"] = { - "name": "report", - "title": "Analysis Report", - } + # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket + self._last_analysis_config = analysis_config + logger.info("Analysis Config: %s", analysis_config) + with tempfile.TemporaryDirectory() as tmpdirname: analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") with open(analysis_config_file, "w") as f: @@ -1083,14 +1086,13 @@ def run_pre_training_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(data_bias_config.get_config()) - analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Pretraining-Bias") + analysis_config = _AnalysisConfigGenerator.bias_pre_training( + data_config, data_bias_config, methods + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Pretraining-Bias" + ) return self._run( data_config, analysis_config, @@ -1165,21 +1167,13 @@ def run_post_training_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(data_bias_config.get_config()) - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - predictor_config.update(model_config.get_predictor_config()) - analysis_config["methods"] = {"post_training_bias": {"methods": methods}} - analysis_config["predictor"] = predictor_config - _set(probability_threshold, "probability_threshold", analysis_config) - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Posttraining-Bias") + analysis_config = _AnalysisConfigGenerator.bias_post_training( + data_config, data_bias_config, model_predicted_label_config, methods, model_config + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Posttraining-Bias" + ) return self._run( data_config, analysis_config, @@ -1264,28 +1258,16 @@ def run_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(bias_config.get_config()) - analysis_config["predictor"] = model_config.get_predictor_config() - if model_predicted_label_config: - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - if predictor_config: - analysis_config["predictor"].update(predictor_config) - if probability_threshold is not None: - analysis_config["probability_threshold"] = probability_threshold - - analysis_config["methods"] = { - "pre_training_bias": {"methods": pre_training_methods}, - "post_training_bias": {"methods": post_training_methods}, - } - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Bias") + analysis_config = _AnalysisConfigGenerator.bias( + data_config, + bias_config, + model_config, + model_predicted_label_config, + pre_training_methods, + post_training_methods, + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias") return self._run( data_config, analysis_config, @@ -1370,6 +1352,36 @@ def run_explainability( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 + analysis_config = _AnalysisConfigGenerator.explainability( + data_config, model_config, model_scores, explainability_config + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Explainability" + ) + return self._run( + data_config, + analysis_config, + wait, + logs, + job_name, + kms_key, + experiment_config, + ) + + +class _AnalysisConfigGenerator: + """Creates analysis_config objects for different type of runs.""" + + @classmethod + def explainability( + cls, + data_config: DataConfig, + model_config: ModelConfig, + model_scores: ModelPredictedLabelConfig, + explainability_config: ExplainabilityConfig, + ): + """Generates a config for Explainability""" analysis_config = data_config.get_config() predictor_config = model_config.get_predictor_config() if isinstance(model_scores, ModelPredictedLabelConfig): @@ -1406,20 +1418,84 @@ def run_explainability( explainability_methods = explainability_config.get_explainability_config() analysis_config["methods"] = explainability_methods analysis_config["predictor"] = predictor_config - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Explainability") - return self._run( - data_config, - analysis_config, - wait, - logs, - job_name, - kms_key, - experiment_config, - ) + return cls._common(analysis_config) + + @classmethod + def bias_pre_training( + cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]] + ): + """Generates a config for Bias Pre Training""" + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "methods": {"pre_training_bias": {"methods": methods}}, + } + return cls._common(analysis_config) + + @classmethod + def bias_post_training( + cls, + data_config: DataConfig, + bias_config: BiasConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + methods: Union[str, List[str]], + model_config: ModelConfig, + ): + """Generates a config for Bias Post Training""" + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "predictor": {**model_config.get_predictor_config()}, + "methods": {"post_training_bias": {"methods": methods}}, + } + if model_predicted_label_config: + ( + probability_threshold, + predictor_config, + ) = model_predicted_label_config.get_predictor_config() + if predictor_config: + analysis_config["predictor"].update(predictor_config) + _set(probability_threshold, "probability_threshold", analysis_config) + return cls._common(analysis_config) + + @classmethod + def bias( + cls, + data_config: DataConfig, + bias_config: BiasConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + pre_training_methods: Union[str, List[str]] = "all", + post_training_methods: Union[str, List[str]] = "all", + ): + """Generates a config for Bias""" + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "predictor": model_config.get_predictor_config(), + "methods": { + "pre_training_bias": {"methods": pre_training_methods}, + "post_training_bias": {"methods": post_training_methods}, + }, + } + if model_predicted_label_config: + ( + probability_threshold, + predictor_config, + ) = model_predicted_label_config.get_predictor_config() + if predictor_config: + analysis_config["predictor"].update(predictor_config) + _set(probability_threshold, "probability_threshold", analysis_config) + return cls._common(analysis_config) + + @staticmethod + def _common(analysis_config): + """Extends analysis config with common values""" + analysis_config["methods"]["report"] = { + "name": "report", + "title": "Analysis Report", + } + return analysis_config def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index fa437573f0..7375657944 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -29,6 +29,7 @@ SHAPConfig, TextConfig, ImageConfig, + _AnalysisConfigGenerator, ) JOB_NAME_PREFIX = "my-prefix" @@ -764,7 +765,10 @@ def test_pre_training_bias( "label_values_or_threshold": [1], "facet": [{"name_or_index": "F1"}], "group_variable": "F2", - "methods": {"pre_training_bias": {"methods": "all"}}, + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "pre_training_bias": {"methods": "all"}, + }, } mock_method.assert_called_with( data_config, @@ -827,7 +831,10 @@ def test_post_training_bias( "joinsource_name_or_index": "F4", "facet": [{"name_or_index": "F1"}], "group_variable": "F2", - "methods": {"post_training_bias": {"methods": "all"}}, + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, + }, "predictor": { "model_name": "xgboost-model", "instance_type": "ml.c5.xlarge", @@ -985,7 +992,10 @@ def _run_test_explain( "grid_resolution": 20, "top_k_features": 10, } - expected_analysis_config["methods"] = expected_explanation_configs + expected_analysis_config["methods"] = { + "report": {"name": "report", "title": "Analysis Report"}, + **expected_explanation_configs, + } mock_method.assert_called_with( data_config, expected_analysis_config, @@ -1277,3 +1287,128 @@ def test_shap_with_image_config( expected_predictor_config, expected_image_config=expected_image_config, ) + + +def test_analysis_config_generator_for_explainability(data_config, model_config): + model_scores = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + SHAPConfig(), + ) + expected = { + "dataset_type": "text/csv", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "shap": {"save_local_shap_values": True, "use_logit": False}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } + assert actual == expected + + +def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config): + actual = _AnalysisConfigGenerator.bias_pre_training( + data_config, data_bias_config, methods="all" + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "pre_training_bias": {"methods": "all"}, + }, + } + assert actual == expected + + +def test_analysis_config_generator_for_bias_post_training( + data_config, data_bias_config, model_config +): + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias_post_training( + data_config, + data_bias_config, + model_predicted_label_config, + methods="all", + model_config=model_config, + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } + assert actual == expected + + +def test_analysis_config_generator_for_bias(data_config, data_bias_config, model_config): + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias( + data_config, + data_bias_config, + model_config, + model_predicted_label_config, + pre_training_methods="all", + post_training_methods="all", + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, + "pre_training_bias": {"methods": "all"}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } + assert actual == expected