Skip to content

Commit 3b2df62

Browse files
aws-byeldosnavinsoni
authored andcommitted
feature: extracted analysis config generation for bias
1 parent cf4b08e commit 3b2df62

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

src/sagemaker/clarify.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,7 @@ def _run(
10341034
def run_pre_training_bias(
10351035
self,
10361036
data_config,
1037-
data_bias_config,
1037+
bias_config,
10381038
methods="all",
10391039
wait=True,
10401040
logs=True,
@@ -1049,7 +1049,7 @@ def run_pre_training_bias(
10491049
10501050
Args:
10511051
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1052-
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1052+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
10531053
methods (str or list[str]): Selects a subset of potential metrics:
10541054
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
10551055
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1085,7 +1085,7 @@ def run_pre_training_bias(
10851085
""" # noqa E501 # pylint: disable=c0301
10861086
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
10871087
data_config,
1088-
data_bias_config,
1088+
bias_config,
10891089
methods
10901090
)
10911091
if job_name is None:
@@ -1106,7 +1106,7 @@ def run_pre_training_bias(
11061106
def run_post_training_bias(
11071107
self,
11081108
data_config,
1109-
data_bias_config,
1109+
bias_config,
11101110
model_config,
11111111
model_predicted_label_config,
11121112
methods="all",
@@ -1126,7 +1126,7 @@ def run_post_training_bias(
11261126
11271127
Args:
11281128
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1129-
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1129+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
11301130
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
11311131
endpoint to be created.
11321132
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1169,7 +1169,7 @@ def run_post_training_bias(
11691169
""" # noqa E501 # pylint: disable=c0301
11701170
analysis_config = _AnalysisConfigGenerator.bias_post_training(
11711171
data_config,
1172-
data_bias_config,
1172+
bias_config,
11731173
model_predicted_label_config,
11741174
methods,
11751175
model_config
@@ -1263,23 +1263,14 @@ def run_bias(
12631263
the Trial Component will be unassociated.
12641264
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
12651265
""" # noqa E501 # pylint: disable=c0301
1266-
analysis_config = data_config.get_config()
1267-
analysis_config.update(bias_config.get_config())
1268-
analysis_config["predictor"] = model_config.get_predictor_config()
1269-
if model_predicted_label_config:
1270-
(
1271-
probability_threshold,
1272-
predictor_config,
1273-
) = model_predicted_label_config.get_predictor_config()
1274-
if predictor_config:
1275-
analysis_config["predictor"].update(predictor_config)
1276-
if probability_threshold is not None:
1277-
analysis_config["probability_threshold"] = probability_threshold
1278-
1279-
analysis_config["methods"] = {
1280-
"pre_training_bias": {"methods": pre_training_methods},
1281-
"post_training_bias": {"methods": post_training_methods},
1282-
}
1266+
analysis_config = _AnalysisConfigGenerator.bias(
1267+
data_config,
1268+
bias_config,
1269+
model_config,
1270+
model_predicted_label_config,
1271+
pre_training_methods,
1272+
post_training_methods,
1273+
)
12831274
if job_name is None:
12841275
if self.job_name_prefix:
12851276
job_name = utils.name_from_base(self.job_name_prefix)
@@ -1438,22 +1429,22 @@ def explainability(
14381429
return analysis_config
14391430

14401431
@staticmethod
1441-
def bias_pre_training(data_config, data_bias_config, methods):
1432+
def bias_pre_training(data_config, bias_config, methods):
14421433
analysis_config = data_config.get_config()
1443-
analysis_config.update(data_bias_config.get_config())
1434+
analysis_config.update(bias_config.get_config())
14441435
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
14451436
return analysis_config
14461437

14471438
@staticmethod
14481439
def bias_post_training(
14491440
data_config,
1450-
data_bias_config,
1441+
bias_config,
14511442
model_predicted_label_config,
14521443
methods,
14531444
model_config
14541445
):
14551446
analysis_config = data_config.get_config()
1456-
analysis_config.update(data_bias_config.get_config())
1447+
analysis_config.update(bias_config.get_config())
14571448
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
14581449
(
14591450
probability_threshold,
@@ -1464,6 +1455,34 @@ def bias_post_training(
14641455
_set(probability_threshold, "probability_threshold", analysis_config)
14651456
return analysis_config
14661457

1458+
@staticmethod
1459+
def bias(
1460+
data_config,
1461+
bias_config,
1462+
model_config,
1463+
model_predicted_label_config,
1464+
pre_training_methods="all",
1465+
post_training_methods="all",
1466+
):
1467+
analysis_config = data_config.get_config()
1468+
analysis_config.update(bias_config.get_config())
1469+
analysis_config["predictor"] = model_config.get_predictor_config()
1470+
if model_predicted_label_config:
1471+
(
1472+
probability_threshold,
1473+
predictor_config,
1474+
) = model_predicted_label_config.get_predictor_config()
1475+
if predictor_config:
1476+
analysis_config["predictor"].update(predictor_config)
1477+
if probability_threshold is not None:
1478+
analysis_config["probability_threshold"] = probability_threshold
1479+
1480+
analysis_config["methods"] = {
1481+
"pre_training_bias": {"methods": pre_training_methods},
1482+
"post_training_bias": {"methods": post_training_methods},
1483+
}
1484+
return analysis_config
1485+
14671486
@staticmethod
14681487
def _common(analysis_config):
14691488
analysis_config["methods"]["report"] = {

tests/unit/test_clarify.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,3 +1347,33 @@ def test_analysis_config_generator_for_bias_post_training(data_config, data_bias
13471347
'model_name': 'xgboost-model',
13481348
'probability': 'pr'}}
13491349
assert actual == expected
1350+
1351+
1352+
def test_analysis_config_generator_for_bias(data_config, data_bias_config, model_config):
1353+
model_predicted_label_config = ModelPredictedLabelConfig(
1354+
probability="pr",
1355+
label_headers=["success"],
1356+
)
1357+
actual = _AnalysisConfigGenerator.bias(
1358+
data_config,
1359+
data_bias_config,
1360+
model_config,
1361+
model_predicted_label_config,
1362+
pre_training_methods="all",
1363+
post_training_methods="all",
1364+
)
1365+
expected = {'dataset_type': 'text/csv',
1366+
'facet': [{'name_or_index': 'F1'}],
1367+
'group_variable': 'F2',
1368+
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1369+
'joinsource_name_or_index': 'F4',
1370+
'label': 'Label',
1371+
'label_values_or_threshold': [1],
1372+
'methods': {'post_training_bias': {'methods': 'all'},
1373+
'pre_training_bias': {'methods': 'all'}},
1374+
'predictor': {'initial_instance_count': 1,
1375+
'instance_type': 'ml.c5.xlarge',
1376+
'label_headers': ['success'],
1377+
'model_name': 'xgboost-model',
1378+
'probability': 'pr'}}
1379+
assert actual == expected

0 commit comments

Comments
 (0)