Skip to content

Commit 88b1f4d

Browse files
aws-byeldosnavinsoni
authored andcommitted
feature: extended analysis config generator methods with common logic
1 parent 32650ee commit 88b1f4d

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

src/sagemaker/clarify.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -983,10 +983,6 @@ def _run(
983983
the Trial Component will be unassociated.
984984
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
985985
"""
986-
analysis_config["methods"]["report"] = {
987-
"name": "report",
988-
"title": "Analysis Report",
989-
}
990986
with tempfile.TemporaryDirectory() as tmpdirname:
991987
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
992988
with open(analysis_config_file, "w") as f:
@@ -1371,8 +1367,9 @@ def run_explainability(
13711367

13721368

13731369
class _AnalysisConfigGenerator:
1374-
@staticmethod
1370+
@classmethod
13751371
def explainability(
1372+
cls,
13761373
data_config,
13771374
model_config,
13781375
model_scores,
@@ -1414,22 +1411,23 @@ def explainability(
14141411
explainability_methods = explainability_config.get_explainability_config()
14151412
analysis_config["methods"] = explainability_methods
14161413
analysis_config["predictor"] = predictor_config
1417-
return analysis_config
1414+
return cls._common(analysis_config)
14181415

1419-
@staticmethod
1420-
def bias_pre_training(data_config, bias_config, methods):
1416+
@classmethod
1417+
def bias_pre_training(cls, data_config, bias_config, methods):
14211418
analysis_config = data_config.get_config()
14221419
analysis_config.update(bias_config.get_config())
14231420
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
1424-
return analysis_config
1421+
return cls._common(analysis_config)
14251422

1426-
@staticmethod
1423+
@classmethod
14271424
def bias_post_training(
1428-
data_config,
1429-
bias_config,
1430-
model_predicted_label_config,
1431-
methods,
1432-
model_config
1425+
cls,
1426+
data_config,
1427+
bias_config,
1428+
model_predicted_label_config,
1429+
methods,
1430+
model_config
14331431
):
14341432
analysis_config = data_config.get_config()
14351433
analysis_config.update(bias_config.get_config())
@@ -1441,10 +1439,11 @@ def bias_post_training(
14411439
predictor_config.update(model_config.get_predictor_config())
14421440
analysis_config["predictor"] = predictor_config
14431441
_set(probability_threshold, "probability_threshold", analysis_config)
1444-
return analysis_config
1442+
return cls._common(analysis_config)
14451443

1446-
@staticmethod
1444+
@classmethod
14471445
def bias(
1446+
cls,
14481447
data_config,
14491448
bias_config,
14501449
model_config,
@@ -1469,7 +1468,7 @@ def bias(
14691468
"pre_training_bias": {"methods": pre_training_methods},
14701469
"post_training_bias": {"methods": post_training_methods},
14711470
}
1472-
return analysis_config
1471+
return cls._common(analysis_config)
14731472

14741473
@staticmethod
14751474
def _common(analysis_config):

tests/unit/test_clarify.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,10 @@ def test_pre_training_bias(
765765
"label_values_or_threshold": [1],
766766
"facet": [{"name_or_index": "F1"}],
767767
"group_variable": "F2",
768-
"methods": {"pre_training_bias": {"methods": "all"}},
768+
"methods": {
769+
'report': {'name': 'report', 'title': 'Analysis Report'},
770+
"pre_training_bias": {"methods": "all"}
771+
},
769772
}
770773
mock_method.assert_called_with(
771774
data_config,
@@ -828,7 +831,10 @@ def test_post_training_bias(
828831
"joinsource_name_or_index": "F4",
829832
"facet": [{"name_or_index": "F1"}],
830833
"group_variable": "F2",
831-
"methods": {"post_training_bias": {"methods": "all"}},
834+
"methods": {
835+
'report': {'name': 'report', 'title': 'Analysis Report'},
836+
"post_training_bias": {"methods": "all"}
837+
},
832838
"predictor": {
833839
"model_name": "xgboost-model",
834840
"instance_type": "ml.c5.xlarge",
@@ -986,7 +992,10 @@ def _run_test_explain(
986992
"grid_resolution": 20,
987993
"top_k_features": 10,
988994
}
989-
expected_analysis_config["methods"] = expected_explanation_configs
995+
expected_analysis_config["methods"] = {
996+
'report': {'name': 'report', 'title': 'Analysis Report'},
997+
**expected_explanation_configs,
998+
}
990999
mock_method.assert_called_with(
9911000
data_config,
9921001
expected_analysis_config,
@@ -1295,7 +1304,10 @@ def test_analysis_config_generator_for_explainability(data_config, model_config)
12951304
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
12961305
'joinsource_name_or_index': 'F4',
12971306
'label': 'Label',
1298-
'methods': {'shap': {'save_local_shap_values': True, 'use_logit': False}},
1307+
'methods': {
1308+
'report': {'name': 'report', 'title': 'Analysis Report'},
1309+
'shap': {'save_local_shap_values': True, 'use_logit': False}
1310+
},
12991311
'predictor': {'initial_instance_count': 1,
13001312
'instance_type': 'ml.c5.xlarge',
13011313
'label_headers': ['success'],
@@ -1317,7 +1329,10 @@ def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_
13171329
'joinsource_name_or_index': 'F4',
13181330
'label': 'Label',
13191331
'label_values_or_threshold': [1],
1320-
'methods': {'pre_training_bias': {'methods': 'all'}}}
1332+
'methods': {
1333+
'report': {'name': 'report', 'title': 'Analysis Report'},
1334+
'pre_training_bias': {'methods': 'all'}}
1335+
}
13211336
assert actual == expected
13221337

13231338

@@ -1340,7 +1355,10 @@ def test_analysis_config_generator_for_bias_post_training(data_config, data_bias
13401355
'joinsource_name_or_index': 'F4',
13411356
'label': 'Label',
13421357
'label_values_or_threshold': [1],
1343-
'methods': {'post_training_bias': {'methods': 'all'}},
1358+
'methods': {
1359+
'report': {'name': 'report', 'title': 'Analysis Report'},
1360+
'post_training_bias': {'methods': 'all'}
1361+
},
13441362
'predictor': {'initial_instance_count': 1,
13451363
'instance_type': 'ml.c5.xlarge',
13461364
'label_headers': ['success'],
@@ -1369,8 +1387,10 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model
13691387
'joinsource_name_or_index': 'F4',
13701388
'label': 'Label',
13711389
'label_values_or_threshold': [1],
1372-
'methods': {'post_training_bias': {'methods': 'all'},
1373-
'pre_training_bias': {'methods': 'all'}},
1390+
'methods': {
1391+
'report': {'name': 'report', 'title': 'Analysis Report'},
1392+
'post_training_bias': {'methods': 'all'},
1393+
'pre_training_bias': {'methods': 'all'}},
13741394
'predictor': {'initial_instance_count': 1,
13751395
'instance_type': 'ml.c5.xlarge',
13761396
'label_headers': ['success'],

0 commit comments

Comments
 (0)