Skip to content

Commit 8fa7650

Browse files
dosatosaws-byeldos
andauthored
feature: added _AnalysisConfigGenerator for clarify (#3271)
* feature: extracted analysis config generation for explainability * feature: extracted analysis config generation for bias pre_training * feature: extracted analysis config generation for bias post_training * feature: extracted analysis config generation for bias * feature: simplified job_name creation * feature: extended analysis config generator methods with common logic * feature: refactored _AnalysisConfigGenerator methods * feature: added _last_analysis_config in SageMakerClarifyProcessor * added data types in _AnalysisConfigGenerator methods * applied style formatting to fix build issues Co-authored-by: Yeldos Balgabekov <[email protected]>
1 parent 8d7dd32 commit 8fa7650

File tree

2 files changed

+277
-66
lines changed

2 files changed

+277
-66
lines changed

src/sagemaker/clarify.py

Lines changed: 139 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import tempfile
2727
from abc import ABC, abstractmethod
28+
from typing import List, Union
29+
2830
from sagemaker import image_uris, s3, utils
2931
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
3032

@@ -922,6 +924,7 @@ def __init__(
922924
version (str): Clarify version to use.
923925
""" # noqa E501 # pylint: disable=c0301
924926
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
927+
self._last_analysis_config = None
925928
self.job_name_prefix = job_name_prefix
926929
super(SageMakerClarifyProcessor, self).__init__(
927930
role,
@@ -983,10 +986,10 @@ def _run(
983986
the Trial Component will be unassociated.
984987
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
985988
"""
986-
analysis_config["methods"]["report"] = {
987-
"name": "report",
988-
"title": "Analysis Report",
989-
}
989+
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
990+
self._last_analysis_config = analysis_config
991+
logger.info("Analysis Config: %s", analysis_config)
992+
990993
with tempfile.TemporaryDirectory() as tmpdirname:
991994
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
992995
with open(analysis_config_file, "w") as f:
@@ -1083,14 +1086,13 @@ def run_pre_training_bias(
10831086
the Trial Component will be unassociated.
10841087
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
10851088
""" # noqa E501 # pylint: disable=c0301
1086-
analysis_config = data_config.get_config()
1087-
analysis_config.update(data_bias_config.get_config())
1088-
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
1089-
if job_name is None:
1090-
if self.job_name_prefix:
1091-
job_name = utils.name_from_base(self.job_name_prefix)
1092-
else:
1093-
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
1089+
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
1090+
data_config, data_bias_config, methods
1091+
)
1092+
# when name is either not provided (is None) or an empty string ("")
1093+
job_name = job_name or utils.name_from_base(
1094+
self.job_name_prefix or "Clarify-Pretraining-Bias"
1095+
)
10941096
return self._run(
10951097
data_config,
10961098
analysis_config,
@@ -1165,21 +1167,13 @@ def run_post_training_bias(
11651167
the Trial Component will be unassociated.
11661168
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
11671169
""" # noqa E501 # pylint: disable=c0301
1168-
analysis_config = data_config.get_config()
1169-
analysis_config.update(data_bias_config.get_config())
1170-
(
1171-
probability_threshold,
1172-
predictor_config,
1173-
) = model_predicted_label_config.get_predictor_config()
1174-
predictor_config.update(model_config.get_predictor_config())
1175-
analysis_config["methods"] = {"post_training_bias": {"methods": methods}}
1176-
analysis_config["predictor"] = predictor_config
1177-
_set(probability_threshold, "probability_threshold", analysis_config)
1178-
if job_name is None:
1179-
if self.job_name_prefix:
1180-
job_name = utils.name_from_base(self.job_name_prefix)
1181-
else:
1182-
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
1170+
analysis_config = _AnalysisConfigGenerator.bias_post_training(
1171+
data_config, data_bias_config, model_predicted_label_config, methods, model_config
1172+
)
1173+
# when name is either not provided (is None) or an empty string ("")
1174+
job_name = job_name or utils.name_from_base(
1175+
self.job_name_prefix or "Clarify-Posttraining-Bias"
1176+
)
11831177
return self._run(
11841178
data_config,
11851179
analysis_config,
@@ -1264,28 +1258,16 @@ def run_bias(
12641258
the Trial Component will be unassociated.
12651259
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
12661260
""" # noqa E501 # pylint: disable=c0301
1267-
analysis_config = data_config.get_config()
1268-
analysis_config.update(bias_config.get_config())
1269-
analysis_config["predictor"] = model_config.get_predictor_config()
1270-
if model_predicted_label_config:
1271-
(
1272-
probability_threshold,
1273-
predictor_config,
1274-
) = model_predicted_label_config.get_predictor_config()
1275-
if predictor_config:
1276-
analysis_config["predictor"].update(predictor_config)
1277-
if probability_threshold is not None:
1278-
analysis_config["probability_threshold"] = probability_threshold
1279-
1280-
analysis_config["methods"] = {
1281-
"pre_training_bias": {"methods": pre_training_methods},
1282-
"post_training_bias": {"methods": post_training_methods},
1283-
}
1284-
if job_name is None:
1285-
if self.job_name_prefix:
1286-
job_name = utils.name_from_base(self.job_name_prefix)
1287-
else:
1288-
job_name = utils.name_from_base("Clarify-Bias")
1261+
analysis_config = _AnalysisConfigGenerator.bias(
1262+
data_config,
1263+
bias_config,
1264+
model_config,
1265+
model_predicted_label_config,
1266+
pre_training_methods,
1267+
post_training_methods,
1268+
)
1269+
# when name is either not provided (is None) or an empty string ("")
1270+
job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias")
12891271
return self._run(
12901272
data_config,
12911273
analysis_config,
@@ -1370,6 +1352,36 @@ def run_explainability(
13701352
the Trial Component will be unassociated.
13711353
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
13721354
""" # noqa E501 # pylint: disable=c0301
1355+
analysis_config = _AnalysisConfigGenerator.explainability(
1356+
data_config, model_config, model_scores, explainability_config
1357+
)
1358+
# when name is either not provided (is None) or an empty string ("")
1359+
job_name = job_name or utils.name_from_base(
1360+
self.job_name_prefix or "Clarify-Explainability"
1361+
)
1362+
return self._run(
1363+
data_config,
1364+
analysis_config,
1365+
wait,
1366+
logs,
1367+
job_name,
1368+
kms_key,
1369+
experiment_config,
1370+
)
1371+
1372+
1373+
class _AnalysisConfigGenerator:
1374+
"""Creates analysis_config objects for different type of runs."""
1375+
1376+
@classmethod
1377+
def explainability(
1378+
cls,
1379+
data_config: DataConfig,
1380+
model_config: ModelConfig,
1381+
model_scores: ModelPredictedLabelConfig,
1382+
explainability_config: ExplainabilityConfig,
1383+
):
1384+
"""Generates a config for Explainability"""
13731385
analysis_config = data_config.get_config()
13741386
predictor_config = model_config.get_predictor_config()
13751387
if isinstance(model_scores, ModelPredictedLabelConfig):
@@ -1406,20 +1418,84 @@ def run_explainability(
14061418
explainability_methods = explainability_config.get_explainability_config()
14071419
analysis_config["methods"] = explainability_methods
14081420
analysis_config["predictor"] = predictor_config
1409-
if job_name is None:
1410-
if self.job_name_prefix:
1411-
job_name = utils.name_from_base(self.job_name_prefix)
1412-
else:
1413-
job_name = utils.name_from_base("Clarify-Explainability")
1414-
return self._run(
1415-
data_config,
1416-
analysis_config,
1417-
wait,
1418-
logs,
1419-
job_name,
1420-
kms_key,
1421-
experiment_config,
1422-
)
1421+
return cls._common(analysis_config)
1422+
1423+
@classmethod
1424+
def bias_pre_training(
1425+
cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]
1426+
):
1427+
"""Generates a config for Bias Pre Training"""
1428+
analysis_config = {
1429+
**data_config.get_config(),
1430+
**bias_config.get_config(),
1431+
"methods": {"pre_training_bias": {"methods": methods}},
1432+
}
1433+
return cls._common(analysis_config)
1434+
1435+
@classmethod
1436+
def bias_post_training(
1437+
cls,
1438+
data_config: DataConfig,
1439+
bias_config: BiasConfig,
1440+
model_predicted_label_config: ModelPredictedLabelConfig,
1441+
methods: Union[str, List[str]],
1442+
model_config: ModelConfig,
1443+
):
1444+
"""Generates a config for Bias Post Training"""
1445+
analysis_config = {
1446+
**data_config.get_config(),
1447+
**bias_config.get_config(),
1448+
"predictor": {**model_config.get_predictor_config()},
1449+
"methods": {"post_training_bias": {"methods": methods}},
1450+
}
1451+
if model_predicted_label_config:
1452+
(
1453+
probability_threshold,
1454+
predictor_config,
1455+
) = model_predicted_label_config.get_predictor_config()
1456+
if predictor_config:
1457+
analysis_config["predictor"].update(predictor_config)
1458+
_set(probability_threshold, "probability_threshold", analysis_config)
1459+
return cls._common(analysis_config)
1460+
1461+
@classmethod
1462+
def bias(
1463+
cls,
1464+
data_config: DataConfig,
1465+
bias_config: BiasConfig,
1466+
model_config: ModelConfig,
1467+
model_predicted_label_config: ModelPredictedLabelConfig,
1468+
pre_training_methods: Union[str, List[str]] = "all",
1469+
post_training_methods: Union[str, List[str]] = "all",
1470+
):
1471+
"""Generates a config for Bias"""
1472+
analysis_config = {
1473+
**data_config.get_config(),
1474+
**bias_config.get_config(),
1475+
"predictor": model_config.get_predictor_config(),
1476+
"methods": {
1477+
"pre_training_bias": {"methods": pre_training_methods},
1478+
"post_training_bias": {"methods": post_training_methods},
1479+
},
1480+
}
1481+
if model_predicted_label_config:
1482+
(
1483+
probability_threshold,
1484+
predictor_config,
1485+
) = model_predicted_label_config.get_predictor_config()
1486+
if predictor_config:
1487+
analysis_config["predictor"].update(predictor_config)
1488+
_set(probability_threshold, "probability_threshold", analysis_config)
1489+
return cls._common(analysis_config)
1490+
1491+
@staticmethod
1492+
def _common(analysis_config):
1493+
"""Extends analysis config with common values"""
1494+
analysis_config["methods"]["report"] = {
1495+
"name": "report",
1496+
"title": "Analysis Report",
1497+
}
1498+
return analysis_config
14231499

14241500

14251501
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):

0 commit comments

Comments
 (0)