@@ -1104,16 +1104,13 @@ def run_post_training_bias(
1104
1104
the Trial Component will be unassociated.
1105
1105
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1106
1106
""" # noqa E501 # pylint: disable=c0301
1107
- analysis_config = data_config .get_config ()
1108
- analysis_config .update (data_bias_config .get_config ())
1109
- (
1110
- probability_threshold ,
1111
- predictor_config ,
1112
- ) = model_predicted_label_config .get_predictor_config ()
1113
- predictor_config .update (model_config .get_predictor_config ())
1114
- analysis_config ["methods" ] = {"post_training_bias" : {"methods" : methods }}
1115
- analysis_config ["predictor" ] = predictor_config
1116
- _set (probability_threshold , "probability_threshold" , analysis_config )
1107
+ analysis_config = _AnalysisConfigGenerator .bias_post_training (
1108
+ data_config ,
1109
+ data_bias_config ,
1110
+ model_predicted_label_config ,
1111
+ methods ,
1112
+ model_config
1113
+ )
1117
1114
if job_name is None :
1118
1115
if self .job_name_prefix :
1119
1116
job_name = utils .name_from_base (self .job_name_prefix )
@@ -1384,6 +1381,26 @@ def bias_pre_training(data_config, data_bias_config, methods):
1384
1381
analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
1385
1382
return analysis_config
1386
1383
1384
+ @staticmethod
1385
+ def bias_post_training (
1386
+ data_config ,
1387
+ data_bias_config ,
1388
+ model_predicted_label_config ,
1389
+ methods ,
1390
+ model_config
1391
+ ):
1392
+ analysis_config = data_config .get_config ()
1393
+ analysis_config .update (data_bias_config .get_config ())
1394
+ analysis_config ["methods" ] = {"post_training_bias" : {"methods" : methods }}
1395
+ (
1396
+ probability_threshold ,
1397
+ predictor_config ,
1398
+ ) = model_predicted_label_config .get_predictor_config ()
1399
+ predictor_config .update (model_config .get_predictor_config ())
1400
+ analysis_config ["predictor" ] = predictor_config
1401
+ _set (probability_threshold , "probability_threshold" , analysis_config )
1402
+ return analysis_config
1403
+
1387
1404
@staticmethod
1388
1405
def _common (analysis_config ):
1389
1406
analysis_config ["methods" ]["report" ] = {
0 commit comments