@@ -1368,68 +1368,70 @@ def run_explainability(
1368
1368
experiment_config ,
1369
1369
)
1370
1370
1371
+ def run_bias_and_explainability (self ):
1372
+ """
1373
+ TODO:
1374
+ - add doc string
1375
+ - add logic
1376
+ - add tests
1377
+ """
1378
+ raise NotImplementedError (
1379
+ "Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1380
+ )
1381
+
1371
1382
1372
1383
class _AnalysisConfigGenerator :
1373
1384
"""
1374
1385
Creates analysis_config objects for different type of runs.
1375
1386
"""
1376
1387
1377
1388
@classmethod
1378
- def explainability (
1389
+ def bias_and_explainability (
1379
1390
cls ,
1380
1391
data_config : DataConfig ,
1381
1392
model_config : ModelConfig ,
1382
- model_scores : ModelPredictedLabelConfig ,
1383
- explainability_config : ExplainabilityConfig ,
1393
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1394
+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1395
+ bias_config : BiasConfig ,
1396
+ pre_training_methods : Union [str , List [str ]] = "all" ,
1397
+ post_training_methods : Union [str , List [str ]] = "all" ,
1384
1398
):
1385
- analysis_config = data_config .get_config ()
1386
- predictor_config = model_config . get_predictor_config ()
1387
- if isinstance ( model_scores , ModelPredictedLabelConfig ):
1388
- (
1389
- probability_threshold ,
1390
- predicted_label_config ,
1391
- ) = model_scores . get_predictor_config ( )
1392
- _set ( probability_threshold , "probability_threshold" , analysis_config )
1393
- predictor_config . update ( predicted_label_config )
1394
- else :
1395
- _set ( model_scores , "label" , predictor_config )
1399
+ analysis_config = { ** data_config .get_config (), ** bias_config . get_config ()}
1400
+ analysis_config = cls . _add_methods (
1401
+ analysis_config ,
1402
+ pre_training_methods = pre_training_methods ,
1403
+ post_training_methods = post_training_methods ,
1404
+ explainability_config = explainability_config ,
1405
+ )
1406
+ analysis_config = cls . _add_predictor (
1407
+ analysis_config , model_config , model_predicted_label_config
1408
+ )
1409
+ return analysis_config
1396
1410
1397
- explainability_methods = {}
1398
- if isinstance (explainability_config , list ):
1399
- if len (explainability_config ) == 0 :
1400
- raise ValueError ("Please provide at least one explainability config." )
1401
- for config in explainability_config :
1402
- explain_config = config .get_explainability_config ()
1403
- explainability_methods .update (explain_config )
1404
- if not len (explainability_methods .keys ()) == len (explainability_config ):
1405
- raise ValueError ("Duplicate explainability configs are provided" )
1406
- if (
1407
- "shap" not in explainability_methods
1408
- and explainability_methods ["pdp" ].get ("features" , None ) is None
1409
- ):
1410
- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1411
- else :
1412
- if (
1413
- isinstance (explainability_config , PDPConfig )
1414
- and explainability_config .get_explainability_config ()["pdp" ].get ("features" , None )
1415
- is None
1416
- ):
1417
- raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1418
- explainability_methods = explainability_config .get_explainability_config ()
1419
- analysis_config ["methods" ] = explainability_methods
1420
- analysis_config ["predictor" ] = predictor_config
1421
- return cls ._common (analysis_config )
1411
+ @classmethod
1412
+ def explainability (
1413
+ cls ,
1414
+ data_config : DataConfig ,
1415
+ model_config : ModelConfig ,
1416
+ model_predicted_label_config : ModelPredictedLabelConfig ,
1417
+ explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]],
1418
+ ):
1419
+ analysis_config = data_config .analysis_config
1420
+ analysis_config = cls ._add_predictor (
1421
+ analysis_config , model_config , model_predicted_label_config
1422
+ )
1423
+ analysis_config = cls ._add_methods (
1424
+ analysis_config , explainability_config = explainability_config
1425
+ )
1426
+ return analysis_config
1422
1427
1423
1428
@classmethod
1424
1429
def bias_pre_training (
1425
1430
cls , data_config : DataConfig , bias_config : BiasConfig , methods : Union [str , List [str ]]
1426
1431
):
1427
- analysis_config = {
1428
- ** data_config .get_config (),
1429
- ** bias_config .get_config (),
1430
- "methods" : {"pre_training_bias" : {"methods" : methods }},
1431
- }
1432
- return cls ._common (analysis_config )
1432
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1433
+ analysis_config = cls ._add_methods (analysis_config , pre_training_methods = methods )
1434
+ return analysis_config
1433
1435
1434
1436
@classmethod
1435
1437
def bias_post_training (
@@ -1440,21 +1442,12 @@ def bias_post_training(
1440
1442
methods : Union [str , List [str ]],
1441
1443
model_config : ModelConfig ,
1442
1444
):
1443
- analysis_config = {
1444
- ** data_config .get_config (),
1445
- ** bias_config .get_config (),
1446
- "predictor" : {** model_config .get_predictor_config ()},
1447
- "methods" : {"post_training_bias" : {"methods" : methods }},
1448
- }
1449
- if model_predicted_label_config :
1450
- (
1451
- probability_threshold ,
1452
- predictor_config ,
1453
- ) = model_predicted_label_config .get_predictor_config ()
1454
- if predictor_config :
1455
- analysis_config ["predictor" ].update (predictor_config )
1456
- _set (probability_threshold , "probability_threshold" , analysis_config )
1457
- return cls ._common (analysis_config )
1445
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1446
+ analysis_config = cls ._add_methods (analysis_config , post_training_methods = methods )
1447
+ analysis_config = cls ._add_predictor (
1448
+ analysis_config , model_config , model_predicted_label_config
1449
+ )
1450
+ return analysis_config
1458
1451
1459
1452
@classmethod
1460
1453
def bias (
@@ -1466,33 +1459,95 @@ def bias(
1466
1459
pre_training_methods : Union [str , List [str ]] = "all" ,
1467
1460
post_training_methods : Union [str , List [str ]] = "all" ,
1468
1461
):
1469
- analysis_config = {
1470
- ** data_config .get_config (),
1471
- ** bias_config .get_config (),
1472
- "predictor" : model_config .get_predictor_config (),
1473
- "methods" : {
1474
- "pre_training_bias" : {"methods" : pre_training_methods },
1475
- "post_training_bias" : {"methods" : post_training_methods },
1476
- },
1477
- }
1478
- if model_predicted_label_config :
1462
+ analysis_config = {** data_config .get_config (), ** bias_config .get_config ()}
1463
+ analysis_config = cls ._add_methods (
1464
+ analysis_config ,
1465
+ pre_training_methods = pre_training_methods ,
1466
+ post_training_methods = post_training_methods ,
1467
+ )
1468
+ analysis_config = cls ._add_predictor (
1469
+ analysis_config , model_config , model_predicted_label_config
1470
+ )
1471
+ return analysis_config
1472
+
1473
+ @classmethod
1474
+ def _add_predictor (cls , analysis_config , model_config , model_predicted_label_config ):
1475
+ analysis_config = {** analysis_config }
1476
+ analysis_config ["predictor" ] = model_config .get_predictor_config ()
1477
+ if isinstance (model_predicted_label_config , ModelPredictedLabelConfig ):
1479
1478
(
1480
1479
probability_threshold ,
1481
1480
predictor_config ,
1482
1481
) = model_predicted_label_config .get_predictor_config ()
1483
1482
if predictor_config :
1484
1483
analysis_config ["predictor" ].update (predictor_config )
1485
1484
_set (probability_threshold , "probability_threshold" , analysis_config )
1486
- return cls ._common (analysis_config )
1485
+ else :
1486
+ _set (model_predicted_label_config , "label" , analysis_config ["predictor" ])
1487
+ return analysis_config
1487
1488
1488
- @staticmethod
1489
- def _common (analysis_config ):
1490
- analysis_config ["methods" ]["report" ] = {
1491
- "name" : "report" ,
1492
- "title" : "Analysis Report" ,
1493
- }
1489
+ @classmethod
1490
+ def _add_methods (
1491
+ cls ,
1492
+ analysis_config ,
1493
+ pre_training_methods = None ,
1494
+ post_training_methods = None ,
1495
+ explainability_config = None ,
1496
+ report = True ,
1497
+ ):
1498
+ # validate
1499
+ params = [pre_training_methods , post_training_methods , explainability_config ]
1500
+ if all ([1 if p is None else 0 for p in params ]):
1501
+ raise AttributeError (
1502
+ "analysis_config must have at least one working method: "
1503
+ "One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1504
+ )
1505
+
1506
+ # main logic
1507
+ analysis_config = {** analysis_config }
1508
+ if "methods" not in analysis_config :
1509
+ analysis_config ["methods" ] = {}
1510
+
1511
+ if report :
1512
+ analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
1513
+
1514
+ if pre_training_methods :
1515
+ analysis_config ["methods" ]["pre_training_bias" ] = {"methods" : pre_training_methods }
1516
+
1517
+ if post_training_methods :
1518
+ analysis_config ["methods" ]["post_training_bias" ] = {"methods" : post_training_methods }
1519
+
1520
+ if explainability_config is not None :
1521
+ explainability_methods = cls ._merge_explainability_configs (explainability_config )
1522
+ analysis_config ["methods" ] = {** analysis_config ["methods" ], ** explainability_methods }
1494
1523
return analysis_config
1495
1524
1525
+ @classmethod
1526
+ def _merge_explainability_configs (
1527
+ cls , explainability_config : Union [ExplainabilityConfig , List [ExplainabilityConfig ]]
1528
+ ):
1529
+ if isinstance (explainability_config , list ):
1530
+ explainability_methods = {}
1531
+ if len (explainability_config ) == 0 :
1532
+ raise ValueError ("Please provide at least one explainability config." )
1533
+ for config in explainability_config :
1534
+ explain_config = config .get_explainability_config ()
1535
+ explainability_methods .update (explain_config )
1536
+ if not len (explainability_methods ) == len (explainability_config ):
1537
+ raise ValueError ("Duplicate explainability configs are provided" )
1538
+ if (
1539
+ "shap" not in explainability_methods
1540
+ and "features" not in explainability_methods ["pdp" ]
1541
+ ):
1542
+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1543
+ return explainability_methods
1544
+ if (
1545
+ isinstance (explainability_config , PDPConfig )
1546
+ and "features" not in explainability_config .get_explainability_config ()["pdp" ]
1547
+ ):
1548
+ raise ValueError ("PDP features must be provided when ShapConfig is not provided" )
1549
+ return explainability_config .get_explainability_config ()
1550
+
1496
1551
1497
1552
def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
1498
1553
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
0 commit comments