diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 18fed12042..7d8c156410 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -1423,8 +1423,8 @@ def run_post_training_bias( self, data_config: DataConfig, data_bias_config: BiasConfig, - model_config: ModelConfig, - model_predicted_label_config: ModelPredictedLabelConfig, + model_config: Optional[ModelConfig] = None, + model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None, methods: Union[str, List[str]] = "all", wait: bool = True, logs: bool = True, @@ -1444,7 +1444,8 @@ def run_post_training_bias( data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its - endpoint to be created. + endpoint to be created. This is required unless``predicted_label_dataset_uri`` or + ``predicted_label`` is provided in ``data_config``. model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`): Config of how to extract the predicted label from the model output. methods (str or list[str]): Selector of a subset of potential metrics: @@ -1508,7 +1509,7 @@ def run_bias( self, data_config: DataConfig, bias_config: BiasConfig, - model_config: ModelConfig, + model_config: Optional[ModelConfig] = None, model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None, pre_training_methods: Union[str, List[str]] = "all", post_training_methods: Union[str, List[str]] = "all", @@ -1529,7 +1530,8 @@ def run_bias( data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its - endpoint to be created. + endpoint to be created. This is required unless``predicted_label_dataset_uri`` or + ``predicted_label`` is provided in ``data_config``. model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`): Config of how to extract the predicted label from the model output. pre_training_methods (str or list[str]): Selector of a subset of potential metrics: @@ -1930,16 +1932,30 @@ def _add_predictor( ): """Extends analysis config with predictor.""" analysis_config = {**analysis_config} - analysis_config["predictor"] = model_config.get_predictor_config() + if isinstance(model_config, ModelConfig): + analysis_config["predictor"] = model_config.get_predictor_config() + else: + if "shap" in analysis_config["methods"] or "pdp" in analysis_config["methods"]: + raise ValueError( + "model_config must be provided when explainability methods are selected." + ) + if ( + "predicted_label_dataset_uri" not in analysis_config + and "predicted_label" not in analysis_config + ): + raise ValueError( + "model_config must be provided when `predicted_label_dataset_uri` or " + "`predicted_label` are not provided in data_config." + ) if isinstance(model_predicted_label_config, ModelPredictedLabelConfig): ( probability_threshold, predictor_config, ) = model_predicted_label_config.get_predictor_config() - if predictor_config: + if predictor_config and "predictor" in analysis_config: analysis_config["predictor"].update(predictor_config) _set(probability_threshold, "probability_threshold", analysis_config) - else: + elif "predictor" in analysis_config: _set(model_predicted_label_config, "label", analysis_config["predictor"]) return analysis_config diff --git a/tests/integ/test_clarify.py b/tests/integ/test_clarify.py index eaa75bce64..bbef52c488 100644 --- a/tests/integ/test_clarify.py +++ b/tests/integ/test_clarify.py @@ -474,6 +474,37 @@ def data_config_facets_not_included_pred_labels( ) +@pytest.fixture +def data_config_pred_labels( + sagemaker_session, + pred_data_path, + data_path, + headers, + pred_label_headers, +): + test_run = utils.unique_name_from_base("test_run") + output_path = "s3://{}/{}/{}".format( + sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run + ) + pred_label_data_s3_uri = "s3://{}/{}/{}/{}".format( + sagemaker_session.default_bucket(), + "linear_learner_analysis_resources", + test_run, + "predicted_labels.csv", + ) + _upload_dataset(pred_data_path, pred_label_data_s3_uri, sagemaker_session) + return DataConfig( + s3_data_input_path=data_path, + s3_output_path=output_path, + label="Label", + headers=headers, + dataset_type="text/csv", + predicted_label_dataset_uri=pred_label_data_s3_uri, + predicted_label_headers=pred_label_headers, + predicted_label="PredictedLabel", + ) + + @pytest.fixture(scope="module") def data_bias_config(): return BiasConfig( @@ -692,6 +723,39 @@ def test_post_training_bias_excluded_columns( check_analysis_config(data_config_excluded_columns, sagemaker_session, "post_training_bias") +def test_post_training_bias_predicted_labels( + clarify_processor, + data_config_pred_labels, + data_bias_config, + model_predicted_label_config, + sagemaker_session, +): + model_config = None + with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): + clarify_processor.run_post_training_bias( + data_config_pred_labels, + data_bias_config, + model_config, + model_predicted_label_config, + job_name=utils.unique_name_from_base("clarify-posttraining-bias-pred-labels"), + wait=True, + ) + analysis_result_json = s3.S3Downloader.read_file( + data_config_pred_labels.s3_output_path + "/analysis.json", + sagemaker_session, + ) + analysis_result = json.loads(analysis_result_json) + assert ( + math.fabs( + analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][ + "value" + ] + ) + <= 1.0 + ) + check_analysis_config(data_config_pred_labels, sagemaker_session, "post_training_bias") + + def test_shap(clarify_processor, data_config, model_config, shap_config, sagemaker_session): with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): clarify_processor.run_explainability( diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index de482997ef..d1d05c41c3 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -1575,3 +1575,89 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model }, } assert actual == expected + + +def test_analysis_config_for_bias_no_model_config(data_bias_config): + s3_data_input_path = "s3://path/to/input.csv" + s3_output_path = "s3://path/to/output" + predicted_labels_uri = "s3://path/to/predicted_labels.csv" + label_name = "Label" + headers = [ + "Label", + "F1", + "F2", + "F3", + "F4", + ] + dataset_type = "text/csv" + data_config = DataConfig( + s3_data_input_path=s3_data_input_path, + s3_output_path=s3_output_path, + label=label_name, + headers=headers, + dataset_type=dataset_type, + predicted_label_dataset_uri=predicted_labels_uri, + predicted_label_headers=["PredictedLabel"], + predicted_label="PredictedLabel", + ) + model_config = None + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + probability_threshold=0.8, + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias( + data_config, + data_bias_config, + model_config, + model_predicted_label_config, + pre_training_methods="all", + post_training_methods="all", + ) + expected = { + "dataset_type": "text/csv", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "label": "Label", + "predicted_label_dataset_uri": "s3://path/to/predicted_labels.csv", + "predicted_label_headers": ["PredictedLabel"], + "predicted_label": "PredictedLabel", + "label_values_or_threshold": [1], + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "pre_training_bias": {"methods": "all"}, + "post_training_bias": {"methods": "all"}, + }, + "probability_threshold": 0.8, + } + assert actual == expected + + +def test_invalid_analysis_config(data_config, data_bias_config, model_config): + with pytest.raises( + ValueError, match="model_config must be provided when explainability methods are selected." + ): + _AnalysisConfigGenerator.bias_and_explainability( + data_config=data_config, + model_config=None, + model_predicted_label_config=ModelPredictedLabelConfig(), + explainability_config=SHAPConfig(), + bias_config=data_bias_config, + pre_training_methods="all", + post_training_methods="all", + ) + + with pytest.raises( + ValueError, + match="model_config must be provided when `predicted_label_dataset_uri` or " + "`predicted_label` are not provided in data_config.", + ): + _AnalysisConfigGenerator.bias( + data_config=data_config, + model_config=None, + model_predicted_label_config=ModelPredictedLabelConfig(), + bias_config=data_bias_config, + pre_training_methods="all", + post_training_methods="all", + )