@@ -1304,6 +1304,9 @@ def run_explainability(
1304
1304
1305
1305
1306
1306
class _AnalysisConfigGenerator :
1307
+ """
1308
+ Creates analysis_config objects for different type of runs.
1309
+ """
1307
1310
@classmethod
1308
1311
def explainability (
1309
1312
cls ,
@@ -1334,15 +1337,15 @@ def explainability(
1334
1337
if not len (explainability_methods .keys ()) == len (explainability_config ):
1335
1338
raise ValueError ("Duplicate explainability configs are provided" )
1336
1339
if (
1337
- "shap" not in explainability_methods
1338
- and explainability_methods ["pdp" ].get ("features" , None ) is None
1340
+ "shap" not in explainability_methods
1341
+ and explainability_methods ["pdp" ].get ("features" , None ) is None
1339
1342
):
1340
1343
raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1341
1344
else :
1342
1345
if (
1343
- isinstance (explainability_config , PDPConfig )
1344
- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1345
- is None
1346
+ isinstance (explainability_config , PDPConfig )
1347
+ and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1348
+ is None
1346
1349
):
1347
1350
raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1348
1351
explainability_methods = explainability_config .get_explainability_config ()
@@ -1352,9 +1355,11 @@ def explainability(
1352
1355
1353
1356
@classmethod
1354
1357
def bias_pre_training (cls , data_config , bias_config , methods ):
1355
- analysis_config = data_config .get_config ()
1356
- analysis_config .update (bias_config .get_config ())
1357
- analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
1358
+ analysis_config = {
1359
+ ** data_config .get_config (),
1360
+ ** bias_config .get_config (),
1361
+ "methods" : {"pre_training_bias" : {"methods" : methods }}
1362
+ }
1358
1363
return cls ._common (analysis_config )
1359
1364
1360
1365
@classmethod
@@ -1366,15 +1371,13 @@ def bias_post_training(
1366
1371
methods ,
1367
1372
model_config
1368
1373
):
1369
- analysis_config = data_config .get_config ()
1370
- analysis_config .update (bias_config .get_config ())
1371
- analysis_config ["methods" ] = {"post_training_bias" : {"methods" : methods }}
1372
- (
1373
- probability_threshold ,
1374
- predictor_config ,
1375
- ) = model_predicted_label_config .get_predictor_config ()
1376
- predictor_config .update (model_config .get_predictor_config ())
1377
- analysis_config ["predictor" ] = predictor_config
1374
+ probability_threshold , predictor_config = model_predicted_label_config .get_predictor_config ()
1375
+ analysis_config = {
1376
+ ** data_config .get_config (),
1377
+ ** bias_config .get_config (),
1378
+ "methods" : {"post_training_bias" : {"methods" : methods }},
1379
+ "predictor" : {** predictor_config , ** model_config .get_predictor_config ()},
1380
+ }
1378
1381
_set (probability_threshold , "probability_threshold" , analysis_config )
1379
1382
return cls ._common (analysis_config )
1380
1383
@@ -1388,23 +1391,21 @@ def bias(
1388
1391
pre_training_methods = "all" ,
1389
1392
post_training_methods = "all" ,
1390
1393
):
1391
- analysis_config = data_config .get_config ()
1392
- analysis_config .update (bias_config .get_config ())
1393
- analysis_config ["predictor" ] = model_config .get_predictor_config ()
1394
+ analysis_config = {
1395
+ ** data_config .get_config (),
1396
+ ** bias_config .get_config (),
1397
+ "predictor" : model_config .get_predictor_config (),
1398
+ "methods" : {
1399
+ "pre_training_bias" : {"methods" : pre_training_methods },
1400
+ "post_training_bias" : {"methods" : post_training_methods },
1401
+ }
1402
+ }
1394
1403
if model_predicted_label_config :
1395
- (
1396
- probability_threshold ,
1397
- predictor_config ,
1398
- ) = model_predicted_label_config .get_predictor_config ()
1404
+ probability_threshold , predictor_config = model_predicted_label_config .get_predictor_config ()
1399
1405
if predictor_config :
1400
1406
analysis_config ["predictor" ].update (predictor_config )
1401
1407
if probability_threshold is not None :
1402
1408
analysis_config ["probability_threshold" ] = probability_threshold
1403
-
1404
- analysis_config ["methods" ] = {
1405
- "pre_training_bias" : {"methods" : pre_training_methods },
1406
- "post_training_bias" : {"methods" : post_training_methods },
1407
- }
1408
1409
return cls ._common (analysis_config )
1409
1410
1410
1411
@staticmethod
0 commit comments