Skip to content

Commit 76b6ae7

Browse files
committed
feature: extracted analysis config generation for explainability
1 parent 79b3e9d commit 76b6ae7

File tree

2 files changed

+62
-19
lines changed

2 files changed

+62
-19
lines changed

src/sagemaker/clarify.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,6 +1307,36 @@ def run_explainability(
13071307
the Trial Component will be unassociated.
13081308
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
13091309
""" # noqa E501 # pylint: disable=c0301
1310+
analysis_config = _AnalysisConfigGenerator.explainability(
1311+
data_config,
1312+
model_config,
1313+
model_scores,
1314+
explainability_config
1315+
)
1316+
if job_name is None:
1317+
if self.job_name_prefix:
1318+
job_name = utils.name_from_base(self.job_name_prefix)
1319+
else:
1320+
job_name = utils.name_from_base("Clarify-Explainability")
1321+
return self._run(
1322+
data_config,
1323+
analysis_config,
1324+
wait,
1325+
logs,
1326+
job_name,
1327+
kms_key,
1328+
experiment_config,
1329+
)
1330+
1331+
1332+
class _AnalysisConfigGenerator:
1333+
@staticmethod
1334+
def explainability(
1335+
data_config,
1336+
model_config,
1337+
model_scores,
1338+
explainability_config
1339+
):
13101340
analysis_config = data_config.get_config()
13111341
predictor_config = model_config.get_predictor_config()
13121342
if isinstance(model_scores, ModelPredictedLabelConfig):
@@ -1329,34 +1359,21 @@ def run_explainability(
13291359
if not len(explainability_methods.keys()) == len(explainability_config):
13301360
raise ValueError("Duplicate explainability configs are provided")
13311361
if (
1332-
"shap" not in explainability_methods
1333-
and explainability_methods["pdp"].get("features", None) is None
1362+
"shap" not in explainability_methods
1363+
and explainability_methods["pdp"].get("features", None) is None
13341364
):
13351365
raise ValueError("PDP features must be provided when ShapConfig is not provided")
13361366
else:
13371367
if (
1338-
isinstance(explainability_config, PDPConfig)
1339-
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1340-
is None
1368+
isinstance(explainability_config, PDPConfig)
1369+
and explainability_config.get_explainability_config()["pdp"].get("features", None)
1370+
is None
13411371
):
13421372
raise ValueError("PDP features must be provided when ShapConfig is not provided")
13431373
explainability_methods = explainability_config.get_explainability_config()
13441374
analysis_config["methods"] = explainability_methods
13451375
analysis_config["predictor"] = predictor_config
1346-
if job_name is None:
1347-
if self.job_name_prefix:
1348-
job_name = utils.name_from_base(self.job_name_prefix)
1349-
else:
1350-
job_name = utils.name_from_base("Clarify-Explainability")
1351-
return self._run(
1352-
data_config,
1353-
analysis_config,
1354-
wait,
1355-
logs,
1356-
job_name,
1357-
kms_key,
1358-
experiment_config,
1359-
)
1376+
return analysis_config
13601377

13611378

13621379
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):

tests/unit/test_clarify.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
SHAPConfig,
3030
TextConfig,
3131
ImageConfig,
32+
_AnalysisConfigGenerator,
3233
)
3334

3435
JOB_NAME_PREFIX = "my-prefix"
@@ -1277,3 +1278,28 @@ def test_shap_with_image_config(
12771278
expected_predictor_config,
12781279
expected_image_config=expected_image_config,
12791280
)
1281+
1282+
1283+
def test_analysis_config_generator_for_explainability(data_config, model_config):
1284+
model_scores = ModelPredictedLabelConfig(
1285+
probability="pr",
1286+
label_headers=["success"],
1287+
)
1288+
actual = _AnalysisConfigGenerator.explainability(
1289+
data_config,
1290+
model_config,
1291+
model_scores,
1292+
SHAPConfig(),
1293+
)
1294+
expected = {'dataset_type': 'text/csv',
1295+
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1296+
'joinsource_name_or_index': 'F4',
1297+
'label': 'Label',
1298+
'methods': {'shap': {'save_local_shap_values': True, 'use_logit': False}},
1299+
'predictor': {'initial_instance_count': 1,
1300+
'instance_type': 'ml.c5.xlarge',
1301+
'label_headers': ['success'],
1302+
'model_name': 'xgboost-model',
1303+
'probability': 'pr'}}
1304+
assert actual == expected
1305+

0 commit comments

Comments
 (0)