@@ -1034,7 +1034,7 @@ def _run(
1034
1034
def run_pre_training_bias (
1035
1035
self ,
1036
1036
data_config ,
1037
- data_bias_config ,
1037
+ bias_config ,
1038
1038
methods = "all" ,
1039
1039
wait = True ,
1040
1040
logs = True ,
@@ -1049,7 +1049,7 @@ def run_pre_training_bias(
1049
1049
1050
1050
Args:
1051
1051
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1052
- data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1052
+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1053
1053
methods (str or list[str]): Selects a subset of potential metrics:
1054
1054
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
1055
1055
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
@@ -1085,7 +1085,7 @@ def run_pre_training_bias(
1085
1085
""" # noqa E501 # pylint: disable=c0301
1086
1086
analysis_config = _AnalysisConfigGenerator .bias_pre_training (
1087
1087
data_config ,
1088
- data_bias_config ,
1088
+ bias_config ,
1089
1089
methods
1090
1090
)
1091
1091
if job_name is None :
@@ -1106,7 +1106,7 @@ def run_pre_training_bias(
1106
1106
def run_post_training_bias (
1107
1107
self ,
1108
1108
data_config ,
1109
- data_bias_config ,
1109
+ bias_config ,
1110
1110
model_config ,
1111
1111
model_predicted_label_config ,
1112
1112
methods = "all" ,
@@ -1126,7 +1126,7 @@ def run_post_training_bias(
1126
1126
1127
1127
Args:
1128
1128
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1129
- data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1129
+ bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1130
1130
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1131
1131
endpoint to be created.
1132
1132
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
@@ -1169,7 +1169,7 @@ def run_post_training_bias(
1169
1169
""" # noqa E501 # pylint: disable=c0301
1170
1170
analysis_config = _AnalysisConfigGenerator .bias_post_training (
1171
1171
data_config ,
1172
- data_bias_config ,
1172
+ bias_config ,
1173
1173
model_predicted_label_config ,
1174
1174
methods ,
1175
1175
model_config
@@ -1263,23 +1263,14 @@ def run_bias(
1263
1263
the Trial Component will be unassociated.
1264
1264
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1265
1265
""" # noqa E501 # pylint: disable=c0301
1266
- analysis_config = data_config .get_config ()
1267
- analysis_config .update (bias_config .get_config ())
1268
- analysis_config ["predictor" ] = model_config .get_predictor_config ()
1269
- if model_predicted_label_config :
1270
- (
1271
- probability_threshold ,
1272
- predictor_config ,
1273
- ) = model_predicted_label_config .get_predictor_config ()
1274
- if predictor_config :
1275
- analysis_config ["predictor" ].update (predictor_config )
1276
- if probability_threshold is not None :
1277
- analysis_config ["probability_threshold" ] = probability_threshold
1278
-
1279
- analysis_config ["methods" ] = {
1280
- "pre_training_bias" : {"methods" : pre_training_methods },
1281
- "post_training_bias" : {"methods" : post_training_methods },
1282
- }
1266
+ analysis_config = _AnalysisConfigGenerator .bias (
1267
+ data_config ,
1268
+ bias_config ,
1269
+ model_config ,
1270
+ model_predicted_label_config ,
1271
+ pre_training_methods ,
1272
+ post_training_methods ,
1273
+ )
1283
1274
if job_name is None :
1284
1275
if self .job_name_prefix :
1285
1276
job_name = utils .name_from_base (self .job_name_prefix )
@@ -1438,22 +1429,22 @@ def explainability(
1438
1429
return analysis_config
1439
1430
1440
1431
@staticmethod
1441
- def bias_pre_training (data_config , data_bias_config , methods ):
1432
+ def bias_pre_training (data_config , bias_config , methods ):
1442
1433
analysis_config = data_config .get_config ()
1443
- analysis_config .update (data_bias_config .get_config ())
1434
+ analysis_config .update (bias_config .get_config ())
1444
1435
analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
1445
1436
return analysis_config
1446
1437
1447
1438
@staticmethod
1448
1439
def bias_post_training (
1449
1440
data_config ,
1450
- data_bias_config ,
1441
+ bias_config ,
1451
1442
model_predicted_label_config ,
1452
1443
methods ,
1453
1444
model_config
1454
1445
):
1455
1446
analysis_config = data_config .get_config ()
1456
- analysis_config .update (data_bias_config .get_config ())
1447
+ analysis_config .update (bias_config .get_config ())
1457
1448
analysis_config ["methods" ] = {"post_training_bias" : {"methods" : methods }}
1458
1449
(
1459
1450
probability_threshold ,
@@ -1464,6 +1455,34 @@ def bias_post_training(
1464
1455
_set (probability_threshold , "probability_threshold" , analysis_config )
1465
1456
return analysis_config
1466
1457
1458
+ @staticmethod
1459
+ def bias (
1460
+ data_config ,
1461
+ bias_config ,
1462
+ model_config ,
1463
+ model_predicted_label_config ,
1464
+ pre_training_methods = "all" ,
1465
+ post_training_methods = "all" ,
1466
+ ):
1467
+ analysis_config = data_config .get_config ()
1468
+ analysis_config .update (bias_config .get_config ())
1469
+ analysis_config ["predictor" ] = model_config .get_predictor_config ()
1470
+ if model_predicted_label_config :
1471
+ (
1472
+ probability_threshold ,
1473
+ predictor_config ,
1474
+ ) = model_predicted_label_config .get_predictor_config ()
1475
+ if predictor_config :
1476
+ analysis_config ["predictor" ].update (predictor_config )
1477
+ if probability_threshold is not None :
1478
+ analysis_config ["probability_threshold" ] = probability_threshold
1479
+
1480
+ analysis_config ["methods" ] = {
1481
+ "pre_training_bias" : {"methods" : pre_training_methods },
1482
+ "post_training_bias" : {"methods" : post_training_methods },
1483
+ }
1484
+ return analysis_config
1485
+
1467
1486
@staticmethod
1468
1487
def _common (analysis_config ):
1469
1488
analysis_config ["methods" ]["report" ] = {
0 commit comments