Skip to content

Commit 64a69c5

Browse files
committed
refactorred _AnalysisConfigGenerator
1 parent b50a44b commit 64a69c5

File tree

2 files changed

+210
-80
lines changed

2 files changed

+210
-80
lines changed

src/sagemaker/clarify.py

Lines changed: 133 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,68 +1369,70 @@ def run_explainability(
13691369
experiment_config,
13701370
)
13711371

1372+
def run_bias_and_explainability(self):
1373+
"""
1374+
TODO:
1375+
- add doc string
1376+
- add logic
1377+
- add tests
1378+
"""
1379+
raise NotImplementedError(
1380+
"Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1381+
)
1382+
13721383

13731384
class _AnalysisConfigGenerator:
13741385
"""Creates analysis_config objects for different type of runs."""
13751386

1387+
@classmethod
1388+
def bias_and_explainability(
1389+
cls,
1390+
data_config: DataConfig,
1391+
model_config: ModelConfig,
1392+
model_predicted_label_config: ModelPredictedLabelConfig,
1393+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
1394+
bias_config: BiasConfig,
1395+
pre_training_methods: Union[str, List[str]] = "all",
1396+
post_training_methods: Union[str, List[str]] = "all",
1397+
):
1398+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1399+
analysis_config = cls._add_methods(
1400+
analysis_config,
1401+
pre_training_methods=pre_training_methods,
1402+
post_training_methods=post_training_methods,
1403+
explainability_config=explainability_config,
1404+
)
1405+
analysis_config = cls._add_predictor(
1406+
analysis_config, model_config, model_predicted_label_config
1407+
)
1408+
return analysis_config
1409+
13761410
@classmethod
13771411
def explainability(
13781412
cls,
13791413
data_config: DataConfig,
13801414
model_config: ModelConfig,
1381-
model_scores: ModelPredictedLabelConfig,
1382-
explainability_config: ExplainabilityConfig,
1415+
model_predicted_label_config: ModelPredictedLabelConfig,
1416+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
13831417
):
13841418
"""Generates a config for Explainability"""
1385-
analysis_config = data_config.get_config()
1386-
predictor_config = model_config.get_predictor_config()
1387-
if isinstance(model_scores, ModelPredictedLabelConfig):
1388-
(
1389-
probability_threshold,
1390-
predicted_label_config,
1391-
) = model_scores.get_predictor_config()
1392-
_set(probability_threshold, "probability_threshold", analysis_config)
1393-
predictor_config.update(predicted_label_config)
1394-
else:
1395-
_set(model_scores, "label", predictor_config)
1396-
1397-
explainability_methods = {}
1398-
if isinstance(explainability_config, list):
1399-
if len(explainability_config) == 0:
1400-
raise ValueError("Please provide at least one explainability config.")
1401-
for config in explainability_config:
1402-
explain_config = config.get_explainability_config()
1403-
explainability_methods.update(explain_config)
1404-
if not len(explainability_methods.keys()) == len(explainability_config):
1405-
raise ValueError("Duplicate explainability configs are provided")
1406-
if (
1407-
"shap" not in explainability_methods
1408-
and explainability_methods["pdp"].get("features", None) is None
1409-
):
1410-
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1411-
else:
1412-
if (
1413-
isinstance(explainability_config, PDPConfig)
1414-
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1415-
is None
1416-
):
1417-
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1418-
explainability_methods = explainability_config.get_explainability_config()
1419-
analysis_config["methods"] = explainability_methods
1420-
analysis_config["predictor"] = predictor_config
1421-
return cls._common(analysis_config)
1419+
analysis_config = data_config.analysis_config
1420+
analysis_config = cls._add_predictor(
1421+
analysis_config, model_config, model_predicted_label_config
1422+
)
1423+
analysis_config = cls._add_methods(
1424+
analysis_config, explainability_config=explainability_config
1425+
)
1426+
return analysis_config
14221427

14231428
@classmethod
14241429
def bias_pre_training(
14251430
cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]
14261431
):
14271432
"""Generates a config for Bias Pre Training"""
1428-
analysis_config = {
1429-
**data_config.get_config(),
1430-
**bias_config.get_config(),
1431-
"methods": {"pre_training_bias": {"methods": methods}},
1432-
}
1433-
return cls._common(analysis_config)
1433+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1434+
analysis_config = cls._add_methods(analysis_config, pre_training_methods=methods)
1435+
return analysis_config
14341436

14351437
@classmethod
14361438
def bias_post_training(
@@ -1442,21 +1444,12 @@ def bias_post_training(
14421444
model_config: ModelConfig,
14431445
):
14441446
"""Generates a config for Bias Post Training"""
1445-
analysis_config = {
1446-
**data_config.get_config(),
1447-
**bias_config.get_config(),
1448-
"predictor": {**model_config.get_predictor_config()},
1449-
"methods": {"post_training_bias": {"methods": methods}},
1450-
}
1451-
if model_predicted_label_config:
1452-
(
1453-
probability_threshold,
1454-
predictor_config,
1455-
) = model_predicted_label_config.get_predictor_config()
1456-
if predictor_config:
1457-
analysis_config["predictor"].update(predictor_config)
1458-
_set(probability_threshold, "probability_threshold", analysis_config)
1459-
return cls._common(analysis_config)
1447+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1448+
analysis_config = cls._add_methods(analysis_config, post_training_methods=methods)
1449+
analysis_config = cls._add_predictor(
1450+
analysis_config, model_config, model_predicted_label_config
1451+
)
1452+
return analysis_config
14601453

14611454
@classmethod
14621455
def bias(
@@ -1469,34 +1462,95 @@ def bias(
14691462
post_training_methods: Union[str, List[str]] = "all",
14701463
):
14711464
"""Generates a config for Bias"""
1472-
analysis_config = {
1473-
**data_config.get_config(),
1474-
**bias_config.get_config(),
1475-
"predictor": model_config.get_predictor_config(),
1476-
"methods": {
1477-
"pre_training_bias": {"methods": pre_training_methods},
1478-
"post_training_bias": {"methods": post_training_methods},
1479-
},
1480-
}
1481-
if model_predicted_label_config:
1465+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1466+
analysis_config = cls._add_methods(
1467+
analysis_config,
1468+
pre_training_methods=pre_training_methods,
1469+
post_training_methods=post_training_methods,
1470+
)
1471+
analysis_config = cls._add_predictor(
1472+
analysis_config, model_config, model_predicted_label_config
1473+
)
1474+
return analysis_config
1475+
1476+
@classmethod
1477+
def _add_predictor(cls, analysis_config, model_config, model_predicted_label_config):
1478+
analysis_config = {**analysis_config}
1479+
analysis_config["predictor"] = model_config.get_predictor_config()
1480+
if isinstance(model_predicted_label_config, ModelPredictedLabelConfig):
14821481
(
14831482
probability_threshold,
14841483
predictor_config,
14851484
) = model_predicted_label_config.get_predictor_config()
14861485
if predictor_config:
14871486
analysis_config["predictor"].update(predictor_config)
14881487
_set(probability_threshold, "probability_threshold", analysis_config)
1489-
return cls._common(analysis_config)
1490-
1491-
@staticmethod
1492-
def _common(analysis_config):
1493-
"""Extends analysis config with common values"""
1494-
analysis_config["methods"]["report"] = {
1495-
"name": "report",
1496-
"title": "Analysis Report",
1497-
}
1488+
else:
1489+
_set(model_predicted_label_config, "label", analysis_config["predictor"])
14981490
return analysis_config
14991491

1492+
@classmethod
1493+
def _add_methods(
1494+
cls,
1495+
analysis_config,
1496+
pre_training_methods=None,
1497+
post_training_methods=None,
1498+
explainability_config=None,
1499+
report=True,
1500+
):
1501+
# validate
1502+
params = [pre_training_methods, post_training_methods, explainability_config]
1503+
if all([1 if p is None else 0 for p in params]):
1504+
raise AttributeError(
1505+
"analysis_config must have at least one working method: "
1506+
"One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1507+
)
1508+
1509+
# main logic
1510+
analysis_config = {**analysis_config}
1511+
if "methods" not in analysis_config:
1512+
analysis_config["methods"] = {}
1513+
1514+
if report:
1515+
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
1516+
1517+
if pre_training_methods:
1518+
analysis_config["methods"]["pre_training_bias"] = {"methods": pre_training_methods}
1519+
1520+
if post_training_methods:
1521+
analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods}
1522+
1523+
if explainability_config is not None:
1524+
explainability_methods = cls._merge_explainability_configs(explainability_config)
1525+
analysis_config["methods"] = {**analysis_config["methods"], **explainability_methods}
1526+
return analysis_config
1527+
1528+
@classmethod
1529+
def _merge_explainability_configs(
1530+
cls, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]]
1531+
):
1532+
if isinstance(explainability_config, list):
1533+
explainability_methods = {}
1534+
if len(explainability_config) == 0:
1535+
raise ValueError("Please provide at least one explainability config.")
1536+
for config in explainability_config:
1537+
explain_config = config.get_explainability_config()
1538+
explainability_methods.update(explain_config)
1539+
if not len(explainability_methods) == len(explainability_config):
1540+
raise ValueError("Duplicate explainability configs are provided")
1541+
if (
1542+
"shap" not in explainability_methods
1543+
and "features" not in explainability_methods["pdp"]
1544+
):
1545+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1546+
return explainability_methods
1547+
if (
1548+
isinstance(explainability_config, PDPConfig)
1549+
and "features" not in explainability_config.get_explainability_config()["pdp"]
1550+
):
1551+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1552+
return explainability_config.get_explainability_config()
1553+
15001554

15011555
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
15021556
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.

tests/unit/test_clarify.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,9 @@ def test_explainability_with_invalid_config(
10941094
"initial_instance_count": 1,
10951095
}
10961096
with pytest.raises(
1097-
AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'"
1097+
AttributeError,
1098+
match="analysis_config must have at least one working method: "
1099+
"One of the `pre_training_methods`, `post_training_methods`, `explainability_config`.",
10981100
):
10991101
_run_test_explain(
11001102
name_from_base,
@@ -1320,6 +1322,80 @@ def test_analysis_config_generator_for_explainability(data_config, model_config)
13201322
assert actual == expected
13211323

13221324

1325+
def test_analysis_config_generator_for_explainability_failing(data_config, model_config):
1326+
model_scores = ModelPredictedLabelConfig(
1327+
probability="pr",
1328+
label_headers=["success"],
1329+
)
1330+
with pytest.raises(
1331+
ValueError, match="PDP features must be provided when ShapConfig is not provided"
1332+
):
1333+
_AnalysisConfigGenerator.explainability(
1334+
data_config,
1335+
model_config,
1336+
model_scores,
1337+
PDPConfig(),
1338+
)
1339+
1340+
with pytest.raises(ValueError, match="Duplicate explainability configs are provided"):
1341+
_AnalysisConfigGenerator.explainability(
1342+
data_config,
1343+
model_config,
1344+
model_scores,
1345+
[SHAPConfig(), SHAPConfig()],
1346+
)
1347+
1348+
with pytest.raises(ValueError, match="Please provide at least one explainability config."):
1349+
_AnalysisConfigGenerator.explainability(
1350+
data_config,
1351+
model_config,
1352+
model_scores,
1353+
[],
1354+
)
1355+
1356+
1357+
def test_analysis_config_generator_for_bias_explainability(
1358+
data_config, data_bias_config, model_config
1359+
):
1360+
model_predicted_label_config = ModelPredictedLabelConfig(
1361+
probability="pr",
1362+
label_headers=["success"],
1363+
)
1364+
actual = _AnalysisConfigGenerator.bias_and_explainability(
1365+
data_config,
1366+
model_config,
1367+
model_predicted_label_config,
1368+
[SHAPConfig(), PDPConfig()],
1369+
data_bias_config,
1370+
pre_training_methods="all",
1371+
post_training_methods="all",
1372+
)
1373+
expected = {
1374+
"dataset_type": "text/csv",
1375+
"facet": [{"name_or_index": "F1"}],
1376+
"group_variable": "F2",
1377+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1378+
"joinsource_name_or_index": "F4",
1379+
"label": "Label",
1380+
"label_values_or_threshold": [1],
1381+
"methods": {
1382+
"pdp": {"grid_resolution": 15, "top_k_features": 10},
1383+
"post_training_bias": {"methods": "all"},
1384+
"pre_training_bias": {"methods": "all"},
1385+
"report": {"name": "report", "title": "Analysis Report"},
1386+
"shap": {"save_local_shap_values": True, "use_logit": False},
1387+
},
1388+
"predictor": {
1389+
"initial_instance_count": 1,
1390+
"instance_type": "ml.c5.xlarge",
1391+
"label_headers": ["success"],
1392+
"model_name": "xgboost-model",
1393+
"probability": "pr",
1394+
},
1395+
}
1396+
assert actual == expected
1397+
1398+
13231399
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
13241400
actual = _AnalysisConfigGenerator.bias_pre_training(
13251401
data_config, data_bias_config, methods="all"

0 commit comments

Comments
 (0)