@@ -1369,70 +1369,72 @@ def run_explainability(
1369
1369
experiment_config ,
1370
1370
)
1371
1371
1372
+ def run_bias_and_explainability (self ):
1373
+ """
1374
+ TODO:
1375
+ - add doc string
1376
+ - add logic
1377
+ - add tests
1378
+ """
1379
+ raise NotImplementedError (
1380
+ "Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1381
+ )
1382
+
1372
1383
1373
1384
class _AnalysisConfigGenerator :
1374
1385
"""
1375
1386
Creates analysis_config objects for different type of runs.
1376
1387
"""
1377
1388
1389
+ @classmethod
1390
+ def bias_and_explainability (
1391
+ cls ,
1392
+ data_config : DataConfig ,
1393
+ model_config : ModelConfig ,
1394
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1395
+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1396
+ bias_config : BiasConfig ,
1397
+ pre_training_methods : Union [str , List [str ]] = "all" ,
1398
+ post_training_methods : Union [str , List [str ]] = "all" ,
1399
+ ):
1400
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1401
+ analysis_config = cls ._add_methods (
1402
+ analysis_config ,
1403
+ pre_training_methods = pre_training_methods ,
1404
+ post_training_methods = post_training_methods ,
1405
+ explainability_config = explainability_config ,
1406
+ )
1407
+ analysis_config = cls ._add_predictor (
1408
+ analysis_config , model_config , model_predicted_label_config
1409
+ )
1410
+ return analysis_config
1411
+
1378
1412
@classmethod
1379
1413
def explainability (
1380
1414
cls ,
1381
1415
data_config : DataConfig ,
1382
1416
model_config : ModelConfig ,
1383
- model_scores : ModelPredictedLabelConfig ,
1384
- explainability_config : ExplainabilityConfig ,
1417
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1418
+ explainability_config : Union [ ExplainabilityConfig , List [ ExplainabilityConfig ]] ,
1385
1419
):
1386
1420
"""Generates a config for Explainability"""
1387
- analysis_config = data_config .get_config ()
1388
- predictor_config = model_config .get_predictor_config ()
1389
- if isinstance (model_scores , ModelPredictedLabelConfig ):
1390
- (
1391
- probability_threshold ,
1392
- predicted_label_config ,
1393
- ) = model_scores .get_predictor_config ()
1394
- _set (probability_threshold , "probability_threshold" , analysis_config )
1395
- predictor_config .update (predicted_label_config )
1396
- else :
1397
- _set (model_scores , "label" , predictor_config )
1398
-
1399
- explainability_methods = {}
1400
- if isinstance (explainability_config , list ):
1401
- if len (explainability_config ) == 0 :
1402
- raise ValueError ("Please provide at least one explainability config." )
1403
- for config in explainability_config :
1404
- explain_config = config .get_explainability_config ()
1405
- explainability_methods .update (explain_config )
1406
- if not len (explainability_methods .keys ()) == len (explainability_config ):
1407
- raise ValueError ("Duplicate explainability configs are provided" )
1408
- if (
1409
- "shap" not in explainability_methods
1410
- and explainability_methods ["pdp" ].get ("features" , None ) is None
1411
- ):
1412
- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1413
- else :
1414
- if (
1415
- isinstance (explainability_config , PDPConfig )
1416
- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1417
- is None
1418
- ):
1419
- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1420
- explainability_methods = explainability_config .get_explainability_config ()
1421
- analysis_config ["methods" ] = explainability_methods
1422
- analysis_config ["predictor" ] = predictor_config
1423
- return cls ._common (analysis_config )
1421
+ analysis_config = data_config .analysis_config
1422
+ analysis_config = cls ._add_predictor (
1423
+ analysis_config , model_config , model_predicted_label_config
1424
+ )
1425
+ analysis_config = cls ._add_methods (
1426
+ analysis_config , explainability_config = explainability_config
1427
+ )
1428
+ return analysis_config
1424
1429
1425
1430
@classmethod
1426
1431
def bias_pre_training (
1427
1432
cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [str , List [str ]]
1428
1433
):
1429
1434
"""Generates a config for Bias Pre Training"""
1430
- analysis_config = {
1431
- ** data_config .get_config (),
1432
- ** bias_config .get_config (),
1433
- "methods" : {"pre_training_bias" : {"methods" : methods }},
1434
- }
1435
- return cls ._common (analysis_config )
1435
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1436
+ analysis_config = cls ._add_methods (analysis_config , pre_training_methods = methods )
1437
+ return analysis_config
1436
1438
1437
1439
@classmethod
1438
1440
def bias_post_training (
@@ -1444,21 +1446,12 @@ def bias_post_training(
1444
1446
model_config : ModelConfig ,
1445
1447
):
1446
1448
"""Generates a config for Bias Post Training"""
1447
- analysis_config = {
1448
- ** data_config .get_config (),
1449
- ** bias_config .get_config (),
1450
- "predictor" : {** model_config .get_predictor_config ()},
1451
- "methods" : {"post_training_bias" : {"methods" : methods }},
1452
- }
1453
- if model_predicted_label_config :
1454
- (
1455
- probability_threshold ,
1456
- predictor_config ,
1457
- ) = model_predicted_label_config .get_predictor_config ()
1458
- if predictor_config :
1459
- analysis_config ["predictor" ].update (predictor_config )
1460
- _set (probability_threshold , "probability_threshold" , analysis_config )
1461
- return cls ._common (analysis_config )
1449
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1450
+ analysis_config = cls ._add_methods (analysis_config , post_training_methods = methods )
1451
+ analysis_config = cls ._add_predictor (
1452
+ analysis_config , model_config , model_predicted_label_config
1453
+ )
1454
+ return analysis_config
1462
1455
1463
1456
@classmethod
1464
1457
def bias (
@@ -1471,34 +1464,95 @@ def bias(
1471
1464
post_training_methods : Union [str , List [str ]] = "all" ,
1472
1465
):
1473
1466
"""Generates a config for Bias"""
1474
- analysis_config = {
1475
- ** data_config .get_config (),
1476
- ** bias_config .get_config (),
1477
- "predictor" : model_config .get_predictor_config (),
1478
- "methods" : {
1479
- "pre_training_bias" : {"methods" : pre_training_methods },
1480
- "post_training_bias" : {"methods" : post_training_methods },
1481
- },
1482
- }
1483
- if model_predicted_label_config :
1467
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1468
+ analysis_config = cls ._add_methods (
1469
+ analysis_config ,
1470
+ pre_training_methods = pre_training_methods ,
1471
+ post_training_methods = post_training_methods ,
1472
+ )
1473
+ analysis_config = cls ._add_predictor (
1474
+ analysis_config , model_config , model_predicted_label_config
1475
+ )
1476
+ return analysis_config
1477
+
1478
+ @classmethod
1479
+ def _add_predictor (cls , analysis_config , model_config , model_predicted_label_config ):
1480
+ analysis_config = {** analysis_config }
1481
+ analysis_config ["predictor" ] = model_config .get_predictor_config ()
1482
+ if isinstance (model_predicted_label_config , ModelPredictedLabelConfig ):
1484
1483
(
1485
1484
probability_threshold ,
1486
1485
predictor_config ,
1487
1486
) = model_predicted_label_config .get_predictor_config ()
1488
1487
if predictor_config :
1489
1488
analysis_config ["predictor" ].update (predictor_config )
1490
1489
_set (probability_threshold , "probability_threshold" , analysis_config )
1491
- return cls ._common (analysis_config )
1492
-
1493
- @staticmethod
1494
- def _common (analysis_config ):
1495
- """Extends analysis config with common values"""
1496
- analysis_config ["methods" ]["report" ] = {
1497
- "name" : "report" ,
1498
- "title" : "Analysis Report" ,
1499
- }
1490
+ else :
1491
+ _set (model_predicted_label_config , "label" , analysis_config ["predictor" ])
1500
1492
return analysis_config
1501
1493
1494
+ @classmethod
1495
+ def _add_methods (
1496
+ cls ,
1497
+ analysis_config ,
1498
+ pre_training_methods = None ,
1499
+ post_training_methods = None ,
1500
+ explainability_config = None ,
1501
+ report = True ,
1502
+ ):
1503
+ # validate
1504
+ params = [pre_training_methods , post_training_methods , explainability_config ]
1505
+ if all ([1 if p is None else 0 for p in params ]):
1506
+ raise AttributeError (
1507
+ "analysis_config must have at least one working method: "
1508
+ "One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1509
+ )
1510
+
1511
+ # main logic
1512
+ analysis_config = {** analysis_config }
1513
+ if "methods" not in analysis_config :
1514
+ analysis_config ["methods" ] = {}
1515
+
1516
+ if report :
1517
+ analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
1518
+
1519
+ if pre_training_methods :
1520
+ analysis_config ["methods" ]["pre_training_bias" ] = {"methods" : pre_training_methods }
1521
+
1522
+ if post_training_methods :
1523
+ analysis_config ["methods" ]["post_training_bias" ] = {"methods" : post_training_methods }
1524
+
1525
+ if explainability_config is not None :
1526
+ explainability_methods = cls ._merge_explainability_configs (explainability_config )
1527
+ analysis_config ["methods" ] = {** analysis_config ["methods" ], ** explainability_methods }
1528
+ return analysis_config
1529
+
1530
+ @classmethod
1531
+ def _merge_explainability_configs (
1532
+ cls , explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]]
1533
+ ):
1534
+ if isinstance (explainability_config , list ):
1535
+ explainability_methods = {}
1536
+ if len (explainability_config ) == 0 :
1537
+ raise ValueError ("Please provide at least one explainability config." )
1538
+ for config in explainability_config :
1539
+ explain_config = config .get_explainability_config ()
1540
+ explainability_methods .update (explain_config )
1541
+ if not len (explainability_methods ) == len (explainability_config ):
1542
+ raise ValueError ("Duplicate explainability configs are provided" )
1543
+ if (
1544
+ "shap" not in explainability_methods
1545
+ and "features" not in explainability_methods ["pdp" ]
1546
+ ):
1547
+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1548
+ return explainability_methods
1549
+ if (
1550
+ isinstance (explainability_config , PDPConfig )
1551
+ and "features" not in explainability_config .get_explainability_config ()["pdp" ]
1552
+ ):
1553
+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1554
+ return explainability_config .get_explainability_config ()
1555
+
1502
1556
1503
1557
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
1504
1558
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
0 commit comments