From c0f1d794a922f497622ce730d74f56c2f7f78897 Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 1 Aug 2022 11:49:44 +0200 Subject: [PATCH 01/10] feature: extracted analysis config generation for explainability --- src/sagemaker/clarify.py | 55 +++++++++++++++++++++++++------------- tests/unit/test_clarify.py | 26 ++++++++++++++++++ 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6590d30514..b2637b9d94 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1370,6 +1370,36 @@ def run_explainability( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 + analysis_config = _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + explainability_config + ) + if job_name is None: + if self.job_name_prefix: + job_name = utils.name_from_base(self.job_name_prefix) + else: + job_name = utils.name_from_base("Clarify-Explainability") + return self._run( + data_config, + analysis_config, + wait, + logs, + job_name, + kms_key, + experiment_config, + ) + + +class _AnalysisConfigGenerator: + @staticmethod + def explainability( + data_config, + model_config, + model_scores, + explainability_config + ): analysis_config = data_config.get_config() predictor_config = model_config.get_predictor_config() if isinstance(model_scores, ModelPredictedLabelConfig): @@ -1392,34 +1422,21 @@ def run_explainability( if not len(explainability_methods.keys()) == len(explainability_config): raise ValueError("Duplicate explainability configs are provided") if ( - "shap" not in explainability_methods - and explainability_methods["pdp"].get("features", None) is None + "shap" not in explainability_methods + and explainability_methods["pdp"].get("features", None) is None ): raise ValueError("PDP features must be provided when ShapConfig is not provided") else: if ( - isinstance(explainability_config, PDPConfig) - and explainability_config.get_explainability_config()["pdp"].get("features", None) - is None + isinstance(explainability_config, PDPConfig) + and explainability_config.get_explainability_config()["pdp"].get("features", None) + is None ): raise ValueError("PDP features must be provided when ShapConfig is not provided") explainability_methods = explainability_config.get_explainability_config() analysis_config["methods"] = explainability_methods analysis_config["predictor"] = predictor_config - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Explainability") - return self._run( - data_config, - analysis_config, - wait, - logs, - job_name, - kms_key, - experiment_config, - ) + return analysis_config def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index fa437573f0..e440ad6eb0 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -29,6 +29,7 @@ SHAPConfig, TextConfig, ImageConfig, + _AnalysisConfigGenerator, ) JOB_NAME_PREFIX = "my-prefix" @@ -1277,3 +1278,28 @@ def test_shap_with_image_config( expected_predictor_config, expected_image_config=expected_image_config, ) + + +def test_analysis_config_generator_for_explainability(data_config, model_config): + model_scores = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + SHAPConfig(), + ) + expected = {'dataset_type': 'text/csv', + 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], + 'joinsource_name_or_index': 'F4', + 'label': 'Label', + 'methods': {'shap': {'save_local_shap_values': True, 'use_logit': False}}, + 'predictor': {'initial_instance_count': 1, + 'instance_type': 'ml.c5.xlarge', + 'label_headers': ['success'], + 'model_name': 'xgboost-model', + 'probability': 'pr'}} + assert actual == expected + From 4fa12c125b326a2db613a24aedb37bc4e8783d88 Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 1 Aug 2022 12:16:20 +0200 Subject: [PATCH 02/10] feature: extracted analysis config generation for bias pre_training --- src/sagemaker/clarify.py | 23 ++++++++++++++++++++--- tests/unit/test_clarify.py | 16 ++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index b2637b9d94..94eb234238 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1083,9 +1083,11 @@ def run_pre_training_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(data_bias_config.get_config()) - analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} + analysis_config = _AnalysisConfigGenerator.bias_pre_training( + data_config, + data_bias_config, + methods + ) if job_name is None: if self.job_name_prefix: job_name = utils.name_from_base(self.job_name_prefix) @@ -1438,6 +1440,21 @@ def explainability( analysis_config["predictor"] = predictor_config return analysis_config + @staticmethod + def bias_pre_training(data_config, data_bias_config, methods): + analysis_config = data_config.get_config() + analysis_config.update(data_bias_config.get_config()) + analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} + return analysis_config + + @staticmethod + def _common(analysis_config): + analysis_config["methods"]["report"] = { + "name": "report", + "title": "Analysis Report", + } + return analysis_config + def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): """Uploads the local ``analysis_config_file`` to the ``s3_output_path``. diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index e440ad6eb0..ed91eb9037 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -1303,3 +1303,19 @@ def test_analysis_config_generator_for_explainability(data_config, model_config) 'probability': 'pr'}} assert actual == expected + +def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config): + actual = _AnalysisConfigGenerator.bias_pre_training( + data_config, + data_bias_config, + methods="all" + ) + expected = {'dataset_type': 'text/csv', + 'facet': [{'name_or_index': 'F1'}], + 'group_variable': 'F2', + 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], + 'joinsource_name_or_index': 'F4', + 'label': 'Label', + 'label_values_or_threshold': [1], + 'methods': {'pre_training_bias': {'methods': 'all'}}} + assert actual == expected From cf4b08e70282de073295bc07b9fd2957fcf71bfd Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 1 Aug 2022 12:20:20 +0200 Subject: [PATCH 03/10] feature: extracted analysis config generation for bias post_training --- src/sagemaker/clarify.py | 37 +++++++++++++++++++++++++++---------- tests/unit/test_clarify.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 94eb234238..a2c99b9833 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1167,16 +1167,13 @@ def run_post_training_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(data_bias_config.get_config()) - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - predictor_config.update(model_config.get_predictor_config()) - analysis_config["methods"] = {"post_training_bias": {"methods": methods}} - analysis_config["predictor"] = predictor_config - _set(probability_threshold, "probability_threshold", analysis_config) + analysis_config = _AnalysisConfigGenerator.bias_post_training( + data_config, + data_bias_config, + model_predicted_label_config, + methods, + model_config + ) if job_name is None: if self.job_name_prefix: job_name = utils.name_from_base(self.job_name_prefix) @@ -1447,6 +1444,26 @@ def bias_pre_training(data_config, data_bias_config, methods): analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} return analysis_config + @staticmethod + def bias_post_training( + data_config, + data_bias_config, + model_predicted_label_config, + methods, + model_config + ): + analysis_config = data_config.get_config() + analysis_config.update(data_bias_config.get_config()) + analysis_config["methods"] = {"post_training_bias": {"methods": methods}} + ( + probability_threshold, + predictor_config, + ) = model_predicted_label_config.get_predictor_config() + predictor_config.update(model_config.get_predictor_config()) + analysis_config["predictor"] = predictor_config + _set(probability_threshold, "probability_threshold", analysis_config) + return analysis_config + @staticmethod def _common(analysis_config): analysis_config["methods"]["report"] = { diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index ed91eb9037..7430791da1 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -1319,3 +1319,31 @@ def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_ 'label_values_or_threshold': [1], 'methods': {'pre_training_bias': {'methods': 'all'}}} assert actual == expected + + +def test_analysis_config_generator_for_bias_post_training(data_config, data_bias_config, model_config): + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias_post_training( + data_config, + data_bias_config, + model_predicted_label_config, + methods="all", + model_config=model_config, + ) + expected = {'dataset_type': 'text/csv', + 'facet': [{'name_or_index': 'F1'}], + 'group_variable': 'F2', + 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], + 'joinsource_name_or_index': 'F4', + 'label': 'Label', + 'label_values_or_threshold': [1], + 'methods': {'post_training_bias': {'methods': 'all'}}, + 'predictor': {'initial_instance_count': 1, + 'instance_type': 'ml.c5.xlarge', + 'label_headers': ['success'], + 'model_name': 'xgboost-model', + 'probability': 'pr'}} + assert actual == expected From 3b2df62ac2f71ed47fb4b1b71af20d85e9894571 Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 1 Aug 2022 12:29:06 +0200 Subject: [PATCH 04/10] feature: extracted analysis config generation for bias --- src/sagemaker/clarify.py | 73 ++++++++++++++++++++++++-------------- tests/unit/test_clarify.py | 30 ++++++++++++++++ 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index a2c99b9833..5f6ca6ba59 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1034,7 +1034,7 @@ def _run( def run_pre_training_bias( self, data_config, - data_bias_config, + bias_config, methods="all", wait=True, logs=True, @@ -1049,7 +1049,7 @@ def run_pre_training_bias( Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. - data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. + bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. methods (str or list[str]): Selects a subset of potential metrics: ["`CI `_", "`DPL `_", @@ -1085,7 +1085,7 @@ def run_pre_training_bias( """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.bias_pre_training( data_config, - data_bias_config, + bias_config, methods ) if job_name is None: @@ -1106,7 +1106,7 @@ def run_pre_training_bias( def run_post_training_bias( self, data_config, - data_bias_config, + bias_config, model_config, model_predicted_label_config, methods="all", @@ -1126,7 +1126,7 @@ def run_post_training_bias( Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. - data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. + bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`): @@ -1169,7 +1169,7 @@ def run_post_training_bias( """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.bias_post_training( data_config, - data_bias_config, + bias_config, model_predicted_label_config, methods, model_config @@ -1263,23 +1263,14 @@ def run_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(bias_config.get_config()) - analysis_config["predictor"] = model_config.get_predictor_config() - if model_predicted_label_config: - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - if predictor_config: - analysis_config["predictor"].update(predictor_config) - if probability_threshold is not None: - analysis_config["probability_threshold"] = probability_threshold - - analysis_config["methods"] = { - "pre_training_bias": {"methods": pre_training_methods}, - "post_training_bias": {"methods": post_training_methods}, - } + analysis_config = _AnalysisConfigGenerator.bias( + data_config, + bias_config, + model_config, + model_predicted_label_config, + pre_training_methods, + post_training_methods, + ) if job_name is None: if self.job_name_prefix: job_name = utils.name_from_base(self.job_name_prefix) @@ -1438,22 +1429,22 @@ def explainability( return analysis_config @staticmethod - def bias_pre_training(data_config, data_bias_config, methods): + def bias_pre_training(data_config, bias_config, methods): analysis_config = data_config.get_config() - analysis_config.update(data_bias_config.get_config()) + analysis_config.update(bias_config.get_config()) analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} return analysis_config @staticmethod def bias_post_training( data_config, - data_bias_config, + bias_config, model_predicted_label_config, methods, model_config ): analysis_config = data_config.get_config() - analysis_config.update(data_bias_config.get_config()) + analysis_config.update(bias_config.get_config()) analysis_config["methods"] = {"post_training_bias": {"methods": methods}} ( probability_threshold, @@ -1464,6 +1455,34 @@ def bias_post_training( _set(probability_threshold, "probability_threshold", analysis_config) return analysis_config + @staticmethod + def bias( + data_config, + bias_config, + model_config, + model_predicted_label_config, + pre_training_methods="all", + post_training_methods="all", + ): + analysis_config = data_config.get_config() + analysis_config.update(bias_config.get_config()) + analysis_config["predictor"] = model_config.get_predictor_config() + if model_predicted_label_config: + ( + probability_threshold, + predictor_config, + ) = model_predicted_label_config.get_predictor_config() + if predictor_config: + analysis_config["predictor"].update(predictor_config) + if probability_threshold is not None: + analysis_config["probability_threshold"] = probability_threshold + + analysis_config["methods"] = { + "pre_training_bias": {"methods": pre_training_methods}, + "post_training_bias": {"methods": post_training_methods}, + } + return analysis_config + @staticmethod def _common(analysis_config): analysis_config["methods"]["report"] = { diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 7430791da1..f5929e5763 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -1347,3 +1347,33 @@ def test_analysis_config_generator_for_bias_post_training(data_config, data_bias 'model_name': 'xgboost-model', 'probability': 'pr'}} assert actual == expected + + +def test_analysis_config_generator_for_bias(data_config, data_bias_config, model_config): + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias( + data_config, + data_bias_config, + model_config, + model_predicted_label_config, + pre_training_methods="all", + post_training_methods="all", + ) + expected = {'dataset_type': 'text/csv', + 'facet': [{'name_or_index': 'F1'}], + 'group_variable': 'F2', + 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], + 'joinsource_name_or_index': 'F4', + 'label': 'Label', + 'label_values_or_threshold': [1], + 'methods': {'post_training_bias': {'methods': 'all'}, + 'pre_training_bias': {'methods': 'all'}}, + 'predictor': {'initial_instance_count': 1, + 'instance_type': 'ml.c5.xlarge', + 'label_headers': ['success'], + 'model_name': 'xgboost-model', + 'probability': 'pr'}} + assert actual == expected From 32650ee0362a470d9da48fa31a6466636c835c05 Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 1 Aug 2022 12:30:43 +0200 Subject: [PATCH 05/10] feature: simplified job_name creation --- src/sagemaker/clarify.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 5f6ca6ba59..40b40ac508 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1088,11 +1088,8 @@ def run_pre_training_bias( bias_config, methods ) - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Pretraining-Bias") + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Pretraining-Bias") return self._run( data_config, analysis_config, @@ -1174,11 +1171,8 @@ def run_post_training_bias( methods, model_config ) - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Posttraining-Bias") + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Posttraining-Bias") return self._run( data_config, analysis_config, @@ -1271,11 +1265,8 @@ def run_bias( pre_training_methods, post_training_methods, ) - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Bias") + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias") return self._run( data_config, analysis_config, @@ -1366,11 +1357,8 @@ def run_explainability( model_scores, explainability_config ) - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Explainability") + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Explainability") return self._run( data_config, analysis_config, From 88b1f4d00b3b583b4f5d8cb12018f3963cf84fd9 Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 1 Aug 2022 13:29:12 +0200 Subject: [PATCH 06/10] feature: extended analysis config generator methods with common logic --- src/sagemaker/clarify.py | 35 +++++++++++++++++------------------ tests/unit/test_clarify.py | 36 ++++++++++++++++++++++++++++-------- 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 40b40ac508..6486a7cf8a 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -983,10 +983,6 @@ def _run( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ - analysis_config["methods"]["report"] = { - "name": "report", - "title": "Analysis Report", - } with tempfile.TemporaryDirectory() as tmpdirname: analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") with open(analysis_config_file, "w") as f: @@ -1371,8 +1367,9 @@ def run_explainability( class _AnalysisConfigGenerator: - @staticmethod + @classmethod def explainability( + cls, data_config, model_config, model_scores, @@ -1414,22 +1411,23 @@ def explainability( explainability_methods = explainability_config.get_explainability_config() analysis_config["methods"] = explainability_methods analysis_config["predictor"] = predictor_config - return analysis_config + return cls._common(analysis_config) - @staticmethod - def bias_pre_training(data_config, bias_config, methods): + @classmethod + def bias_pre_training(cls, data_config, bias_config, methods): analysis_config = data_config.get_config() analysis_config.update(bias_config.get_config()) analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} - return analysis_config + return cls._common(analysis_config) - @staticmethod + @classmethod def bias_post_training( - data_config, - bias_config, - model_predicted_label_config, - methods, - model_config + cls, + data_config, + bias_config, + model_predicted_label_config, + methods, + model_config ): analysis_config = data_config.get_config() analysis_config.update(bias_config.get_config()) @@ -1441,10 +1439,11 @@ def bias_post_training( predictor_config.update(model_config.get_predictor_config()) analysis_config["predictor"] = predictor_config _set(probability_threshold, "probability_threshold", analysis_config) - return analysis_config + return cls._common(analysis_config) - @staticmethod + @classmethod def bias( + cls, data_config, bias_config, model_config, @@ -1469,7 +1468,7 @@ def bias( "pre_training_bias": {"methods": pre_training_methods}, "post_training_bias": {"methods": post_training_methods}, } - return analysis_config + return cls._common(analysis_config) @staticmethod def _common(analysis_config): diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index f5929e5763..6e2e0e98ec 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -765,7 +765,10 @@ def test_pre_training_bias( "label_values_or_threshold": [1], "facet": [{"name_or_index": "F1"}], "group_variable": "F2", - "methods": {"pre_training_bias": {"methods": "all"}}, + "methods": { + 'report': {'name': 'report', 'title': 'Analysis Report'}, + "pre_training_bias": {"methods": "all"} + }, } mock_method.assert_called_with( data_config, @@ -828,7 +831,10 @@ def test_post_training_bias( "joinsource_name_or_index": "F4", "facet": [{"name_or_index": "F1"}], "group_variable": "F2", - "methods": {"post_training_bias": {"methods": "all"}}, + "methods": { + 'report': {'name': 'report', 'title': 'Analysis Report'}, + "post_training_bias": {"methods": "all"} + }, "predictor": { "model_name": "xgboost-model", "instance_type": "ml.c5.xlarge", @@ -986,7 +992,10 @@ def _run_test_explain( "grid_resolution": 20, "top_k_features": 10, } - expected_analysis_config["methods"] = expected_explanation_configs + expected_analysis_config["methods"] = { + 'report': {'name': 'report', 'title': 'Analysis Report'}, + **expected_explanation_configs, + } mock_method.assert_called_with( data_config, expected_analysis_config, @@ -1295,7 +1304,10 @@ def test_analysis_config_generator_for_explainability(data_config, model_config) 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], 'joinsource_name_or_index': 'F4', 'label': 'Label', - 'methods': {'shap': {'save_local_shap_values': True, 'use_logit': False}}, + 'methods': { + 'report': {'name': 'report', 'title': 'Analysis Report'}, + 'shap': {'save_local_shap_values': True, 'use_logit': False} + }, 'predictor': {'initial_instance_count': 1, 'instance_type': 'ml.c5.xlarge', 'label_headers': ['success'], @@ -1317,7 +1329,10 @@ def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_ 'joinsource_name_or_index': 'F4', 'label': 'Label', 'label_values_or_threshold': [1], - 'methods': {'pre_training_bias': {'methods': 'all'}}} + 'methods': { + 'report': {'name': 'report', 'title': 'Analysis Report'}, + 'pre_training_bias': {'methods': 'all'}} + } assert actual == expected @@ -1340,7 +1355,10 @@ def test_analysis_config_generator_for_bias_post_training(data_config, data_bias 'joinsource_name_or_index': 'F4', 'label': 'Label', 'label_values_or_threshold': [1], - 'methods': {'post_training_bias': {'methods': 'all'}}, + 'methods': { + 'report': {'name': 'report', 'title': 'Analysis Report'}, + 'post_training_bias': {'methods': 'all'} + }, 'predictor': {'initial_instance_count': 1, 'instance_type': 'ml.c5.xlarge', 'label_headers': ['success'], @@ -1369,8 +1387,10 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model 'joinsource_name_or_index': 'F4', 'label': 'Label', 'label_values_or_threshold': [1], - 'methods': {'post_training_bias': {'methods': 'all'}, - 'pre_training_bias': {'methods': 'all'}}, + 'methods': { + 'report': {'name': 'report', 'title': 'Analysis Report'}, + 'post_training_bias': {'methods': 'all'}, + 'pre_training_bias': {'methods': 'all'}}, 'predictor': {'initial_instance_count': 1, 'instance_type': 'ml.c5.xlarge', 'label_headers': ['success'], From 71989f9d13a67df8eacb36ba1c90dae252eafe06 Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 1 Aug 2022 13:37:48 +0200 Subject: [PATCH 07/10] feature: refactored _AnalysisConfigGenerator methods --- src/sagemaker/clarify.py | 67 +++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6486a7cf8a..24a99d2773 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1367,6 +1367,9 @@ def run_explainability( class _AnalysisConfigGenerator: + """ + Creates analysis_config objects for different type of runs. + """ @classmethod def explainability( cls, @@ -1397,15 +1400,15 @@ def explainability( if not len(explainability_methods.keys()) == len(explainability_config): raise ValueError("Duplicate explainability configs are provided") if ( - "shap" not in explainability_methods - and explainability_methods["pdp"].get("features", None) is None + "shap" not in explainability_methods + and explainability_methods["pdp"].get("features", None) is None ): raise ValueError("PDP features must be provided when ShapConfig is not provided") else: if ( - isinstance(explainability_config, PDPConfig) - and explainability_config.get_explainability_config()["pdp"].get("features", None) - is None + isinstance(explainability_config, PDPConfig) + and explainability_config.get_explainability_config()["pdp"].get("features", None) + is None ): raise ValueError("PDP features must be provided when ShapConfig is not provided") explainability_methods = explainability_config.get_explainability_config() @@ -1415,9 +1418,11 @@ def explainability( @classmethod def bias_pre_training(cls, data_config, bias_config, methods): - analysis_config = data_config.get_config() - analysis_config.update(bias_config.get_config()) - analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "methods": {"pre_training_bias": {"methods": methods}} + } return cls._common(analysis_config) @classmethod @@ -1429,16 +1434,17 @@ def bias_post_training( methods, model_config ): - analysis_config = data_config.get_config() - analysis_config.update(bias_config.get_config()) - analysis_config["methods"] = {"post_training_bias": {"methods": methods}} - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - predictor_config.update(model_config.get_predictor_config()) - analysis_config["predictor"] = predictor_config - _set(probability_threshold, "probability_threshold", analysis_config) + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "predictor": {**model_config.get_predictor_config()}, + "methods": {"post_training_bias": {"methods": methods}}, + } + if model_predicted_label_config: + probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config() + if predictor_config: + analysis_config["predictor"].update(predictor_config) + _set(probability_threshold, "probability_threshold", analysis_config) return cls._common(analysis_config) @classmethod @@ -1451,23 +1457,20 @@ def bias( pre_training_methods="all", post_training_methods="all", ): - analysis_config = data_config.get_config() - analysis_config.update(bias_config.get_config()) - analysis_config["predictor"] = model_config.get_predictor_config() + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "predictor": model_config.get_predictor_config(), + "methods": { + "pre_training_bias": {"methods": pre_training_methods}, + "post_training_bias": {"methods": post_training_methods}, + } + } if model_predicted_label_config: - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() + probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config() if predictor_config: analysis_config["predictor"].update(predictor_config) - if probability_threshold is not None: - analysis_config["probability_threshold"] = probability_threshold - - analysis_config["methods"] = { - "pre_training_bias": {"methods": pre_training_methods}, - "post_training_bias": {"methods": post_training_methods}, - } + _set(probability_threshold, "probability_threshold", analysis_config) return cls._common(analysis_config) @staticmethod From e1fb7a2fbd241c3b84fcc0becaadd98f4818472a Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 1 Aug 2022 14:09:28 +0200 Subject: [PATCH 08/10] feature: added _last_analysis_config in SageMakerClarifyProcessor --- src/sagemaker/clarify.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 24a99d2773..40e56dd5ca 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -983,6 +983,10 @@ def _run( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ + # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket + self._last_analysis_config = analysis_config + logger.info("Analysis Config: ", analysis_config) + with tempfile.TemporaryDirectory() as tmpdirname: analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") with open(analysis_config_file, "w") as f: From ee04c97ad8be26a0446345c1298a7fd6316cc25a Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Wed, 3 Aug 2022 11:11:31 +0200 Subject: [PATCH 09/10] added data types in _AnalysisConfigGenerator methods --- src/sagemaker/clarify.py | 46 ++++++++++++++++++++------------------ tests/unit/test_clarify.py | 28 +++++++++++------------ 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 40e56dd5ca..d55c8dab2a 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -25,6 +25,8 @@ import tempfile from abc import ABC, abstractmethod +from typing import List, Union + from sagemaker import image_uris, s3, utils from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor @@ -1034,7 +1036,7 @@ def _run( def run_pre_training_bias( self, data_config, - bias_config, + data_bias_config, methods="all", wait=True, logs=True, @@ -1049,7 +1051,7 @@ def run_pre_training_bias( Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. - bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. + data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. methods (str or list[str]): Selects a subset of potential metrics: ["`CI `_", "`DPL `_", @@ -1085,7 +1087,7 @@ def run_pre_training_bias( """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.bias_pre_training( data_config, - bias_config, + data_bias_config, methods ) # when name is either not provided (is None) or an empty string ("") @@ -1103,7 +1105,7 @@ def run_pre_training_bias( def run_post_training_bias( self, data_config, - bias_config, + data_bias_config, model_config, model_predicted_label_config, methods="all", @@ -1123,7 +1125,7 @@ def run_post_training_bias( Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. - bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. + data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`): @@ -1166,7 +1168,7 @@ def run_post_training_bias( """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.bias_post_training( data_config, - bias_config, + data_bias_config, model_predicted_label_config, methods, model_config @@ -1377,10 +1379,10 @@ class _AnalysisConfigGenerator: @classmethod def explainability( cls, - data_config, - model_config, - model_scores, - explainability_config + data_config: DataConfig, + model_config: ModelConfig, + model_scores: ModelPredictedLabelConfig, + explainability_config: ExplainabilityConfig, ): analysis_config = data_config.get_config() predictor_config = model_config.get_predictor_config() @@ -1421,7 +1423,7 @@ def explainability( return cls._common(analysis_config) @classmethod - def bias_pre_training(cls, data_config, bias_config, methods): + def bias_pre_training(cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]): analysis_config = { **data_config.get_config(), **bias_config.get_config(), @@ -1432,11 +1434,11 @@ def bias_pre_training(cls, data_config, bias_config, methods): @classmethod def bias_post_training( cls, - data_config, - bias_config, - model_predicted_label_config, - methods, - model_config + data_config: DataConfig, + bias_config: BiasConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + methods: Union[str, List[str]], + model_config: ModelConfig, ): analysis_config = { **data_config.get_config(), @@ -1454,12 +1456,12 @@ def bias_post_training( @classmethod def bias( cls, - data_config, - bias_config, - model_config, - model_predicted_label_config, - pre_training_methods="all", - post_training_methods="all", + data_config: DataConfig, + bias_config: BiasConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + pre_training_methods: Union[str, List[str]] = "all", + post_training_methods: Union[str, List[str]] = "all", ): analysis_config = { **data_config.get_config(), diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 6e2e0e98ec..ebfebee1ad 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -1382,18 +1382,18 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model ) expected = {'dataset_type': 'text/csv', 'facet': [{'name_or_index': 'F1'}], - 'group_variable': 'F2', - 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], - 'joinsource_name_or_index': 'F4', - 'label': 'Label', - 'label_values_or_threshold': [1], - 'methods': { - 'report': {'name': 'report', 'title': 'Analysis Report'}, - 'post_training_bias': {'methods': 'all'}, - 'pre_training_bias': {'methods': 'all'}}, - 'predictor': {'initial_instance_count': 1, - 'instance_type': 'ml.c5.xlarge', - 'label_headers': ['success'], - 'model_name': 'xgboost-model', - 'probability': 'pr'}} + 'group_variable': 'F2', + 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], + 'joinsource_name_or_index': 'F4', + 'label': 'Label', + 'label_values_or_threshold': [1], + 'methods': { + 'report': {'name': 'report', 'title': 'Analysis Report'}, + 'post_training_bias': {'methods': 'all'}, + 'pre_training_bias': {'methods': 'all'}}, + 'predictor': {'initial_instance_count': 1, + 'instance_type': 'ml.c5.xlarge', + 'label_headers': ['success'], + 'model_name': 'xgboost-model', + 'probability': 'pr'}} assert actual == expected From b50a44b6051e53682bc7b5f01376cf508b81a8c3 Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Mon, 8 Aug 2022 12:02:55 +0200 Subject: [PATCH 10/10] applied style formatting to fix build issues --- src/sagemaker/clarify.py | 58 +++++++++------ tests/unit/test_clarify.py | 145 ++++++++++++++++++++----------------- 2 files changed, 114 insertions(+), 89 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index d55c8dab2a..3bc2071330 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -924,6 +924,7 @@ def __init__( version (str): Clarify version to use. """ # noqa E501 # pylint: disable=c0301 container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version) + self._last_analysis_config = None self.job_name_prefix = job_name_prefix super(SageMakerClarifyProcessor, self).__init__( role, @@ -987,7 +988,7 @@ def _run( """ # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket self._last_analysis_config = analysis_config - logger.info("Analysis Config: ", analysis_config) + logger.info("Analysis Config: %s", analysis_config) with tempfile.TemporaryDirectory() as tmpdirname: analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") @@ -1086,12 +1087,12 @@ def run_pre_training_bias( * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.bias_pre_training( - data_config, - data_bias_config, - methods + data_config, data_bias_config, methods ) # when name is either not provided (is None) or an empty string ("") - job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Pretraining-Bias") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Pretraining-Bias" + ) return self._run( data_config, analysis_config, @@ -1167,14 +1168,12 @@ def run_post_training_bias( * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.bias_post_training( - data_config, - data_bias_config, - model_predicted_label_config, - methods, - model_config + data_config, data_bias_config, model_predicted_label_config, methods, model_config ) # when name is either not provided (is None) or an empty string ("") - job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Posttraining-Bias") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Posttraining-Bias" + ) return self._run( data_config, analysis_config, @@ -1354,13 +1353,12 @@ def run_explainability( * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.explainability( - data_config, - model_config, - model_scores, - explainability_config + data_config, model_config, model_scores, explainability_config ) # when name is either not provided (is None) or an empty string ("") - job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Explainability") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Explainability" + ) return self._run( data_config, analysis_config, @@ -1373,9 +1371,8 @@ def run_explainability( class _AnalysisConfigGenerator: - """ - Creates analysis_config objects for different type of runs. - """ + """Creates analysis_config objects for different type of runs.""" + @classmethod def explainability( cls, @@ -1384,6 +1381,7 @@ def explainability( model_scores: ModelPredictedLabelConfig, explainability_config: ExplainabilityConfig, ): + """Generates a config for Explainability""" analysis_config = data_config.get_config() predictor_config = model_config.get_predictor_config() if isinstance(model_scores, ModelPredictedLabelConfig): @@ -1423,11 +1421,14 @@ def explainability( return cls._common(analysis_config) @classmethod - def bias_pre_training(cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]): + def bias_pre_training( + cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]] + ): + """Generates a config for Bias Pre Training""" analysis_config = { **data_config.get_config(), **bias_config.get_config(), - "methods": {"pre_training_bias": {"methods": methods}} + "methods": {"pre_training_bias": {"methods": methods}}, } return cls._common(analysis_config) @@ -1440,6 +1441,7 @@ def bias_post_training( methods: Union[str, List[str]], model_config: ModelConfig, ): + """Generates a config for Bias Post Training""" analysis_config = { **data_config.get_config(), **bias_config.get_config(), @@ -1447,7 +1449,10 @@ def bias_post_training( "methods": {"post_training_bias": {"methods": methods}}, } if model_predicted_label_config: - probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config() + ( + probability_threshold, + predictor_config, + ) = model_predicted_label_config.get_predictor_config() if predictor_config: analysis_config["predictor"].update(predictor_config) _set(probability_threshold, "probability_threshold", analysis_config) @@ -1463,6 +1468,7 @@ def bias( pre_training_methods: Union[str, List[str]] = "all", post_training_methods: Union[str, List[str]] = "all", ): + """Generates a config for Bias""" analysis_config = { **data_config.get_config(), **bias_config.get_config(), @@ -1470,10 +1476,13 @@ def bias( "methods": { "pre_training_bias": {"methods": pre_training_methods}, "post_training_bias": {"methods": post_training_methods}, - } + }, } if model_predicted_label_config: - probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config() + ( + probability_threshold, + predictor_config, + ) = model_predicted_label_config.get_predictor_config() if predictor_config: analysis_config["predictor"].update(predictor_config) _set(probability_threshold, "probability_threshold", analysis_config) @@ -1481,6 +1490,7 @@ def bias( @staticmethod def _common(analysis_config): + """Extends analysis config with common values""" analysis_config["methods"]["report"] = { "name": "report", "title": "Analysis Report", diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index ebfebee1ad..7375657944 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -766,8 +766,8 @@ def test_pre_training_bias( "facet": [{"name_or_index": "F1"}], "group_variable": "F2", "methods": { - 'report': {'name': 'report', 'title': 'Analysis Report'}, - "pre_training_bias": {"methods": "all"} + "report": {"name": "report", "title": "Analysis Report"}, + "pre_training_bias": {"methods": "all"}, }, } mock_method.assert_called_with( @@ -832,8 +832,8 @@ def test_post_training_bias( "facet": [{"name_or_index": "F1"}], "group_variable": "F2", "methods": { - 'report': {'name': 'report', 'title': 'Analysis Report'}, - "post_training_bias": {"methods": "all"} + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, }, "predictor": { "model_name": "xgboost-model", @@ -993,7 +993,7 @@ def _run_test_explain( "top_k_features": 10, } expected_analysis_config["methods"] = { - 'report': {'name': 'report', 'title': 'Analysis Report'}, + "report": {"name": "report", "title": "Analysis Report"}, **expected_explanation_configs, } mock_method.assert_called_with( @@ -1300,43 +1300,49 @@ def test_analysis_config_generator_for_explainability(data_config, model_config) model_scores, SHAPConfig(), ) - expected = {'dataset_type': 'text/csv', - 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], - 'joinsource_name_or_index': 'F4', - 'label': 'Label', - 'methods': { - 'report': {'name': 'report', 'title': 'Analysis Report'}, - 'shap': {'save_local_shap_values': True, 'use_logit': False} - }, - 'predictor': {'initial_instance_count': 1, - 'instance_type': 'ml.c5.xlarge', - 'label_headers': ['success'], - 'model_name': 'xgboost-model', - 'probability': 'pr'}} + expected = { + "dataset_type": "text/csv", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "shap": {"save_local_shap_values": True, "use_logit": False}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } assert actual == expected def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config): actual = _AnalysisConfigGenerator.bias_pre_training( - data_config, - data_bias_config, - methods="all" + data_config, data_bias_config, methods="all" ) - expected = {'dataset_type': 'text/csv', - 'facet': [{'name_or_index': 'F1'}], - 'group_variable': 'F2', - 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], - 'joinsource_name_or_index': 'F4', - 'label': 'Label', - 'label_values_or_threshold': [1], - 'methods': { - 'report': {'name': 'report', 'title': 'Analysis Report'}, - 'pre_training_bias': {'methods': 'all'}} - } + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "pre_training_bias": {"methods": "all"}, + }, + } assert actual == expected -def test_analysis_config_generator_for_bias_post_training(data_config, data_bias_config, model_config): +def test_analysis_config_generator_for_bias_post_training( + data_config, data_bias_config, model_config +): model_predicted_label_config = ModelPredictedLabelConfig( probability="pr", label_headers=["success"], @@ -1348,22 +1354,26 @@ def test_analysis_config_generator_for_bias_post_training(data_config, data_bias methods="all", model_config=model_config, ) - expected = {'dataset_type': 'text/csv', - 'facet': [{'name_or_index': 'F1'}], - 'group_variable': 'F2', - 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], - 'joinsource_name_or_index': 'F4', - 'label': 'Label', - 'label_values_or_threshold': [1], - 'methods': { - 'report': {'name': 'report', 'title': 'Analysis Report'}, - 'post_training_bias': {'methods': 'all'} - }, - 'predictor': {'initial_instance_count': 1, - 'instance_type': 'ml.c5.xlarge', - 'label_headers': ['success'], - 'model_name': 'xgboost-model', - 'probability': 'pr'}} + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } assert actual == expected @@ -1380,20 +1390,25 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model pre_training_methods="all", post_training_methods="all", ) - expected = {'dataset_type': 'text/csv', - 'facet': [{'name_or_index': 'F1'}], - 'group_variable': 'F2', - 'headers': ['Label', 'F1', 'F2', 'F3', 'F4'], - 'joinsource_name_or_index': 'F4', - 'label': 'Label', - 'label_values_or_threshold': [1], - 'methods': { - 'report': {'name': 'report', 'title': 'Analysis Report'}, - 'post_training_bias': {'methods': 'all'}, - 'pre_training_bias': {'methods': 'all'}}, - 'predictor': {'initial_instance_count': 1, - 'instance_type': 'ml.c5.xlarge', - 'label_headers': ['success'], - 'model_name': 'xgboost-model', - 'probability': 'pr'}} + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, + "pre_training_bias": {"methods": "all"}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } assert actual == expected