Skip to content

Commit 229784b

Browse files
committed
feature: extracted analysis config generation for bias
1 parent a1d47b7 commit 229784b

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

src/sagemaker/clarify.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def _run(
971971
def run_pre_training_bias(
972972
self,
973973
data_config,
974-
data_bias_config,
974+
bias_config,
975975
methods="all",
976976
wait=True,
977977
logs=True,
@@ -986,7 +986,7 @@ def run_pre_training_bias(
986986
987987
Args:
988988
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
989-
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
989+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
990990
methods (str or list[str]): Selects a subset of potential metrics:
991991
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
992992
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1022,7 +1022,7 @@ def run_pre_training_bias(
10221022
""" # noqa E501 # pylint: disable=c0301
10231023
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
10241024
data_config,
1025-
data_bias_config,
1025+
bias_config,
10261026
methods
10271027
)
10281028
if job_name is None:
@@ -1043,7 +1043,7 @@ def run_pre_training_bias(
10431043
def run_post_training_bias(
10441044
self,
10451045
data_config,
1046-
data_bias_config,
1046+
bias_config,
10471047
model_config,
10481048
model_predicted_label_config,
10491049
methods="all",
@@ -1063,7 +1063,7 @@ def run_post_training_bias(
10631063
10641064
Args:
10651065
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1066-
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1066+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
10671067
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
10681068
endpoint to be created.
10691069
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1106,7 +1106,7 @@ def run_post_training_bias(
11061106
""" # noqa E501 # pylint: disable=c0301
11071107
analysis_config = _AnalysisConfigGenerator.bias_post_training(
11081108
data_config,
1109-
data_bias_config,
1109+
bias_config,
11101110
model_predicted_label_config,
11111111
methods,
11121112
model_config
@@ -1200,23 +1200,14 @@ def run_bias(
12001200
the Trial Component will be unassociated.
12011201
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
12021202
""" # noqa E501 # pylint: disable=c0301
1203-
analysis_config = data_config.get_config()
1204-
analysis_config.update(bias_config.get_config())
1205-
analysis_config["predictor"] = model_config.get_predictor_config()
1206-
if model_predicted_label_config:
1207-
(
1208-
probability_threshold,
1209-
predictor_config,
1210-
) = model_predicted_label_config.get_predictor_config()
1211-
if predictor_config:
1212-
analysis_config["predictor"].update(predictor_config)
1213-
if probability_threshold is not None:
1214-
analysis_config["probability_threshold"] = probability_threshold
1215-
1216-
analysis_config["methods"] = {
1217-
"pre_training_bias": {"methods": pre_training_methods},
1218-
"post_training_bias": {"methods": post_training_methods},
1219-
}
1203+
analysis_config = _AnalysisConfigGenerator.bias(
1204+
data_config,
1205+
bias_config,
1206+
model_config,
1207+
model_predicted_label_config,
1208+
pre_training_methods,
1209+
post_training_methods,
1210+
)
12201211
if job_name is None:
12211212
if self.job_name_prefix:
12221213
job_name = utils.name_from_base(self.job_name_prefix)
@@ -1375,22 +1366,22 @@ def explainability(
13751366
return analysis_config
13761367

13771368
@staticmethod
1378-
def bias_pre_training(data_config, data_bias_config, methods):
1369+
def bias_pre_training(data_config, bias_config, methods):
13791370
analysis_config = data_config.get_config()
1380-
analysis_config.update(data_bias_config.get_config())
1371+
analysis_config.update(bias_config.get_config())
13811372
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
13821373
return analysis_config
13831374

13841375
@staticmethod
13851376
def bias_post_training(
13861377
data_config,
1387-
data_bias_config,
1378+
bias_config,
13881379
model_predicted_label_config,
13891380
methods,
13901381
model_config
13911382
):
13921383
analysis_config = data_config.get_config()
1393-
analysis_config.update(data_bias_config.get_config())
1384+
analysis_config.update(bias_config.get_config())
13941385
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
13951386
(
13961387
probability_threshold,
@@ -1401,6 +1392,34 @@ def bias_post_training(
14011392
_set(probability_threshold, "probability_threshold", analysis_config)
14021393
return analysis_config
14031394

1395+
@staticmethod
1396+
def bias(
1397+
data_config,
1398+
bias_config,
1399+
model_config,
1400+
model_predicted_label_config,
1401+
pre_training_methods="all",
1402+
post_training_methods="all",
1403+
):
1404+
analysis_config = data_config.get_config()
1405+
analysis_config.update(bias_config.get_config())
1406+
analysis_config["predictor"] = model_config.get_predictor_config()
1407+
if model_predicted_label_config:
1408+
(
1409+
probability_threshold,
1410+
predictor_config,
1411+
) = model_predicted_label_config.get_predictor_config()
1412+
if predictor_config:
1413+
analysis_config["predictor"].update(predictor_config)
1414+
if probability_threshold is not None:
1415+
analysis_config["probability_threshold"] = probability_threshold
1416+
1417+
analysis_config["methods"] = {
1418+
"pre_training_bias": {"methods": pre_training_methods},
1419+
"post_training_bias": {"methods": post_training_methods},
1420+
}
1421+
return analysis_config
1422+
14041423
@staticmethod
14051424
def _common(analysis_config):
14061425
analysis_config["methods"]["report"] = {

tests/unit/test_clarify.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,3 +1347,33 @@ def test_analysis_config_generator_for_bias_post_training(data_config, data_bias
13471347
'model_name': 'xgboost-model',
13481348
'probability': 'pr'}}
13491349
assert actual == expected
1350+
1351+
1352+
def test_analysis_config_generator_for_bias(data_config, data_bias_config, model_config):
1353+
model_predicted_label_config = ModelPredictedLabelConfig(
1354+
probability="pr",
1355+
label_headers=["success"],
1356+
)
1357+
actual = _AnalysisConfigGenerator.bias(
1358+
data_config,
1359+
data_bias_config,
1360+
model_config,
1361+
model_predicted_label_config,
1362+
pre_training_methods="all",
1363+
post_training_methods="all",
1364+
)
1365+
expected = {'dataset_type': 'text/csv',
1366+
'facet': [{'name_or_index': 'F1'}],
1367+
'group_variable': 'F2',
1368+
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1369+
'joinsource_name_or_index': 'F4',
1370+
'label': 'Label',
1371+
'label_values_or_threshold': [1],
1372+
'methods': {'post_training_bias': {'methods': 'all'},
1373+
'pre_training_bias': {'methods': 'all'}},
1374+
'predictor': {'initial_instance_count': 1,
1375+
'instance_type': 'ml.c5.xlarge',
1376+
'label_headers': ['success'],
1377+
'model_name': 'xgboost-model',
1378+
'probability': 'pr'}}
1379+
assert actual == expected

0 commit comments

Comments
 (0)