Skip to content

Commit 121ca9d

Browse files
committed
refactorred _AnalysisConfigGenerator
1 parent 24bd02e commit 121ca9d

File tree

2 files changed

+210
-80
lines changed

2 files changed

+210
-80
lines changed

src/sagemaker/clarify.py

Lines changed: 133 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,70 +1369,72 @@ def run_explainability(
13691369
experiment_config,
13701370
)
13711371

1372+
def run_bias_and_explainability(self):
1373+
"""
1374+
TODO:
1375+
- add doc string
1376+
- add logic
1377+
- add tests
1378+
"""
1379+
raise NotImplementedError(
1380+
"Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1381+
)
1382+
13721383

13731384
class _AnalysisConfigGenerator:
13741385
"""
13751386
Creates analysis_config objects for different type of runs.
13761387
"""
13771388

1389+
@classmethod
1390+
def bias_and_explainability(
1391+
cls,
1392+
data_config: DataConfig,
1393+
model_config: ModelConfig,
1394+
model_predicted_label_config: ModelPredictedLabelConfig,
1395+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
1396+
bias_config: BiasConfig,
1397+
pre_training_methods: Union[str, List[str]] = "all",
1398+
post_training_methods: Union[str, List[str]] = "all",
1399+
):
1400+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1401+
analysis_config = cls._add_methods(
1402+
analysis_config,
1403+
pre_training_methods=pre_training_methods,
1404+
post_training_methods=post_training_methods,
1405+
explainability_config=explainability_config,
1406+
)
1407+
analysis_config = cls._add_predictor(
1408+
analysis_config, model_config, model_predicted_label_config
1409+
)
1410+
return analysis_config
1411+
13781412
@classmethod
13791413
def explainability(
13801414
cls,
13811415
data_config: DataConfig,
13821416
model_config: ModelConfig,
1383-
model_scores: ModelPredictedLabelConfig,
1384-
explainability_config: ExplainabilityConfig,
1417+
model_predicted_label_config: ModelPredictedLabelConfig,
1418+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
13851419
):
13861420
"""Generates a config for Explainability"""
1387-
analysis_config = data_config.get_config()
1388-
predictor_config = model_config.get_predictor_config()
1389-
if isinstance(model_scores, ModelPredictedLabelConfig):
1390-
(
1391-
probability_threshold,
1392-
predicted_label_config,
1393-
) = model_scores.get_predictor_config()
1394-
_set(probability_threshold, "probability_threshold", analysis_config)
1395-
predictor_config.update(predicted_label_config)
1396-
else:
1397-
_set(model_scores, "label", predictor_config)
1398-
1399-
explainability_methods = {}
1400-
if isinstance(explainability_config, list):
1401-
if len(explainability_config) == 0:
1402-
raise ValueError("Please provide at least one explainability config.")
1403-
for config in explainability_config:
1404-
explain_config = config.get_explainability_config()
1405-
explainability_methods.update(explain_config)
1406-
if not len(explainability_methods.keys()) == len(explainability_config):
1407-
raise ValueError("Duplicate explainability configs are provided")
1408-
if (
1409-
"shap" not in explainability_methods
1410-
and explainability_methods["pdp"].get("features", None) is None
1411-
):
1412-
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1413-
else:
1414-
if (
1415-
isinstance(explainability_config, PDPConfig)
1416-
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1417-
is None
1418-
):
1419-
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1420-
explainability_methods = explainability_config.get_explainability_config()
1421-
analysis_config["methods"] = explainability_methods
1422-
analysis_config["predictor"] = predictor_config
1423-
return cls._common(analysis_config)
1421+
analysis_config = data_config.analysis_config
1422+
analysis_config = cls._add_predictor(
1423+
analysis_config, model_config, model_predicted_label_config
1424+
)
1425+
analysis_config = cls._add_methods(
1426+
analysis_config, explainability_config=explainability_config
1427+
)
1428+
return analysis_config
14241429

14251430
@classmethod
14261431
def bias_pre_training(
14271432
cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]
14281433
):
14291434
"""Generates a config for Bias Pre Training"""
1430-
analysis_config = {
1431-
**data_config.get_config(),
1432-
**bias_config.get_config(),
1433-
"methods": {"pre_training_bias": {"methods": methods}},
1434-
}
1435-
return cls._common(analysis_config)
1435+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1436+
analysis_config = cls._add_methods(analysis_config, pre_training_methods=methods)
1437+
return analysis_config
14361438

14371439
@classmethod
14381440
def bias_post_training(
@@ -1444,21 +1446,12 @@ def bias_post_training(
14441446
model_config: ModelConfig,
14451447
):
14461448
"""Generates a config for Bias Post Training"""
1447-
analysis_config = {
1448-
**data_config.get_config(),
1449-
**bias_config.get_config(),
1450-
"predictor": {**model_config.get_predictor_config()},
1451-
"methods": {"post_training_bias": {"methods": methods}},
1452-
}
1453-
if model_predicted_label_config:
1454-
(
1455-
probability_threshold,
1456-
predictor_config,
1457-
) = model_predicted_label_config.get_predictor_config()
1458-
if predictor_config:
1459-
analysis_config["predictor"].update(predictor_config)
1460-
_set(probability_threshold, "probability_threshold", analysis_config)
1461-
return cls._common(analysis_config)
1449+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1450+
analysis_config = cls._add_methods(analysis_config, post_training_methods=methods)
1451+
analysis_config = cls._add_predictor(
1452+
analysis_config, model_config, model_predicted_label_config
1453+
)
1454+
return analysis_config
14621455

14631456
@classmethod
14641457
def bias(
@@ -1471,34 +1464,95 @@ def bias(
14711464
post_training_methods: Union[str, List[str]] = "all",
14721465
):
14731466
"""Generates a config for Bias"""
1474-
analysis_config = {
1475-
**data_config.get_config(),
1476-
**bias_config.get_config(),
1477-
"predictor": model_config.get_predictor_config(),
1478-
"methods": {
1479-
"pre_training_bias": {"methods": pre_training_methods},
1480-
"post_training_bias": {"methods": post_training_methods},
1481-
},
1482-
}
1483-
if model_predicted_label_config:
1467+
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1468+
analysis_config = cls._add_methods(
1469+
analysis_config,
1470+
pre_training_methods=pre_training_methods,
1471+
post_training_methods=post_training_methods,
1472+
)
1473+
analysis_config = cls._add_predictor(
1474+
analysis_config, model_config, model_predicted_label_config
1475+
)
1476+
return analysis_config
1477+
1478+
@classmethod
1479+
def _add_predictor(cls, analysis_config, model_config, model_predicted_label_config):
1480+
analysis_config = {**analysis_config}
1481+
analysis_config["predictor"] = model_config.get_predictor_config()
1482+
if isinstance(model_predicted_label_config, ModelPredictedLabelConfig):
14841483
(
14851484
probability_threshold,
14861485
predictor_config,
14871486
) = model_predicted_label_config.get_predictor_config()
14881487
if predictor_config:
14891488
analysis_config["predictor"].update(predictor_config)
14901489
_set(probability_threshold, "probability_threshold", analysis_config)
1491-
return cls._common(analysis_config)
1492-
1493-
@staticmethod
1494-
def _common(analysis_config):
1495-
"""Extends analysis config with common values"""
1496-
analysis_config["methods"]["report"] = {
1497-
"name": "report",
1498-
"title": "Analysis Report",
1499-
}
1490+
else:
1491+
_set(model_predicted_label_config, "label", analysis_config["predictor"])
15001492
return analysis_config
15011493

1494+
@classmethod
1495+
def _add_methods(
1496+
cls,
1497+
analysis_config,
1498+
pre_training_methods=None,
1499+
post_training_methods=None,
1500+
explainability_config=None,
1501+
report=True,
1502+
):
1503+
# validate
1504+
params = [pre_training_methods, post_training_methods, explainability_config]
1505+
if all([1 if p is None else 0 for p in params]):
1506+
raise AttributeError(
1507+
"analysis_config must have at least one working method: "
1508+
"One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1509+
)
1510+
1511+
# main logic
1512+
analysis_config = {**analysis_config}
1513+
if "methods" not in analysis_config:
1514+
analysis_config["methods"] = {}
1515+
1516+
if report:
1517+
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
1518+
1519+
if pre_training_methods:
1520+
analysis_config["methods"]["pre_training_bias"] = {"methods": pre_training_methods}
1521+
1522+
if post_training_methods:
1523+
analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods}
1524+
1525+
if explainability_config is not None:
1526+
explainability_methods = cls._merge_explainability_configs(explainability_config)
1527+
analysis_config["methods"] = {**analysis_config["methods"], **explainability_methods}
1528+
return analysis_config
1529+
1530+
@classmethod
1531+
def _merge_explainability_configs(
1532+
cls, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]]
1533+
):
1534+
if isinstance(explainability_config, list):
1535+
explainability_methods = {}
1536+
if len(explainability_config) == 0:
1537+
raise ValueError("Please provide at least one explainability config.")
1538+
for config in explainability_config:
1539+
explain_config = config.get_explainability_config()
1540+
explainability_methods.update(explain_config)
1541+
if not len(explainability_methods) == len(explainability_config):
1542+
raise ValueError("Duplicate explainability configs are provided")
1543+
if (
1544+
"shap" not in explainability_methods
1545+
and "features" not in explainability_methods["pdp"]
1546+
):
1547+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1548+
return explainability_methods
1549+
if (
1550+
isinstance(explainability_config, PDPConfig)
1551+
and "features" not in explainability_config.get_explainability_config()["pdp"]
1552+
):
1553+
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1554+
return explainability_config.get_explainability_config()
1555+
15021556

15031557
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
15041558
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.

tests/unit/test_clarify.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,9 @@ def test_explainability_with_invalid_config(
10941094
"initial_instance_count": 1,
10951095
}
10961096
with pytest.raises(
1097-
AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'"
1097+
AttributeError,
1098+
match="analysis_config must have at least one working method: "
1099+
"One of the `pre_training_methods`, `post_training_methods`, `explainability_config`.",
10981100
):
10991101
_run_test_explain(
11001102
name_from_base,
@@ -1320,6 +1322,80 @@ def test_analysis_config_generator_for_explainability(data_config, model_config)
13201322
assert actual == expected
13211323

13221324

1325+
def test_analysis_config_generator_for_explainability_failing(data_config, model_config):
1326+
model_scores = ModelPredictedLabelConfig(
1327+
probability="pr",
1328+
label_headers=["success"],
1329+
)
1330+
with pytest.raises(
1331+
ValueError, match="PDP features must be provided when ShapConfig is not provided"
1332+
):
1333+
_AnalysisConfigGenerator.explainability(
1334+
data_config,
1335+
model_config,
1336+
model_scores,
1337+
PDPConfig(),
1338+
)
1339+
1340+
with pytest.raises(ValueError, match="Duplicate explainability configs are provided"):
1341+
_AnalysisConfigGenerator.explainability(
1342+
data_config,
1343+
model_config,
1344+
model_scores,
1345+
[SHAPConfig(), SHAPConfig()],
1346+
)
1347+
1348+
with pytest.raises(ValueError, match="Please provide at least one explainability config."):
1349+
_AnalysisConfigGenerator.explainability(
1350+
data_config,
1351+
model_config,
1352+
model_scores,
1353+
[],
1354+
)
1355+
1356+
1357+
def test_analysis_config_generator_for_bias_explainability(
1358+
data_config, data_bias_config, model_config
1359+
):
1360+
model_predicted_label_config = ModelPredictedLabelConfig(
1361+
probability="pr",
1362+
label_headers=["success"],
1363+
)
1364+
actual = _AnalysisConfigGenerator.bias_and_explainability(
1365+
data_config,
1366+
model_config,
1367+
model_predicted_label_config,
1368+
[SHAPConfig(), PDPConfig()],
1369+
data_bias_config,
1370+
pre_training_methods="all",
1371+
post_training_methods="all",
1372+
)
1373+
expected = {
1374+
"dataset_type": "text/csv",
1375+
"facet": [{"name_or_index": "F1"}],
1376+
"group_variable": "F2",
1377+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1378+
"joinsource_name_or_index": "F4",
1379+
"label": "Label",
1380+
"label_values_or_threshold": [1],
1381+
"methods": {
1382+
"pdp": {"grid_resolution": 15, "top_k_features": 10},
1383+
"post_training_bias": {"methods": "all"},
1384+
"pre_training_bias": {"methods": "all"},
1385+
"report": {"name": "report", "title": "Analysis Report"},
1386+
"shap": {"save_local_shap_values": True, "use_logit": False},
1387+
},
1388+
"predictor": {
1389+
"initial_instance_count": 1,
1390+
"instance_type": "ml.c5.xlarge",
1391+
"label_headers": ["success"],
1392+
"model_name": "xgboost-model",
1393+
"probability": "pr",
1394+
},
1395+
}
1396+
assert actual == expected
1397+
1398+
13231399
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
13241400
actual = _AnalysisConfigGenerator.bias_pre_training(
13251401
data_config, data_bias_config, methods="all"

0 commit comments

Comments
 (0)