@@ -1367,6 +1367,9 @@ def run_explainability(
1367
1367
1368
1368
1369
1369
class _AnalysisConfigGenerator :
1370
+ """
1371
+ Creates analysis_config objects for different type of runs.
1372
+ """
1370
1373
@classmethod
1371
1374
def explainability (
1372
1375
cls ,
@@ -1397,15 +1400,15 @@ def explainability(
1397
1400
if not len (explainability_methods .keys ()) == len (explainability_config ):
1398
1401
raise ValueError ("Duplicate explainability configs are provided" )
1399
1402
if (
1400
- "shap" not in explainability_methods
1401
- and explainability_methods ["pdp" ].get ("features" , None ) is None
1403
+ "shap" not in explainability_methods
1404
+ and explainability_methods ["pdp" ].get ("features" , None ) is None
1402
1405
):
1403
1406
raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1404
1407
else :
1405
1408
if (
1406
- isinstance (explainability_config , PDPConfig )
1407
- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1408
- is None
1409
+ isinstance (explainability_config , PDPConfig )
1410
+ and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1411
+ is None
1409
1412
):
1410
1413
raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1411
1414
explainability_methods = explainability_config .get_explainability_config ()
@@ -1415,9 +1418,11 @@ def explainability(
1415
1418
1416
1419
@classmethod
1417
1420
def bias_pre_training (cls , data_config , bias_config , methods ):
1418
- analysis_config = data_config .get_config ()
1419
- analysis_config .update (bias_config .get_config ())
1420
- analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
1421
+ analysis_config = {
1422
+ ** data_config .get_config (),
1423
+ ** bias_config .get_config (),
1424
+ "methods" : {"pre_training_bias" : {"methods" : methods }}
1425
+ }
1421
1426
return cls ._common (analysis_config )
1422
1427
1423
1428
@classmethod
@@ -1429,16 +1434,17 @@ def bias_post_training(
1429
1434
methods ,
1430
1435
model_config
1431
1436
):
1432
- analysis_config = data_config .get_config ()
1433
- analysis_config .update (bias_config .get_config ())
1434
- analysis_config ["methods" ] = {"post_training_bias" : {"methods" : methods }}
1435
- (
1436
- probability_threshold ,
1437
- predictor_config ,
1438
- ) = model_predicted_label_config .get_predictor_config ()
1439
- predictor_config .update (model_config .get_predictor_config ())
1440
- analysis_config ["predictor" ] = predictor_config
1441
- _set (probability_threshold , "probability_threshold" , analysis_config )
1437
+ analysis_config = {
1438
+ ** data_config .get_config (),
1439
+ ** bias_config .get_config (),
1440
+ "predictor" : {** model_config .get_predictor_config ()},
1441
+ "methods" : {"post_training_bias" : {"methods" : methods }},
1442
+ }
1443
+ if model_predicted_label_config :
1444
+ probability_threshold , predictor_config = model_predicted_label_config .get_predictor_config ()
1445
+ if predictor_config :
1446
+ analysis_config ["predictor" ].update (predictor_config )
1447
+ _set (probability_threshold , "probability_threshold" , analysis_config )
1442
1448
return cls ._common (analysis_config )
1443
1449
1444
1450
@classmethod
@@ -1451,23 +1457,20 @@ def bias(
1451
1457
pre_training_methods = "all" ,
1452
1458
post_training_methods = "all" ,
1453
1459
):
1454
- analysis_config = data_config .get_config ()
1455
- analysis_config .update (bias_config .get_config ())
1456
- analysis_config ["predictor" ] = model_config .get_predictor_config ()
1460
+ analysis_config = {
1461
+ ** data_config .get_config (),
1462
+ ** bias_config .get_config (),
1463
+ "predictor" : model_config .get_predictor_config (),
1464
+ "methods" : {
1465
+ "pre_training_bias" : {"methods" : pre_training_methods },
1466
+ "post_training_bias" : {"methods" : post_training_methods },
1467
+ }
1468
+ }
1457
1469
if model_predicted_label_config :
1458
- (
1459
- probability_threshold ,
1460
- predictor_config ,
1461
- ) = model_predicted_label_config .get_predictor_config ()
1470
+ probability_threshold , predictor_config = model_predicted_label_config .get_predictor_config ()
1462
1471
if predictor_config :
1463
1472
analysis_config ["predictor" ].update (predictor_config )
1464
- if probability_threshold is not None :
1465
- analysis_config ["probability_threshold" ] = probability_threshold
1466
-
1467
- analysis_config ["methods" ] = {
1468
- "pre_training_bias" : {"methods" : pre_training_methods },
1469
- "post_training_bias" : {"methods" : post_training_methods },
1470
- }
1473
+ _set (probability_threshold , "probability_threshold" , analysis_config )
1471
1474
return cls ._common (analysis_config )
1472
1475
1473
1476
@staticmethod
0 commit comments