Skip to content

Commit a0b1f83

Browse files
committed
feature: extended analysis config generator methods with common logic
1 parent 1316eba commit a0b1f83

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

src/sagemaker/clarify.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -920,10 +920,6 @@ def _run(
920920
the Trial Component will be unassociated.
921921
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
922922
"""
923-
analysis_config["methods"]["report"] = {
924-
"name": "report",
925-
"title": "Analysis Report",
926-
}
927923
with tempfile.TemporaryDirectory() as tmpdirname:
928924
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
929925
with open(analysis_config_file, "w") as f:
@@ -1308,8 +1304,9 @@ def run_explainability(
13081304

13091305

13101306
class _AnalysisConfigGenerator:
1311-
@staticmethod
1307+
@classmethod
13121308
def explainability(
1309+
cls,
13131310
data_config,
13141311
model_config,
13151312
model_scores,
@@ -1351,22 +1348,23 @@ def explainability(
13511348
explainability_methods = explainability_config.get_explainability_config()
13521349
analysis_config["methods"] = explainability_methods
13531350
analysis_config["predictor"] = predictor_config
1354-
return analysis_config
1351+
return cls._common(analysis_config)
13551352

1356-
@staticmethod
1357-
def bias_pre_training(data_config, bias_config, methods):
1353+
@classmethod
1354+
def bias_pre_training(cls, data_config, bias_config, methods):
13581355
analysis_config = data_config.get_config()
13591356
analysis_config.update(bias_config.get_config())
13601357
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
1361-
return analysis_config
1358+
return cls._common(analysis_config)
13621359

1363-
@staticmethod
1360+
@classmethod
13641361
def bias_post_training(
1365-
data_config,
1366-
bias_config,
1367-
model_predicted_label_config,
1368-
methods,
1369-
model_config
1362+
cls,
1363+
data_config,
1364+
bias_config,
1365+
model_predicted_label_config,
1366+
methods,
1367+
model_config
13701368
):
13711369
analysis_config = data_config.get_config()
13721370
analysis_config.update(bias_config.get_config())
@@ -1378,10 +1376,11 @@ def bias_post_training(
13781376
predictor_config.update(model_config.get_predictor_config())
13791377
analysis_config["predictor"] = predictor_config
13801378
_set(probability_threshold, "probability_threshold", analysis_config)
1381-
return analysis_config
1379+
return cls._common(analysis_config)
13821380

1383-
@staticmethod
1381+
@classmethod
13841382
def bias(
1383+
cls,
13851384
data_config,
13861385
bias_config,
13871386
model_config,
@@ -1406,7 +1405,7 @@ def bias(
14061405
"pre_training_bias": {"methods": pre_training_methods},
14071406
"post_training_bias": {"methods": post_training_methods},
14081407
}
1409-
return analysis_config
1408+
return cls._common(analysis_config)
14101409

14111410
@staticmethod
14121411
def _common(analysis_config):

tests/unit/test_clarify.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,10 @@ def test_pre_training_bias(
765765
"label_values_or_threshold": [1],
766766
"facet": [{"name_or_index": "F1"}],
767767
"group_variable": "F2",
768-
"methods": {"pre_training_bias": {"methods": "all"}},
768+
"methods": {
769+
'report': {'name': 'report', 'title': 'Analysis Report'},
770+
"pre_training_bias": {"methods": "all"}
771+
},
769772
}
770773
mock_method.assert_called_with(
771774
data_config,
@@ -828,7 +831,10 @@ def test_post_training_bias(
828831
"joinsource_name_or_index": "F4",
829832
"facet": [{"name_or_index": "F1"}],
830833
"group_variable": "F2",
831-
"methods": {"post_training_bias": {"methods": "all"}},
834+
"methods": {
835+
'report': {'name': 'report', 'title': 'Analysis Report'},
836+
"post_training_bias": {"methods": "all"}
837+
},
832838
"predictor": {
833839
"model_name": "xgboost-model",
834840
"instance_type": "ml.c5.xlarge",
@@ -986,7 +992,10 @@ def _run_test_explain(
986992
"grid_resolution": 20,
987993
"top_k_features": 10,
988994
}
989-
expected_analysis_config["methods"] = expected_explanation_configs
995+
expected_analysis_config["methods"] = {
996+
'report': {'name': 'report', 'title': 'Analysis Report'},
997+
**expected_explanation_configs,
998+
}
990999
mock_method.assert_called_with(
9911000
data_config,
9921001
expected_analysis_config,
@@ -1295,7 +1304,10 @@ def test_analysis_config_generator_for_explainability(data_config, model_config)
12951304
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
12961305
'joinsource_name_or_index': 'F4',
12971306
'label': 'Label',
1298-
'methods': {'shap': {'save_local_shap_values': True, 'use_logit': False}},
1307+
'methods': {
1308+
'report': {'name': 'report', 'title': 'Analysis Report'},
1309+
'shap': {'save_local_shap_values': True, 'use_logit': False}
1310+
},
12991311
'predictor': {'initial_instance_count': 1,
13001312
'instance_type': 'ml.c5.xlarge',
13011313
'label_headers': ['success'],
@@ -1317,7 +1329,10 @@ def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_
13171329
'joinsource_name_or_index': 'F4',
13181330
'label': 'Label',
13191331
'label_values_or_threshold': [1],
1320-
'methods': {'pre_training_bias': {'methods': 'all'}}}
1332+
'methods': {
1333+
'report': {'name': 'report', 'title': 'Analysis Report'},
1334+
'pre_training_bias': {'methods': 'all'}}
1335+
}
13211336
assert actual == expected
13221337

13231338

@@ -1340,7 +1355,10 @@ def test_analysis_config_generator_for_bias_post_training(data_config, data_bias
13401355
'joinsource_name_or_index': 'F4',
13411356
'label': 'Label',
13421357
'label_values_or_threshold': [1],
1343-
'methods': {'post_training_bias': {'methods': 'all'}},
1358+
'methods': {
1359+
'report': {'name': 'report', 'title': 'Analysis Report'},
1360+
'post_training_bias': {'methods': 'all'}
1361+
},
13441362
'predictor': {'initial_instance_count': 1,
13451363
'instance_type': 'ml.c5.xlarge',
13461364
'label_headers': ['success'],
@@ -1369,8 +1387,10 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model
13691387
'joinsource_name_or_index': 'F4',
13701388
'label': 'Label',
13711389
'label_values_or_threshold': [1],
1372-
'methods': {'post_training_bias': {'methods': 'all'},
1373-
'pre_training_bias': {'methods': 'all'}},
1390+
'methods': {
1391+
'report': {'name': 'report', 'title': 'Analysis Report'},
1392+
'post_training_bias': {'methods': 'all'},
1393+
'pre_training_bias': {'methods': 'all'}},
13741394
'predictor': {'initial_instance_count': 1,
13751395
'instance_type': 'ml.c5.xlarge',
13761396
'label_headers': ['success'],

0 commit comments

Comments
 (0)