@@ -1304,6 +1304,9 @@ def run_explainability(
1304
1304
1305
1305
1306
1306
class _AnalysisConfigGenerator :
1307
+ """
1308
+ Creates analysis_config objects for different type of runs.
1309
+ """
1307
1310
@classmethod
1308
1311
def explainability (
1309
1312
cls ,
@@ -1334,15 +1337,15 @@ def explainability(
1334
1337
if not len (explainability_methods .keys ()) == len (explainability_config ):
1335
1338
raise ValueError ("Duplicate explainability configs are provided" )
1336
1339
if (
1337
- "shap" not in explainability_methods
1338
- and explainability_methods ["pdp" ].get ("features" , None ) is None
1340
+ "shap" not in explainability_methods
1341
+ and explainability_methods ["pdp" ].get ("features" , None ) is None
1339
1342
):
1340
1343
raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1341
1344
else :
1342
1345
if (
1343
- isinstance (explainability_config , PDPConfig )
1344
- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1345
- is None
1346
+ isinstance (explainability_config , PDPConfig )
1347
+ and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1348
+ is None
1346
1349
):
1347
1350
raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1348
1351
explainability_methods = explainability_config .get_explainability_config ()
@@ -1352,9 +1355,11 @@ def explainability(
1352
1355
1353
1356
@classmethod
1354
1357
def bias_pre_training (cls , data_config , bias_config , methods ):
1355
- analysis_config = data_config .get_config ()
1356
- analysis_config .update (bias_config .get_config ())
1357
- analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
1358
+ analysis_config = {
1359
+ ** data_config .get_config (),
1360
+ ** bias_config .get_config (),
1361
+ "methods" : {"pre_training_bias" : {"methods" : methods }}
1362
+ }
1358
1363
return cls ._common (analysis_config )
1359
1364
1360
1365
@classmethod
@@ -1366,16 +1371,17 @@ def bias_post_training(
1366
1371
methods ,
1367
1372
model_config
1368
1373
):
1369
- analysis_config = data_config .get_config ()
1370
- analysis_config .update (bias_config .get_config ())
1371
- analysis_config ["methods" ] = {"post_training_bias" : {"methods" : methods }}
1372
- (
1373
- probability_threshold ,
1374
- predictor_config ,
1375
- ) = model_predicted_label_config .get_predictor_config ()
1376
- predictor_config .update (model_config .get_predictor_config ())
1377
- analysis_config ["predictor" ] = predictor_config
1378
- _set (probability_threshold , "probability_threshold" , analysis_config )
1374
+ analysis_config = {
1375
+ ** data_config .get_config (),
1376
+ ** bias_config .get_config (),
1377
+ "predictor" : {** model_config .get_predictor_config ()},
1378
+ "methods" : {"post_training_bias" : {"methods" : methods }},
1379
+ }
1380
+ if model_predicted_label_config :
1381
+ probability_threshold , predictor_config = model_predicted_label_config .get_predictor_config ()
1382
+ if predictor_config :
1383
+ analysis_config ["predictor" ].update (predictor_config )
1384
+ _set (probability_threshold , "probability_threshold" , analysis_config )
1379
1385
return cls ._common (analysis_config )
1380
1386
1381
1387
@classmethod
@@ -1388,23 +1394,20 @@ def bias(
1388
1394
pre_training_methods = "all" ,
1389
1395
post_training_methods = "all" ,
1390
1396
):
1391
- analysis_config = data_config .get_config ()
1392
- analysis_config .update (bias_config .get_config ())
1393
- analysis_config ["predictor" ] = model_config .get_predictor_config ()
1397
+ analysis_config = {
1398
+ ** data_config .get_config (),
1399
+ ** bias_config .get_config (),
1400
+ "predictor" : model_config .get_predictor_config (),
1401
+ "methods" : {
1402
+ "pre_training_bias" : {"methods" : pre_training_methods },
1403
+ "post_training_bias" : {"methods" : post_training_methods },
1404
+ }
1405
+ }
1394
1406
if model_predicted_label_config :
1395
- (
1396
- probability_threshold ,
1397
- predictor_config ,
1398
- ) = model_predicted_label_config .get_predictor_config ()
1407
+ probability_threshold , predictor_config = model_predicted_label_config .get_predictor_config ()
1399
1408
if predictor_config :
1400
1409
analysis_config ["predictor" ].update (predictor_config )
1401
- if probability_threshold is not None :
1402
- analysis_config ["probability_threshold" ] = probability_threshold
1403
-
1404
- analysis_config ["methods" ] = {
1405
- "pre_training_bias" : {"methods" : pre_training_methods },
1406
- "post_training_bias" : {"methods" : post_training_methods },
1407
- }
1410
+ _set (probability_threshold , "probability_threshold" , analysis_config )
1408
1411
return cls ._common (analysis_config )
1409
1412
1410
1413
@staticmethod
0 commit comments