Skip to content

Commit a1d47b7

Browse files
committed
feature: extracted analysis config generation for bias post_training
1 parent 2ca03ef commit a1d47b7

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

src/sagemaker/clarify.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,16 +1104,13 @@ def run_post_training_bias(
11041104
the Trial Component will be unassociated.
11051105
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
11061106
""" # noqa E501 # pylint: disable=c0301
1107-
analysis_config = data_config.get_config()
1108-
analysis_config.update(data_bias_config.get_config())
1109-
(
1110-
probability_threshold,
1111-
predictor_config,
1112-
) = model_predicted_label_config.get_predictor_config()
1113-
predictor_config.update(model_config.get_predictor_config())
1114-
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
1115-
analysis_config["predictor"] = predictor_config
1116-
_set(probability_threshold, "probability_threshold", analysis_config)
1107+
analysis_config = _AnalysisConfigGenerator.bias_post_training(
1108+
data_config,
1109+
data_bias_config,
1110+
model_predicted_label_config,
1111+
methods,
1112+
model_config
1113+
)
11171114
if job_name is None:
11181115
if self.job_name_prefix:
11191116
job_name = utils.name_from_base(self.job_name_prefix)
@@ -1384,6 +1381,26 @@ def bias_pre_training(data_config, data_bias_config, methods):
13841381
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
13851382
return analysis_config
13861383

1384+
@staticmethod
1385+
def bias_post_training(
1386+
data_config,
1387+
data_bias_config,
1388+
model_predicted_label_config,
1389+
methods,
1390+
model_config
1391+
):
1392+
analysis_config = data_config.get_config()
1393+
analysis_config.update(data_bias_config.get_config())
1394+
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
1395+
(
1396+
probability_threshold,
1397+
predictor_config,
1398+
) = model_predicted_label_config.get_predictor_config()
1399+
predictor_config.update(model_config.get_predictor_config())
1400+
analysis_config["predictor"] = predictor_config
1401+
_set(probability_threshold, "probability_threshold", analysis_config)
1402+
return analysis_config
1403+
13871404
@staticmethod
13881405
def _common(analysis_config):
13891406
analysis_config["methods"]["report"] = {

tests/unit/test_clarify.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,3 +1319,31 @@ def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_
13191319
'label_values_or_threshold': [1],
13201320
'methods': {'pre_training_bias': {'methods': 'all'}}}
13211321
assert actual == expected
1322+
1323+
1324+
def test_analysis_config_generator_for_bias_post_training(data_config, data_bias_config, model_config):
1325+
model_predicted_label_config = ModelPredictedLabelConfig(
1326+
probability="pr",
1327+
label_headers=["success"],
1328+
)
1329+
actual = _AnalysisConfigGenerator.bias_post_training(
1330+
data_config,
1331+
data_bias_config,
1332+
model_predicted_label_config,
1333+
methods="all",
1334+
model_config=model_config,
1335+
)
1336+
expected = {'dataset_type': 'text/csv',
1337+
'facet': [{'name_or_index': 'F1'}],
1338+
'group_variable': 'F2',
1339+
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1340+
'joinsource_name_or_index': 'F4',
1341+
'label': 'Label',
1342+
'label_values_or_threshold': [1],
1343+
'methods': {'post_training_bias': {'methods': 'all'}},
1344+
'predictor': {'initial_instance_count': 1,
1345+
'instance_type': 'ml.c5.xlarge',
1346+
'label_headers': ['success'],
1347+
'model_name': 'xgboost-model',
1348+
'probability': 'pr'}}
1349+
assert actual == expected

0 commit comments

Comments
 (0)