@@ -1369,68 +1369,70 @@ def run_explainability(
1369
1369
experiment_config ,
1370
1370
)
1371
1371
1372
+ def run_bias_and_explainability (self ):
1373
+ """
1374
+ TODO:
1375
+ - add doc string
1376
+ - add logic
1377
+ - add tests
1378
+ """
1379
+ raise NotImplementedError (
1380
+ "Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1381
+ )
1382
+
1372
1383
1373
1384
class _AnalysisConfigGenerator :
1374
1385
"""Creates analysis_config objects for different type of runs."""
1375
1386
1387
+ @classmethod
1388
+ def bias_and_explainability (
1389
+ cls ,
1390
+ data_config : DataConfig ,
1391
+ model_config : ModelConfig ,
1392
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1393
+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1394
+ bias_config : BiasConfig ,
1395
+ pre_training_methods : Union [str , List [str ]] = "all" ,
1396
+ post_training_methods : Union [str , List [str ]] = "all" ,
1397
+ ):
1398
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1399
+ analysis_config = cls ._add_methods (
1400
+ analysis_config ,
1401
+ pre_training_methods = pre_training_methods ,
1402
+ post_training_methods = post_training_methods ,
1403
+ explainability_config = explainability_config ,
1404
+ )
1405
+ analysis_config = cls ._add_predictor (
1406
+ analysis_config , model_config , model_predicted_label_config
1407
+ )
1408
+ return analysis_config
1409
+
1376
1410
@classmethod
1377
1411
def explainability (
1378
1412
cls ,
1379
1413
data_config : DataConfig ,
1380
1414
model_config : ModelConfig ,
1381
- model_scores : ModelPredictedLabelConfig ,
1382
- explainability_config : ExplainabilityConfig ,
1415
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1416
+ explainability_config : Union [ ExplainabilityConfig , List [ ExplainabilityConfig ]] ,
1383
1417
):
1384
1418
"""Generates a config for Explainability"""
1385
- analysis_config = data_config .get_config ()
1386
- predictor_config = model_config .get_predictor_config ()
1387
- if isinstance (model_scores , ModelPredictedLabelConfig ):
1388
- (
1389
- probability_threshold ,
1390
- predicted_label_config ,
1391
- ) = model_scores .get_predictor_config ()
1392
- _set (probability_threshold , "probability_threshold" , analysis_config )
1393
- predictor_config .update (predicted_label_config )
1394
- else :
1395
- _set (model_scores , "label" , predictor_config )
1396
-
1397
- explainability_methods = {}
1398
- if isinstance (explainability_config , list ):
1399
- if len (explainability_config ) == 0 :
1400
- raise ValueError ("Please provide at least one explainability config." )
1401
- for config in explainability_config :
1402
- explain_config = config .get_explainability_config ()
1403
- explainability_methods .update (explain_config )
1404
- if not len (explainability_methods .keys ()) == len (explainability_config ):
1405
- raise ValueError ("Duplicate explainability configs are provided" )
1406
- if (
1407
- "shap" not in explainability_methods
1408
- and explainability_methods ["pdp" ].get ("features" , None ) is None
1409
- ):
1410
- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1411
- else :
1412
- if (
1413
- isinstance (explainability_config , PDPConfig )
1414
- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1415
- is None
1416
- ):
1417
- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1418
- explainability_methods = explainability_config .get_explainability_config ()
1419
- analysis_config ["methods" ] = explainability_methods
1420
- analysis_config ["predictor" ] = predictor_config
1421
- return cls ._common (analysis_config )
1419
+ analysis_config = data_config .analysis_config
1420
+ analysis_config = cls ._add_predictor (
1421
+ analysis_config , model_config , model_predicted_label_config
1422
+ )
1423
+ analysis_config = cls ._add_methods (
1424
+ analysis_config , explainability_config = explainability_config
1425
+ )
1426
+ return analysis_config
1422
1427
1423
1428
@classmethod
1424
1429
def bias_pre_training (
1425
1430
cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [str , List [str ]]
1426
1431
):
1427
1432
"""Generates a config for Bias Pre Training"""
1428
- analysis_config = {
1429
- ** data_config .get_config (),
1430
- ** bias_config .get_config (),
1431
- "methods" : {"pre_training_bias" : {"methods" : methods }},
1432
- }
1433
- return cls ._common (analysis_config )
1433
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1434
+ analysis_config = cls ._add_methods (analysis_config , pre_training_methods = methods )
1435
+ return analysis_config
1434
1436
1435
1437
@classmethod
1436
1438
def bias_post_training (
@@ -1442,21 +1444,12 @@ def bias_post_training(
1442
1444
model_config : ModelConfig ,
1443
1445
):
1444
1446
"""Generates a config for Bias Post Training"""
1445
- analysis_config = {
1446
- ** data_config .get_config (),
1447
- ** bias_config .get_config (),
1448
- "predictor" : {** model_config .get_predictor_config ()},
1449
- "methods" : {"post_training_bias" : {"methods" : methods }},
1450
- }
1451
- if model_predicted_label_config :
1452
- (
1453
- probability_threshold ,
1454
- predictor_config ,
1455
- ) = model_predicted_label_config .get_predictor_config ()
1456
- if predictor_config :
1457
- analysis_config ["predictor" ].update (predictor_config )
1458
- _set (probability_threshold , "probability_threshold" , analysis_config )
1459
- return cls ._common (analysis_config )
1447
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1448
+ analysis_config = cls ._add_methods (analysis_config , post_training_methods = methods )
1449
+ analysis_config = cls ._add_predictor (
1450
+ analysis_config , model_config , model_predicted_label_config
1451
+ )
1452
+ return analysis_config
1460
1453
1461
1454
@classmethod
1462
1455
def bias (
@@ -1469,34 +1462,95 @@ def bias(
1469
1462
post_training_methods : Union [str , List [str ]] = "all" ,
1470
1463
):
1471
1464
"""Generates a config for Bias"""
1472
- analysis_config = {
1473
- ** data_config .get_config (),
1474
- ** bias_config .get_config (),
1475
- "predictor" : model_config .get_predictor_config (),
1476
- "methods" : {
1477
- "pre_training_bias" : {"methods" : pre_training_methods },
1478
- "post_training_bias" : {"methods" : post_training_methods },
1479
- },
1480
- }
1481
- if model_predicted_label_config :
1465
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1466
+ analysis_config = cls ._add_methods (
1467
+ analysis_config ,
1468
+ pre_training_methods = pre_training_methods ,
1469
+ post_training_methods = post_training_methods ,
1470
+ )
1471
+ analysis_config = cls ._add_predictor (
1472
+ analysis_config , model_config , model_predicted_label_config
1473
+ )
1474
+ return analysis_config
1475
+
1476
+ @classmethod
1477
+ def _add_predictor (cls , analysis_config , model_config , model_predicted_label_config ):
1478
+ analysis_config = {** analysis_config }
1479
+ analysis_config ["predictor" ] = model_config .get_predictor_config ()
1480
+ if isinstance (model_predicted_label_config , ModelPredictedLabelConfig ):
1482
1481
(
1483
1482
probability_threshold ,
1484
1483
predictor_config ,
1485
1484
) = model_predicted_label_config .get_predictor_config ()
1486
1485
if predictor_config :
1487
1486
analysis_config ["predictor" ].update (predictor_config )
1488
1487
_set (probability_threshold , "probability_threshold" , analysis_config )
1489
- return cls ._common (analysis_config )
1490
-
1491
- @staticmethod
1492
- def _common (analysis_config ):
1493
- """Extends analysis config with common values"""
1494
- analysis_config ["methods" ]["report" ] = {
1495
- "name" : "report" ,
1496
- "title" : "Analysis Report" ,
1497
- }
1488
+ else :
1489
+ _set (model_predicted_label_config , "label" , analysis_config ["predictor" ])
1498
1490
return analysis_config
1499
1491
1492
+ @classmethod
1493
+ def _add_methods (
1494
+ cls ,
1495
+ analysis_config ,
1496
+ pre_training_methods = None ,
1497
+ post_training_methods = None ,
1498
+ explainability_config = None ,
1499
+ report = True ,
1500
+ ):
1501
+ # validate
1502
+ params = [pre_training_methods , post_training_methods , explainability_config ]
1503
+ if all ([1 if p is None else 0 for p in params ]):
1504
+ raise AttributeError (
1505
+ "analysis_config must have at least one working method: "
1506
+ "One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1507
+ )
1508
+
1509
+ # main logic
1510
+ analysis_config = {** analysis_config }
1511
+ if "methods" not in analysis_config :
1512
+ analysis_config ["methods" ] = {}
1513
+
1514
+ if report :
1515
+ analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
1516
+
1517
+ if pre_training_methods :
1518
+ analysis_config ["methods" ]["pre_training_bias" ] = {"methods" : pre_training_methods }
1519
+
1520
+ if post_training_methods :
1521
+ analysis_config ["methods" ]["post_training_bias" ] = {"methods" : post_training_methods }
1522
+
1523
+ if explainability_config is not None :
1524
+ explainability_methods = cls ._merge_explainability_configs (explainability_config )
1525
+ analysis_config ["methods" ] = {** analysis_config ["methods" ], ** explainability_methods }
1526
+ return analysis_config
1527
+
1528
+ @classmethod
1529
+ def _merge_explainability_configs (
1530
+ cls , explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]]
1531
+ ):
1532
+ if isinstance (explainability_config , list ):
1533
+ explainability_methods = {}
1534
+ if len (explainability_config ) == 0 :
1535
+ raise ValueError ("Please provide at least one explainability config." )
1536
+ for config in explainability_config :
1537
+ explain_config = config .get_explainability_config ()
1538
+ explainability_methods .update (explain_config )
1539
+ if not len (explainability_methods ) == len (explainability_config ):
1540
+ raise ValueError ("Duplicate explainability configs are provided" )
1541
+ if (
1542
+ "shap" not in explainability_methods
1543
+ and "features" not in explainability_methods ["pdp" ]
1544
+ ):
1545
+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1546
+ return explainability_methods
1547
+ if (
1548
+ isinstance (explainability_config , PDPConfig )
1549
+ and "features" not in explainability_config .get_explainability_config ()["pdp" ]
1550
+ ):
1551
+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1552
+ return explainability_config .get_explainability_config ()
1553
+
1500
1554
1501
1555
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
1502
1556
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
0 commit comments