Skip to content

Commit 2ca03ef

Browse files
committed
feature: extracted analysis config generation for bias pre_training
1 parent 76b6ae7 commit 2ca03ef

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/sagemaker/clarify.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,9 +1020,11 @@ def run_pre_training_bias(
10201020
the Trial Component will be unassociated.
10211021
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
10221022
""" # noqa E501 # pylint: disable=c0301
1023-
analysis_config = data_config.get_config()
1024-
analysis_config.update(data_bias_config.get_config())
1025-
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
1023+
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
1024+
data_config,
1025+
data_bias_config,
1026+
methods
1027+
)
10261028
if job_name is None:
10271029
if self.job_name_prefix:
10281030
job_name = utils.name_from_base(self.job_name_prefix)
@@ -1375,6 +1377,21 @@ def explainability(
13751377
analysis_config["predictor"] = predictor_config
13761378
return analysis_config
13771379

1380+
@staticmethod
1381+
def bias_pre_training(data_config, data_bias_config, methods):
1382+
analysis_config = data_config.get_config()
1383+
analysis_config.update(data_bias_config.get_config())
1384+
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
1385+
return analysis_config
1386+
1387+
@staticmethod
1388+
def _common(analysis_config):
1389+
analysis_config["methods"]["report"] = {
1390+
"name": "report",
1391+
"title": "Analysis Report",
1392+
}
1393+
return analysis_config
1394+
13781395

13791396
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
13801397
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.

tests/unit/test_clarify.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,3 +1303,19 @@ def test_analysis_config_generator_for_explainability(data_config, model_config)
13031303
'probability': 'pr'}}
13041304
assert actual == expected
13051305

1306+
1307+
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
1308+
actual = _AnalysisConfigGenerator.bias_pre_training(
1309+
data_config,
1310+
data_bias_config,
1311+
methods="all"
1312+
)
1313+
expected = {'dataset_type': 'text/csv',
1314+
'facet': [{'name_or_index': 'F1'}],
1315+
'group_variable': 'F2',
1316+
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1317+
'joinsource_name_or_index': 'F4',
1318+
'label': 'Label',
1319+
'label_values_or_threshold': [1],
1320+
'methods': {'pre_training_bias': {'methods': 'all'}}}
1321+
assert actual == expected

0 commit comments

Comments
 (0)