Skip to content

Commit 7f00235

Browse files
committed
feature: refactored _AnalysisConfigGenerator methods
1 parent a0b1f83 commit 7f00235

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

src/sagemaker/clarify.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,9 @@ def run_explainability(
13041304

13051305

13061306
class _AnalysisConfigGenerator:
1307+
"""
1308+
Creates analysis_config objects for different type of runs.
1309+
"""
13071310
@classmethod
13081311
def explainability(
13091312
cls,
@@ -1334,15 +1337,15 @@ def explainability(
13341337
if not len(explainability_methods.keys()) == len(explainability_config):
13351338
raise ValueError("Duplicate explainability configs are provided")
13361339
if (
1337-
"shap" not in explainability_methods
1338-
and explainability_methods["pdp"].get("features", None) is None
1340+
"shap" not in explainability_methods
1341+
and explainability_methods["pdp"].get("features", None) is None
13391342
):
13401343
raise ValueError("PDP features must be provided when ShapConfig is not provided")
13411344
else:
13421345
if (
1343-
isinstance(explainability_config, PDPConfig)
1344-
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1345-
is None
1346+
isinstance(explainability_config, PDPConfig)
1347+
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1348+
is None
13461349
):
13471350
raise ValueError("PDP features must be provided when ShapConfig is not provided")
13481351
explainability_methods = explainability_config.get_explainability_config()
@@ -1352,9 +1355,11 @@ def explainability(
13521355

13531356
@classmethod
13541357
def bias_pre_training(cls, data_config, bias_config, methods):
1355-
analysis_config = data_config.get_config()
1356-
analysis_config.update(bias_config.get_config())
1357-
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
1358+
analysis_config = {
1359+
**data_config.get_config(),
1360+
**bias_config.get_config(),
1361+
"methods": {"pre_training_bias": {"methods": methods}}
1362+
}
13581363
return cls._common(analysis_config)
13591364

13601365
@classmethod
@@ -1366,16 +1371,17 @@ def bias_post_training(
13661371
methods,
13671372
model_config
13681373
):
1369-
analysis_config = data_config.get_config()
1370-
analysis_config.update(bias_config.get_config())
1371-
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
1372-
(
1373-
probability_threshold,
1374-
predictor_config,
1375-
) = model_predicted_label_config.get_predictor_config()
1376-
predictor_config.update(model_config.get_predictor_config())
1377-
analysis_config["predictor"] = predictor_config
1378-
_set(probability_threshold, "probability_threshold", analysis_config)
1374+
analysis_config = {
1375+
**data_config.get_config(),
1376+
**bias_config.get_config(),
1377+
"predictor": {**model_config.get_predictor_config()},
1378+
"methods": {"post_training_bias": {"methods": methods}},
1379+
}
1380+
if model_predicted_label_config:
1381+
probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config()
1382+
if predictor_config:
1383+
analysis_config["predictor"].update(predictor_config)
1384+
_set(probability_threshold, "probability_threshold", analysis_config)
13791385
return cls._common(analysis_config)
13801386

13811387
@classmethod
@@ -1388,23 +1394,20 @@ def bias(
13881394
pre_training_methods="all",
13891395
post_training_methods="all",
13901396
):
1391-
analysis_config = data_config.get_config()
1392-
analysis_config.update(bias_config.get_config())
1393-
analysis_config["predictor"] = model_config.get_predictor_config()
1397+
analysis_config = {
1398+
**data_config.get_config(),
1399+
**bias_config.get_config(),
1400+
"predictor": model_config.get_predictor_config(),
1401+
"methods": {
1402+
"pre_training_bias": {"methods": pre_training_methods},
1403+
"post_training_bias": {"methods": post_training_methods},
1404+
}
1405+
}
13941406
if model_predicted_label_config:
1395-
(
1396-
probability_threshold,
1397-
predictor_config,
1398-
) = model_predicted_label_config.get_predictor_config()
1407+
probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config()
13991408
if predictor_config:
14001409
analysis_config["predictor"].update(predictor_config)
1401-
if probability_threshold is not None:
1402-
analysis_config["probability_threshold"] = probability_threshold
1403-
1404-
analysis_config["methods"] = {
1405-
"pre_training_bias": {"methods": pre_training_methods},
1406-
"post_training_bias": {"methods": post_training_methods},
1407-
}
1410+
_set(probability_threshold, "probability_threshold", analysis_config)
14081411
return cls._common(analysis_config)
14091412

14101413
@staticmethod

0 commit comments

Comments
 (0)