Skip to content

Commit 1c9dd1d

Browse files
authored
change: make model_config optional when predicted labels are provided for bias detection (#3596)
1 parent 2bcd643 commit 1c9dd1d

File tree

3 files changed

+174
-8
lines changed

3 files changed

+174
-8
lines changed

src/sagemaker/clarify.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,8 +1423,8 @@ def run_post_training_bias(
14231423
self,
14241424
data_config: DataConfig,
14251425
data_bias_config: BiasConfig,
1426-
model_config: ModelConfig,
1427-
model_predicted_label_config: ModelPredictedLabelConfig,
1426+
model_config: Optional[ModelConfig] = None,
1427+
model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
14281428
methods: Union[str, List[str]] = "all",
14291429
wait: bool = True,
14301430
logs: bool = True,
@@ -1444,7 +1444,8 @@ def run_post_training_bias(
14441444
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
14451445
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
14461446
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1447-
endpoint to be created.
1447+
endpoint to be created. This is required unless``predicted_label_dataset_uri`` or
1448+
``predicted_label`` is provided in ``data_config``.
14481449
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
14491450
Config of how to extract the predicted label from the model output.
14501451
methods (str or list[str]): Selector of a subset of potential metrics:
@@ -1508,7 +1509,7 @@ def run_bias(
15081509
self,
15091510
data_config: DataConfig,
15101511
bias_config: BiasConfig,
1511-
model_config: ModelConfig,
1512+
model_config: Optional[ModelConfig] = None,
15121513
model_predicted_label_config: Optional[ModelPredictedLabelConfig] = None,
15131514
pre_training_methods: Union[str, List[str]] = "all",
15141515
post_training_methods: Union[str, List[str]] = "all",
@@ -1529,7 +1530,8 @@ def run_bias(
15291530
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
15301531
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
15311532
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1532-
endpoint to be created.
1533+
endpoint to be created. This is required unless``predicted_label_dataset_uri`` or
1534+
``predicted_label`` is provided in ``data_config``.
15331535
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
15341536
Config of how to extract the predicted label from the model output.
15351537
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
@@ -1930,16 +1932,30 @@ def _add_predictor(
19301932
):
19311933
"""Extends analysis config with predictor."""
19321934
analysis_config = {**analysis_config}
1933-
analysis_config["predictor"] = model_config.get_predictor_config()
1935+
if isinstance(model_config, ModelConfig):
1936+
analysis_config["predictor"] = model_config.get_predictor_config()
1937+
else:
1938+
if "shap" in analysis_config["methods"] or "pdp" in analysis_config["methods"]:
1939+
raise ValueError(
1940+
"model_config must be provided when explainability methods are selected."
1941+
)
1942+
if (
1943+
"predicted_label_dataset_uri" not in analysis_config
1944+
and "predicted_label" not in analysis_config
1945+
):
1946+
raise ValueError(
1947+
"model_config must be provided when `predicted_label_dataset_uri` or "
1948+
"`predicted_label` are not provided in data_config."
1949+
)
19341950
if isinstance(model_predicted_label_config, ModelPredictedLabelConfig):
19351951
(
19361952
probability_threshold,
19371953
predictor_config,
19381954
) = model_predicted_label_config.get_predictor_config()
1939-
if predictor_config:
1955+
if predictor_config and "predictor" in analysis_config:
19401956
analysis_config["predictor"].update(predictor_config)
19411957
_set(probability_threshold, "probability_threshold", analysis_config)
1942-
else:
1958+
elif "predictor" in analysis_config:
19431959
_set(model_predicted_label_config, "label", analysis_config["predictor"])
19441960
return analysis_config
19451961

tests/integ/test_clarify.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,37 @@ def data_config_facets_not_included_pred_labels(
474474
)
475475

476476

477+
@pytest.fixture
478+
def data_config_pred_labels(
479+
sagemaker_session,
480+
pred_data_path,
481+
data_path,
482+
headers,
483+
pred_label_headers,
484+
):
485+
test_run = utils.unique_name_from_base("test_run")
486+
output_path = "s3://{}/{}/{}".format(
487+
sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run
488+
)
489+
pred_label_data_s3_uri = "s3://{}/{}/{}/{}".format(
490+
sagemaker_session.default_bucket(),
491+
"linear_learner_analysis_resources",
492+
test_run,
493+
"predicted_labels.csv",
494+
)
495+
_upload_dataset(pred_data_path, pred_label_data_s3_uri, sagemaker_session)
496+
return DataConfig(
497+
s3_data_input_path=data_path,
498+
s3_output_path=output_path,
499+
label="Label",
500+
headers=headers,
501+
dataset_type="text/csv",
502+
predicted_label_dataset_uri=pred_label_data_s3_uri,
503+
predicted_label_headers=pred_label_headers,
504+
predicted_label="PredictedLabel",
505+
)
506+
507+
477508
@pytest.fixture(scope="module")
478509
def data_bias_config():
479510
return BiasConfig(
@@ -692,6 +723,39 @@ def test_post_training_bias_excluded_columns(
692723
check_analysis_config(data_config_excluded_columns, sagemaker_session, "post_training_bias")
693724

694725

726+
def test_post_training_bias_predicted_labels(
727+
clarify_processor,
728+
data_config_pred_labels,
729+
data_bias_config,
730+
model_predicted_label_config,
731+
sagemaker_session,
732+
):
733+
model_config = None
734+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
735+
clarify_processor.run_post_training_bias(
736+
data_config_pred_labels,
737+
data_bias_config,
738+
model_config,
739+
model_predicted_label_config,
740+
job_name=utils.unique_name_from_base("clarify-posttraining-bias-pred-labels"),
741+
wait=True,
742+
)
743+
analysis_result_json = s3.S3Downloader.read_file(
744+
data_config_pred_labels.s3_output_path + "/analysis.json",
745+
sagemaker_session,
746+
)
747+
analysis_result = json.loads(analysis_result_json)
748+
assert (
749+
math.fabs(
750+
analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][
751+
"value"
752+
]
753+
)
754+
<= 1.0
755+
)
756+
check_analysis_config(data_config_pred_labels, sagemaker_session, "post_training_bias")
757+
758+
695759
def test_shap(clarify_processor, data_config, model_config, shap_config, sagemaker_session):
696760
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
697761
clarify_processor.run_explainability(

tests/unit/test_clarify.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,3 +1575,89 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model
15751575
},
15761576
}
15771577
assert actual == expected
1578+
1579+
1580+
def test_analysis_config_for_bias_no_model_config(data_bias_config):
1581+
s3_data_input_path = "s3://path/to/input.csv"
1582+
s3_output_path = "s3://path/to/output"
1583+
predicted_labels_uri = "s3://path/to/predicted_labels.csv"
1584+
label_name = "Label"
1585+
headers = [
1586+
"Label",
1587+
"F1",
1588+
"F2",
1589+
"F3",
1590+
"F4",
1591+
]
1592+
dataset_type = "text/csv"
1593+
data_config = DataConfig(
1594+
s3_data_input_path=s3_data_input_path,
1595+
s3_output_path=s3_output_path,
1596+
label=label_name,
1597+
headers=headers,
1598+
dataset_type=dataset_type,
1599+
predicted_label_dataset_uri=predicted_labels_uri,
1600+
predicted_label_headers=["PredictedLabel"],
1601+
predicted_label="PredictedLabel",
1602+
)
1603+
model_config = None
1604+
model_predicted_label_config = ModelPredictedLabelConfig(
1605+
probability="pr",
1606+
probability_threshold=0.8,
1607+
label_headers=["success"],
1608+
)
1609+
actual = _AnalysisConfigGenerator.bias(
1610+
data_config,
1611+
data_bias_config,
1612+
model_config,
1613+
model_predicted_label_config,
1614+
pre_training_methods="all",
1615+
post_training_methods="all",
1616+
)
1617+
expected = {
1618+
"dataset_type": "text/csv",
1619+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1620+
"label": "Label",
1621+
"predicted_label_dataset_uri": "s3://path/to/predicted_labels.csv",
1622+
"predicted_label_headers": ["PredictedLabel"],
1623+
"predicted_label": "PredictedLabel",
1624+
"label_values_or_threshold": [1],
1625+
"facet": [{"name_or_index": "F1"}],
1626+
"group_variable": "F2",
1627+
"methods": {
1628+
"report": {"name": "report", "title": "Analysis Report"},
1629+
"pre_training_bias": {"methods": "all"},
1630+
"post_training_bias": {"methods": "all"},
1631+
},
1632+
"probability_threshold": 0.8,
1633+
}
1634+
assert actual == expected
1635+
1636+
1637+
def test_invalid_analysis_config(data_config, data_bias_config, model_config):
1638+
with pytest.raises(
1639+
ValueError, match="model_config must be provided when explainability methods are selected."
1640+
):
1641+
_AnalysisConfigGenerator.bias_and_explainability(
1642+
data_config=data_config,
1643+
model_config=None,
1644+
model_predicted_label_config=ModelPredictedLabelConfig(),
1645+
explainability_config=SHAPConfig(),
1646+
bias_config=data_bias_config,
1647+
pre_training_methods="all",
1648+
post_training_methods="all",
1649+
)
1650+
1651+
with pytest.raises(
1652+
ValueError,
1653+
match="model_config must be provided when `predicted_label_dataset_uri` or "
1654+
"`predicted_label` are not provided in data_config.",
1655+
):
1656+
_AnalysisConfigGenerator.bias(
1657+
data_config=data_config,
1658+
model_config=None,
1659+
model_predicted_label_config=ModelPredictedLabelConfig(),
1660+
bias_config=data_bias_config,
1661+
pre_training_methods="all",
1662+
post_training_methods="all",
1663+
)

0 commit comments

Comments
 (0)