@@ -971,7 +971,7 @@ def _run(
971
971
def run_pre_training_bias (
972
972
self ,
973
973
data_config ,
974
- data_bias_config ,
974
+ bias_config ,
975
975
methods = "all" ,
976
976
wait = True ,
977
977
logs = True ,
@@ -986,7 +986,7 @@ def run_pre_training_bias(
986
986
987
987
Args:
988
988
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
989
- data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
989
+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
990
990
methods (str or list[str]): Selects a subset of potential metrics:
991
991
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
992
992
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1022,7 +1022,7 @@ def run_pre_training_bias(
1022
1022
""" # noqa E501 # pylint: disable=c0301
1023
1023
analysis_config = _AnalysisConfigGenerator .bias_pre_training (
1024
1024
data_config ,
1025
- data_bias_config ,
1025
+ bias_config ,
1026
1026
methods
1027
1027
)
1028
1028
if job_name is None :
@@ -1043,7 +1043,7 @@ def run_pre_training_bias(
1043
1043
def run_post_training_bias (
1044
1044
self ,
1045
1045
data_config ,
1046
- data_bias_config ,
1046
+ bias_config ,
1047
1047
model_config ,
1048
1048
model_predicted_label_config ,
1049
1049
methods = "all" ,
@@ -1063,7 +1063,7 @@ def run_post_training_bias(
1063
1063
1064
1064
Args:
1065
1065
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1066
- data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1066
+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1067
1067
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1068
1068
endpoint to be created.
1069
1069
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1106,7 +1106,7 @@ def run_post_training_bias(
1106
1106
""" # noqa E501 # pylint: disable=c0301
1107
1107
analysis_config = _AnalysisConfigGenerator .bias_post_training (
1108
1108
data_config ,
1109
- data_bias_config ,
1109
+ bias_config ,
1110
1110
model_predicted_label_config ,
1111
1111
methods ,
1112
1112
model_config
@@ -1200,23 +1200,14 @@ def run_bias(
1200
1200
the Trial Component will be unassociated.
1201
1201
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1202
1202
""" # noqa E501 # pylint: disable=c0301
1203
- analysis_config = data_config .get_config ()
1204
- analysis_config .update (bias_config .get_config ())
1205
- analysis_config ["predictor" ] = model_config .get_predictor_config ()
1206
- if model_predicted_label_config :
1207
- (
1208
- probability_threshold ,
1209
- predictor_config ,
1210
- ) = model_predicted_label_config .get_predictor_config ()
1211
- if predictor_config :
1212
- analysis_config ["predictor" ].update (predictor_config )
1213
- if probability_threshold is not None :
1214
- analysis_config ["probability_threshold" ] = probability_threshold
1215
-
1216
- analysis_config ["methods" ] = {
1217
- "pre_training_bias" : {"methods" : pre_training_methods },
1218
- "post_training_bias" : {"methods" : post_training_methods },
1219
- }
1203
+ analysis_config = _AnalysisConfigGenerator .bias (
1204
+ data_config ,
1205
+ bias_config ,
1206
+ model_config ,
1207
+ model_predicted_label_config ,
1208
+ pre_training_methods ,
1209
+ post_training_methods ,
1210
+ )
1220
1211
if job_name is None :
1221
1212
if self .job_name_prefix :
1222
1213
job_name = utils .name_from_base (self .job_name_prefix )
@@ -1375,22 +1366,22 @@ def explainability(
1375
1366
return analysis_config
1376
1367
1377
1368
@staticmethod
1378
- def bias_pre_training (data_config , data_bias_config , methods ):
1369
+ def bias_pre_training (data_config , bias_config , methods ):
1379
1370
analysis_config = data_config .get_config ()
1380
- analysis_config .update (data_bias_config .get_config ())
1371
+ analysis_config .update (bias_config .get_config ())
1381
1372
analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
1382
1373
return analysis_config
1383
1374
1384
1375
@staticmethod
1385
1376
def bias_post_training (
1386
1377
data_config ,
1387
- data_bias_config ,
1378
+ bias_config ,
1388
1379
model_predicted_label_config ,
1389
1380
methods ,
1390
1381
model_config
1391
1382
):
1392
1383
analysis_config = data_config .get_config ()
1393
- analysis_config .update (data_bias_config .get_config ())
1384
+ analysis_config .update (bias_config .get_config ())
1394
1385
analysis_config ["methods" ] = {"post_training_bias" : {"methods" : methods }}
1395
1386
(
1396
1387
probability_threshold ,
@@ -1401,6 +1392,34 @@ def bias_post_training(
1401
1392
_set (probability_threshold , "probability_threshold" , analysis_config )
1402
1393
return analysis_config
1403
1394
1395
+ @staticmethod
1396
+ def bias (
1397
+ data_config ,
1398
+ bias_config ,
1399
+ model_config ,
1400
+ model_predicted_label_config ,
1401
+ pre_training_methods = "all" ,
1402
+ post_training_methods = "all" ,
1403
+ ):
1404
+ analysis_config = data_config .get_config ()
1405
+ analysis_config .update (bias_config .get_config ())
1406
+ analysis_config ["predictor" ] = model_config .get_predictor_config ()
1407
+ if model_predicted_label_config :
1408
+ (
1409
+ probability_threshold ,
1410
+ predictor_config ,
1411
+ ) = model_predicted_label_config .get_predictor_config ()
1412
+ if predictor_config :
1413
+ analysis_config ["predictor" ].update (predictor_config )
1414
+ if probability_threshold is not None :
1415
+ analysis_config ["probability_threshold" ] = probability_threshold
1416
+
1417
+ analysis_config ["methods" ] = {
1418
+ "pre_training_bias" : {"methods" : pre_training_methods },
1419
+ "post_training_bias" : {"methods" : post_training_methods },
1420
+ }
1421
+ return analysis_config
1422
+
1404
1423
@staticmethod
1405
1424
def _common (analysis_config ):
1406
1425
analysis_config ["methods" ]["report" ] = {
0 commit comments