Skip to content

Commit cf4b08e

Browse files
aws-byeldosnavinsoni
authored andcommitted
feature: extracted analysis config generation for bias post_training
1 parent 4fa12c1 commit cf4b08e

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

src/sagemaker/clarify.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,16 +1167,13 @@ def run_post_training_bias(
11671167
the Trial Component will be unassociated.
11681168
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
11691169
""" # noqa E501 # pylint: disable=c0301
1170-
analysis_config = data_config.get_config()
1171-
analysis_config.update(data_bias_config.get_config())
1172-
(
1173-
probability_threshold,
1174-
predictor_config,
1175-
) = model_predicted_label_config.get_predictor_config()
1176-
predictor_config.update(model_config.get_predictor_config())
1177-
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
1178-
analysis_config["predictor"] = predictor_config
1179-
_set(probability_threshold, "probability_threshold", analysis_config)
1170+
analysis_config = _AnalysisConfigGenerator.bias_post_training(
1171+
data_config,
1172+
data_bias_config,
1173+
model_predicted_label_config,
1174+
methods,
1175+
model_config
1176+
)
11801177
if job_name is None:
11811178
if self.job_name_prefix:
11821179
job_name = utils.name_from_base(self.job_name_prefix)
@@ -1447,6 +1444,26 @@ def bias_pre_training(data_config, data_bias_config, methods):
14471444
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
14481445
return analysis_config
14491446

1447+
@staticmethod
1448+
def bias_post_training(
1449+
data_config,
1450+
data_bias_config,
1451+
model_predicted_label_config,
1452+
methods,
1453+
model_config
1454+
):
1455+
analysis_config = data_config.get_config()
1456+
analysis_config.update(data_bias_config.get_config())
1457+
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
1458+
(
1459+
probability_threshold,
1460+
predictor_config,
1461+
) = model_predicted_label_config.get_predictor_config()
1462+
predictor_config.update(model_config.get_predictor_config())
1463+
analysis_config["predictor"] = predictor_config
1464+
_set(probability_threshold, "probability_threshold", analysis_config)
1465+
return analysis_config
1466+
14501467
@staticmethod
14511468
def _common(analysis_config):
14521469
analysis_config["methods"]["report"] = {

tests/unit/test_clarify.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,3 +1319,31 @@ def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_
13191319
'label_values_or_threshold': [1],
13201320
'methods': {'pre_training_bias': {'methods': 'all'}}}
13211321
assert actual == expected
1322+
1323+
1324+
def test_analysis_config_generator_for_bias_post_training(data_config, data_bias_config, model_config):
1325+
model_predicted_label_config = ModelPredictedLabelConfig(
1326+
probability="pr",
1327+
label_headers=["success"],
1328+
)
1329+
actual = _AnalysisConfigGenerator.bias_post_training(
1330+
data_config,
1331+
data_bias_config,
1332+
model_predicted_label_config,
1333+
methods="all",
1334+
model_config=model_config,
1335+
)
1336+
expected = {'dataset_type': 'text/csv',
1337+
'facet': [{'name_or_index': 'F1'}],
1338+
'group_variable': 'F2',
1339+
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1340+
'joinsource_name_or_index': 'F4',
1341+
'label': 'Label',
1342+
'label_values_or_threshold': [1],
1343+
'methods': {'post_training_bias': {'methods': 'all'}},
1344+
'predictor': {'initial_instance_count': 1,
1345+
'instance_type': 'ml.c5.xlarge',
1346+
'label_headers': ['success'],
1347+
'model_name': 'xgboost-model',
1348+
'probability': 'pr'}}
1349+
assert actual == expected

0 commit comments

Comments
 (0)