Skip to content

Commit 2bb714f

Browse files
committed
refactored _AnalysisConfigGenerator
to re-use for generating a config for both bias and explainability at once
1 parent fb68b4a commit 2bb714f

File tree

2 files changed

+209
-78
lines changed

2 files changed

+209
-78
lines changed

src/sagemaker/clarify.py

Lines changed: 132 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,68 +1368,70 @@ def run_explainability(
13681368
experiment_config,
13691369
)
13701370

1371+
def run_bias_and_explainability(self):
1372+
"""
1373+
TODO:
1374+
- add doc string
1375+
- add logic
1376+
- add tests
1377+
"""
1378+
raise NotImplementedError(
1379+
"Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1380+
)
1381+
13711382

13721383
class _AnalysisConfigGenerator:
13731384
"""
13741385
Creates analysis_config objects for different type of runs.
13751386
"""
13761387

13771388
@classmethod
1378-
def explainability(
1389+
def bias_and_explainability(
13791390
cls,
13801391
data_config: DataConfig,
13811392
model_config: ModelConfig,
1382-
model_scores: ModelPredictedLabelConfig,
1383-
explainability_config: ExplainabilityConfig,
1393+
model_predicted_label_config: ModelPredictedLabelConfig,
1394+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
1395+
bias_config: BiasConfig,
1396+
pre_training_methods: Union[str, List[str]] = "all",
1397+
post_training_methods: Union[str, List[str]] = "all",
13841398
):
1385-
analysis_config = data_config.get_config()
1386-
predictor_config = model_config.get_predictor_config()
1387-
if isinstance(model_scores, ModelPredictedLabelConfig):
1388-
(
1389-
probability_threshold,
1390-
predicted_label_config,
1391-
) = model_scores.get_predictor_config()
1392-
_set(probability_threshold, "probability_threshold", analysis_config)
1393-
predictor_config.update(predicted_label_config)
1394-
else:
1395-
_set(model_scores, "label", predictor_config)
1399+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1400+
analysis_config = cls._add_methods(
1401+
analysis_config,
1402+
pre_training_methods=pre_training_methods,
1403+
post_training_methods=post_training_methods,
1404+
explainability_config=explainability_config,
1405+
)
1406+
analysis_config = cls._add_predictor(
1407+
analysis_config, model_config, model_predicted_label_config
1408+
)
1409+
return analysis_config
13961410

1397-
explainability_methods = {}
1398-
if isinstance(explainability_config, list):
1399-
if len(explainability_config) == 0:
1400-
raise ValueError("Please provide at least one explainability config.")
1401-
for config in explainability_config:
1402-
explain_config = config.get_explainability_config()
1403-
explainability_methods.update(explain_config)
1404-
if not len(explainability_methods.keys()) == len(explainability_config):
1405-
raise ValueError("Duplicate explainability configs are provided")
1406-
if (
1407-
"shap" not in explainability_methods
1408-
and explainability_methods["pdp"].get("features", None) is None
1409-
):
1410-
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1411-
else:
1412-
if (
1413-
isinstance(explainability_config, PDPConfig)
1414-
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1415-
is None
1416-
):
1417-
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1418-
explainability_methods = explainability_config.get_explainability_config()
1419-
analysis_config["methods"] = explainability_methods
1420-
analysis_config["predictor"] = predictor_config
1421-
return cls._common(analysis_config)
1411+
@classmethod
1412+
def explainability(
1413+
cls,
1414+
data_config: DataConfig,
1415+
model_config: ModelConfig,
1416+
model_predicted_label_config: ModelPredictedLabelConfig,
1417+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
1418+
):
1419+
analysis_config = data_config.analysis_config
1420+
analysis_config = cls._add_predictor(
1421+
analysis_config, model_config, model_predicted_label_config
1422+
)
1423+
analysis_config = cls._add_methods(
1424+
analysis_config, explainability_config=explainability_config
1425+
)
1426+
return analysis_config
14221427

14231428
@classmethod
14241429
def bias_pre_training(
14251430
cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]
14261431
):
1427-
analysis_config = {
1428-
**data_config.get_config(),
1429-
**bias_config.get_config(),
1430-
"methods": {"pre_training_bias": {"methods": methods}},
1431-
}
1432-
return cls._common(analysis_config)
1432+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1433+
analysis_config = cls._add_methods(analysis_config, pre_training_methods=methods)
1434+
return analysis_config
14331435

14341436
@classmethod
14351437
def bias_post_training(
@@ -1440,21 +1442,12 @@ def bias_post_training(
14401442
methods: Union[str, List[str]],
14411443
model_config: ModelConfig,
14421444
):
1443-
analysis_config = {
1444-
**data_config.get_config(),
1445-
**bias_config.get_config(),
1446-
"predictor": {**model_config.get_predictor_config()},
1447-
"methods": {"post_training_bias": {"methods": methods}},
1448-
}
1449-
if model_predicted_label_config:
1450-
(
1451-
probability_threshold,
1452-
predictor_config,
1453-
) = model_predicted_label_config.get_predictor_config()
1454-
if predictor_config:
1455-
analysis_config["predictor"].update(predictor_config)
1456-
_set(probability_threshold, "probability_threshold", analysis_config)
1457-
return cls._common(analysis_config)
1445+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1446+
analysis_config = cls._add_methods(analysis_config, post_training_methods=methods)
1447+
analysis_config = cls._add_predictor(
1448+
analysis_config, model_config, model_predicted_label_config
1449+
)
1450+
return analysis_config
14581451

14591452
@classmethod
14601453
def bias(
@@ -1466,33 +1459,95 @@ def bias(
14661459
pre_training_methods: Union[str, List[str]] = "all",
14671460
post_training_methods: Union[str, List[str]] = "all",
14681461
):
1469-
analysis_config = {
1470-
**data_config.get_config(),
1471-
**bias_config.get_config(),
1472-
"predictor": model_config.get_predictor_config(),
1473-
"methods": {
1474-
"pre_training_bias": {"methods": pre_training_methods},
1475-
"post_training_bias": {"methods": post_training_methods},
1476-
},
1477-
}
1478-
if model_predicted_label_config:
1462+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1463+
analysis_config = cls._add_methods(
1464+
analysis_config,
1465+
pre_training_methods=pre_training_methods,
1466+
post_training_methods=post_training_methods,
1467+
)
1468+
analysis_config = cls._add_predictor(
1469+
analysis_config, model_config, model_predicted_label_config
1470+
)
1471+
return analysis_config
1472+
1473+
@classmethod
1474+
def _add_predictor(cls, analysis_config, model_config, model_predicted_label_config):
1475+
analysis_config = {**analysis_config}
1476+
analysis_config["predictor"] = model_config.get_predictor_config()
1477+
if isinstance(model_predicted_label_config, ModelPredictedLabelConfig):
14791478
(
14801479
probability_threshold,
14811480
predictor_config,
14821481
) = model_predicted_label_config.get_predictor_config()
14831482
if predictor_config:
14841483
analysis_config["predictor"].update(predictor_config)
14851484
_set(probability_threshold, "probability_threshold", analysis_config)
1486-
return cls._common(analysis_config)
1485+
else:
1486+
_set(model_predicted_label_config, "label", analysis_config["predictor"])
1487+
return analysis_config
14871488

1488-
@staticmethod
1489-
def _common(analysis_config):
1490-
analysis_config["methods"]["report"] = {
1491-
"name": "report",
1492-
"title": "Analysis Report",
1493-
}
1489+
@classmethod
1490+
def _add_methods(
1491+
cls,
1492+
analysis_config,
1493+
pre_training_methods=None,
1494+
post_training_methods=None,
1495+
explainability_config=None,
1496+
report=True,
1497+
):
1498+
# validate
1499+
params = [pre_training_methods, post_training_methods, explainability_config]
1500+
if all([1 if p is None else 0 for p in params]):
1501+
raise AttributeError(
1502+
"analysis_config must have at least one working method: "
1503+
"One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1504+
)
1505+
1506+
# main logic
1507+
analysis_config = {**analysis_config}
1508+
if "methods" not in analysis_config:
1509+
analysis_config["methods"] = {}
1510+
1511+
if report:
1512+
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
1513+
1514+
if pre_training_methods:
1515+
analysis_config["methods"]["pre_training_bias"] = {"methods": pre_training_methods}
1516+
1517+
if post_training_methods:
1518+
analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods}
1519+
1520+
if explainability_config is not None:
1521+
explainability_methods = cls._merge_explainability_configs(explainability_config)
1522+
analysis_config["methods"] = {**analysis_config["methods"], **explainability_methods}
14941523
return analysis_config
14951524

1525+
@classmethod
1526+
def _merge_explainability_configs(
1527+
cls, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]]
1528+
):
1529+
if isinstance(explainability_config, list):
1530+
explainability_methods = {}
1531+
if len(explainability_config) == 0:
1532+
raise ValueError("Please provide at least one explainability config.")
1533+
for config in explainability_config:
1534+
explain_config = config.get_explainability_config()
1535+
explainability_methods.update(explain_config)
1536+
if not len(explainability_methods) == len(explainability_config):
1537+
raise ValueError("Duplicate explainability configs are provided")
1538+
if (
1539+
"shap" not in explainability_methods
1540+
and "features" not in explainability_methods["pdp"]
1541+
):
1542+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1543+
return explainability_methods
1544+
if (
1545+
isinstance(explainability_config, PDPConfig)
1546+
and "features" not in explainability_config.get_explainability_config()["pdp"]
1547+
):
1548+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1549+
return explainability_config.get_explainability_config()
1550+
14961551

14971552
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
14981553
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.

tests/unit/test_clarify.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,9 @@ def test_explainability_with_invalid_config(
10941094
"initial_instance_count": 1,
10951095
}
10961096
with pytest.raises(
1097-
AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'"
1097+
AttributeError,
1098+
match="analysis_config must have at least one working method: "
1099+
"One of the `pre_training_methods`, `post_training_methods`, `explainability_config`.",
10981100
):
10991101
_run_test_explain(
11001102
name_from_base,
@@ -1320,6 +1322,80 @@ def test_analysis_config_generator_for_explainability(data_config, model_config)
13201322
assert actual == expected
13211323

13221324

1325+
def test_analysis_config_generator_for_explainability_failing(data_config, model_config):
1326+
model_scores = ModelPredictedLabelConfig(
1327+
probability="pr",
1328+
label_headers=["success"],
1329+
)
1330+
with pytest.raises(
1331+
ValueError, match="PDP features must be provided when ShapConfig is not provided"
1332+
):
1333+
_AnalysisConfigGenerator.explainability(
1334+
data_config,
1335+
model_config,
1336+
model_scores,
1337+
PDPConfig(),
1338+
)
1339+
1340+
with pytest.raises(ValueError, match="Duplicate explainability configs are provided"):
1341+
_AnalysisConfigGenerator.explainability(
1342+
data_config,
1343+
model_config,
1344+
model_scores,
1345+
[SHAPConfig(), SHAPConfig()],
1346+
)
1347+
1348+
with pytest.raises(ValueError, match="Please provide at least one explainability config."):
1349+
_AnalysisConfigGenerator.explainability(
1350+
data_config,
1351+
model_config,
1352+
model_scores,
1353+
[],
1354+
)
1355+
1356+
1357+
def test_analysis_config_generator_for_bias_explainability(
1358+
data_config, data_bias_config, model_config
1359+
):
1360+
model_predicted_label_config = ModelPredictedLabelConfig(
1361+
probability="pr",
1362+
label_headers=["success"],
1363+
)
1364+
actual = _AnalysisConfigGenerator.bias_and_explainability(
1365+
data_config,
1366+
model_config,
1367+
model_predicted_label_config,
1368+
[SHAPConfig(), PDPConfig()],
1369+
data_bias_config,
1370+
pre_training_methods="all",
1371+
post_training_methods="all",
1372+
)
1373+
expected = {
1374+
"dataset_type": "text/csv",
1375+
"facet": [{"name_or_index": "F1"}],
1376+
"group_variable": "F2",
1377+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1378+
"joinsource_name_or_index": "F4",
1379+
"label": "Label",
1380+
"label_values_or_threshold": [1],
1381+
"methods": {
1382+
"pdp": {"grid_resolution": 15, "top_k_features": 10},
1383+
"post_training_bias": {"methods": "all"},
1384+
"pre_training_bias": {"methods": "all"},
1385+
"report": {"name": "report", "title": "Analysis Report"},
1386+
"shap": {"save_local_shap_values": True, "use_logit": False},
1387+
},
1388+
"predictor": {
1389+
"initial_instance_count": 1,
1390+
"instance_type": "ml.c5.xlarge",
1391+
"label_headers": ["success"],
1392+
"model_name": "xgboost-model",
1393+
"probability": "pr",
1394+
},
1395+
}
1396+
assert actual == expected
1397+
1398+
13231399
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
13241400
actual = _AnalysisConfigGenerator.bias_pre_training(
13251401
data_config, data_bias_config, methods="all"

0 commit comments

Comments
 (0)