Skip to content

Commit 71989f9

Browse files
aws-byeldosnavinsoni
authored andcommitted
feature: refactored _AnalysisConfigGenerator methods
1 parent 88b1f4d commit 71989f9

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

src/sagemaker/clarify.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,9 @@ def run_explainability(
13671367

13681368

13691369
class _AnalysisConfigGenerator:
1370+
"""
1371+
Creates analysis_config objects for different type of runs.
1372+
"""
13701373
@classmethod
13711374
def explainability(
13721375
cls,
@@ -1397,15 +1400,15 @@ def explainability(
13971400
if not len(explainability_methods.keys()) == len(explainability_config):
13981401
raise ValueError("Duplicate explainability configs are provided")
13991402
if (
1400-
"shap" not in explainability_methods
1401-
and explainability_methods["pdp"].get("features", None) is None
1403+
"shap" not in explainability_methods
1404+
and explainability_methods["pdp"].get("features", None) is None
14021405
):
14031406
raise ValueError("PDP features must be provided when ShapConfig is not provided")
14041407
else:
14051408
if (
1406-
isinstance(explainability_config, PDPConfig)
1407-
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1408-
is None
1409+
isinstance(explainability_config, PDPConfig)
1410+
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1411+
is None
14091412
):
14101413
raise ValueError("PDP features must be provided when ShapConfig is not provided")
14111414
explainability_methods = explainability_config.get_explainability_config()
@@ -1415,9 +1418,11 @@ def explainability(
14151418

14161419
@classmethod
14171420
def bias_pre_training(cls, data_config, bias_config, methods):
1418-
analysis_config = data_config.get_config()
1419-
analysis_config.update(bias_config.get_config())
1420-
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
1421+
analysis_config = {
1422+
**data_config.get_config(),
1423+
**bias_config.get_config(),
1424+
"methods": {"pre_training_bias": {"methods": methods}}
1425+
}
14211426
return cls._common(analysis_config)
14221427

14231428
@classmethod
@@ -1429,16 +1434,17 @@ def bias_post_training(
14291434
methods,
14301435
model_config
14311436
):
1432-
analysis_config = data_config.get_config()
1433-
analysis_config.update(bias_config.get_config())
1434-
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
1435-
(
1436-
probability_threshold,
1437-
predictor_config,
1438-
) = model_predicted_label_config.get_predictor_config()
1439-
predictor_config.update(model_config.get_predictor_config())
1440-
analysis_config["predictor"] = predictor_config
1441-
_set(probability_threshold, "probability_threshold", analysis_config)
1437+
analysis_config = {
1438+
**data_config.get_config(),
1439+
**bias_config.get_config(),
1440+
"predictor": {**model_config.get_predictor_config()},
1441+
"methods": {"post_training_bias": {"methods": methods}},
1442+
}
1443+
if model_predicted_label_config:
1444+
probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config()
1445+
if predictor_config:
1446+
analysis_config["predictor"].update(predictor_config)
1447+
_set(probability_threshold, "probability_threshold", analysis_config)
14421448
return cls._common(analysis_config)
14431449

14441450
@classmethod
@@ -1451,23 +1457,20 @@ def bias(
14511457
pre_training_methods="all",
14521458
post_training_methods="all",
14531459
):
1454-
analysis_config = data_config.get_config()
1455-
analysis_config.update(bias_config.get_config())
1456-
analysis_config["predictor"] = model_config.get_predictor_config()
1460+
analysis_config = {
1461+
**data_config.get_config(),
1462+
**bias_config.get_config(),
1463+
"predictor": model_config.get_predictor_config(),
1464+
"methods": {
1465+
"pre_training_bias": {"methods": pre_training_methods},
1466+
"post_training_bias": {"methods": post_training_methods},
1467+
}
1468+
}
14571469
if model_predicted_label_config:
1458-
(
1459-
probability_threshold,
1460-
predictor_config,
1461-
) = model_predicted_label_config.get_predictor_config()
1470+
probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config()
14621471
if predictor_config:
14631472
analysis_config["predictor"].update(predictor_config)
1464-
if probability_threshold is not None:
1465-
analysis_config["probability_threshold"] = probability_threshold
1466-
1467-
analysis_config["methods"] = {
1468-
"pre_training_bias": {"methods": pre_training_methods},
1469-
"post_training_bias": {"methods": post_training_methods},
1470-
}
1473+
_set(probability_threshold, "probability_threshold", analysis_config)
14711474
return cls._common(analysis_config)
14721475

14731476
@staticmethod

0 commit comments

Comments
 (0)