Skip to content

change: make model_config optional when predicted labels are provided for bias detection #3596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,8 +1423,8 @@ def run_post_training_bias(
self,
data_config: DataConfig,
data_bias_config: BiasConfig,
model_config: ModelConfig,
model_predicted_label_config: ModelPredictedLabelConfig,
model_config: Optional[ModelConfig] = None,
model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
methods: Union[str, List[str]] = "all",
wait: bool = True,
logs: bool = True,
Expand All @@ -1444,7 +1444,8 @@ def run_post_training_bias(
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
endpoint to be created.
endpoint to be created. This is required unless``predicted_label_dataset_uri`` or
``predicted_label`` is provided in ``data_config``.
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
Config of how to extract the predicted label from the model output.
methods (str or list[str]): Selector of a subset of potential metrics:
Expand Down Expand Up @@ -1508,7 +1509,7 @@ def run_bias(
self,
data_config: DataConfig,
bias_config: BiasConfig,
model_config: ModelConfig,
model_config: Optional[ModelConfig] = None,
model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
pre_training_methods: Union[str, List[str]] = "all",
post_training_methods: Union[str, List[str]] = "all",
Expand All @@ -1529,7 +1530,8 @@ def run_bias(
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
endpoint to be created.
endpoint to be created. This is required unless``predicted_label_dataset_uri`` or
``predicted_label`` is provided in ``data_config``.
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
Config of how to extract the predicted label from the model output.
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
Expand Down Expand Up @@ -1930,16 +1932,30 @@ def _add_predictor(
):
"""Extends analysis config with predictor."""
analysis_config = {**analysis_config}
analysis_config["predictor"] = model_config.get_predictor_config()
if isinstance(model_config, ModelConfig):
analysis_config["predictor"] = model_config.get_predictor_config()
else:
if "shap" in analysis_config["methods"] or "pdp" in analysis_config["methods"]:
raise ValueError(
"model_config must be provided when explainability methods are selected."
)
if (
"predicted_label_dataset_uri" not in analysis_config
and "predicted_label" not in analysis_config
):
raise ValueError(
"model_config must be provided when `predicted_label_dataset_uri` or "
"`predicted_label` are not provided in data_config."
)
if isinstance(model_predicted_label_config, ModelPredictedLabelConfig):
(
probability_threshold,
predictor_config,
) = model_predicted_label_config.get_predictor_config()
if predictor_config:
if predictor_config and "predictor" in analysis_config:
analysis_config["predictor"].update(predictor_config)
_set(probability_threshold, "probability_threshold", analysis_config)
else:
elif "predictor" in analysis_config:
_set(model_predicted_label_config, "label", analysis_config["predictor"])
return analysis_config

Expand Down
64 changes: 64 additions & 0 deletions tests/integ/test_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,37 @@ def data_config_facets_not_included_pred_labels(
)


@pytest.fixture
def data_config_pred_labels(
sagemaker_session,
pred_data_path,
data_path,
headers,
pred_label_headers,
):
test_run = utils.unique_name_from_base("test_run")
output_path = "s3://{}/{}/{}".format(
sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run
)
pred_label_data_s3_uri = "s3://{}/{}/{}/{}".format(
sagemaker_session.default_bucket(),
"linear_learner_analysis_resources",
test_run,
"predicted_labels.csv",
)
_upload_dataset(pred_data_path, pred_label_data_s3_uri, sagemaker_session)
return DataConfig(
s3_data_input_path=data_path,
s3_output_path=output_path,
label="Label",
headers=headers,
dataset_type="text/csv",
predicted_label_dataset_uri=pred_label_data_s3_uri,
predicted_label_headers=pred_label_headers,
predicted_label="PredictedLabel",
)


@pytest.fixture(scope="module")
def data_bias_config():
return BiasConfig(
Expand Down Expand Up @@ -692,6 +723,39 @@ def test_post_training_bias_excluded_columns(
check_analysis_config(data_config_excluded_columns, sagemaker_session, "post_training_bias")


def test_post_training_bias_predicted_labels(
clarify_processor,
data_config_pred_labels,
data_bias_config,
model_predicted_label_config,
sagemaker_session,
):
model_config = None
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
clarify_processor.run_post_training_bias(
data_config_pred_labels,
data_bias_config,
model_config,
model_predicted_label_config,
job_name=utils.unique_name_from_base("clarify-posttraining-bias-pred-labels"),
wait=True,
)
analysis_result_json = s3.S3Downloader.read_file(
data_config_pred_labels.s3_output_path + "/analysis.json",
sagemaker_session,
)
analysis_result = json.loads(analysis_result_json)
assert (
math.fabs(
analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][
"value"
]
)
<= 1.0
)
check_analysis_config(data_config_pred_labels, sagemaker_session, "post_training_bias")


def test_shap(clarify_processor, data_config, model_config, shap_config, sagemaker_session):
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
clarify_processor.run_explainability(
Expand Down
86 changes: 86 additions & 0 deletions tests/unit/test_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,3 +1575,89 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model
},
}
assert actual == expected


def test_analysis_config_for_bias_no_model_config(data_bias_config):
s3_data_input_path = "s3://path/to/input.csv"
s3_output_path = "s3://path/to/output"
predicted_labels_uri = "s3://path/to/predicted_labels.csv"
label_name = "Label"
headers = [
"Label",
"F1",
"F2",
"F3",
"F4",
]
dataset_type = "text/csv"
data_config = DataConfig(
s3_data_input_path=s3_data_input_path,
s3_output_path=s3_output_path,
label=label_name,
headers=headers,
dataset_type=dataset_type,
predicted_label_dataset_uri=predicted_labels_uri,
predicted_label_headers=["PredictedLabel"],
predicted_label="PredictedLabel",
)
model_config = None
model_predicted_label_config = ModelPredictedLabelConfig(
probability="pr",
probability_threshold=0.8,
label_headers=["success"],
)
actual = _AnalysisConfigGenerator.bias(
data_config,
data_bias_config,
model_config,
model_predicted_label_config,
pre_training_methods="all",
post_training_methods="all",
)
expected = {
"dataset_type": "text/csv",
"headers": ["Label", "F1", "F2", "F3", "F4"],
"label": "Label",
"predicted_label_dataset_uri": "s3://path/to/predicted_labels.csv",
"predicted_label_headers": ["PredictedLabel"],
"predicted_label": "PredictedLabel",
"label_values_or_threshold": [1],
"facet": [{"name_or_index": "F1"}],
"group_variable": "F2",
"methods": {
"report": {"name": "report", "title": "Analysis Report"},
"pre_training_bias": {"methods": "all"},
"post_training_bias": {"methods": "all"},
},
"probability_threshold": 0.8,
}
assert actual == expected


def test_invalid_analysis_config(data_config, data_bias_config, model_config):
with pytest.raises(
ValueError, match="model_config must be provided when explainability methods are selected."
):
_AnalysisConfigGenerator.bias_and_explainability(
data_config=data_config,
model_config=None,
model_predicted_label_config=ModelPredictedLabelConfig(),
explainability_config=SHAPConfig(),
bias_config=data_bias_config,
pre_training_methods="all",
post_training_methods="all",
)

with pytest.raises(
ValueError,
match="model_config must be provided when `predicted_label_dataset_uri` or "
"`predicted_label` are not provided in data_config.",
):
_AnalysisConfigGenerator.bias(
data_config=data_config,
model_config=None,
model_predicted_label_config=ModelPredictedLabelConfig(),
bias_config=data_bias_config,
pre_training_methods="all",
post_training_methods="all",
)