Skip to content

Commit 1ecdbd4

Browse files
committed
applied style formatting to fix build issues
1 parent ea9ea51 commit 1ecdbd4

File tree

2 files changed

+113
-86
lines changed

2 files changed

+113
-86
lines changed

src/sagemaker/clarify.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,7 @@ def __init__(
924924
version (str): Clarify version to use.
925925
""" # noqa E501 # pylint: disable=c0301
926926
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
927+
self._last_analysis_config = None
927928
self.job_name_prefix = job_name_prefix
928929
super(SageMakerClarifyProcessor, self).__init__(
929930
role,
@@ -987,7 +988,7 @@ def _run(
987988
"""
988989
# for debugging: to access locally, i.e. without a need to look for it in an S3 bucket
989990
self._last_analysis_config = analysis_config
990-
logger.info("Analysis Config: ", analysis_config)
991+
logger.info("Analysis Config: %s", analysis_config)
991992

992993
with tempfile.TemporaryDirectory() as tmpdirname:
993994
analysis_config_file = os.path.join(tmpdirname, "analysis_config.json")
@@ -1086,12 +1087,12 @@ def run_pre_training_bias(
10861087
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
10871088
""" # noqa E501 # pylint: disable=c0301
10881089
analysis_config = _AnalysisConfigGenerator.bias_pre_training(
1089-
data_config,
1090-
data_bias_config,
1091-
methods
1090+
data_config, data_bias_config, methods
10921091
)
10931092
# when name is either not provided (is None) or an empty string ("")
1094-
job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Pretraining-Bias")
1093+
job_name = job_name or utils.name_from_base(
1094+
self.job_name_prefix or "Clarify-Pretraining-Bias"
1095+
)
10951096
return self._run(
10961097
data_config,
10971098
analysis_config,
@@ -1167,14 +1168,12 @@ def run_post_training_bias(
11671168
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
11681169
""" # noqa E501 # pylint: disable=c0301
11691170
analysis_config = _AnalysisConfigGenerator.bias_post_training(
1170-
data_config,
1171-
data_bias_config,
1172-
model_predicted_label_config,
1173-
methods,
1174-
model_config
1171+
data_config, data_bias_config, model_predicted_label_config, methods, model_config
11751172
)
11761173
# when name is either not provided (is None) or an empty string ("")
1177-
job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Posttraining-Bias")
1174+
job_name = job_name or utils.name_from_base(
1175+
self.job_name_prefix or "Clarify-Posttraining-Bias"
1176+
)
11781177
return self._run(
11791178
data_config,
11801179
analysis_config,
@@ -1354,13 +1353,12 @@ def run_explainability(
13541353
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
13551354
""" # noqa E501 # pylint: disable=c0301
13561355
analysis_config = _AnalysisConfigGenerator.explainability(
1357-
data_config,
1358-
model_config,
1359-
model_scores,
1360-
explainability_config
1356+
data_config, model_config, model_scores, explainability_config
13611357
)
13621358
# when name is either not provided (is None) or an empty string ("")
1363-
job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Explainability")
1359+
job_name = job_name or utils.name_from_base(
1360+
self.job_name_prefix or "Clarify-Explainability"
1361+
)
13641362
return self._run(
13651363
data_config,
13661364
analysis_config,
@@ -1376,6 +1374,7 @@ class _AnalysisConfigGenerator:
13761374
"""
13771375
Creates analysis_config objects for different type of runs.
13781376
"""
1377+
13791378
@classmethod
13801379
def explainability(
13811380
cls,
@@ -1384,6 +1383,7 @@ def explainability(
13841383
model_scores: ModelPredictedLabelConfig,
13851384
explainability_config: ExplainabilityConfig,
13861385
):
1386+
""" Generates a config for Explainability """
13871387
analysis_config = data_config.get_config()
13881388
predictor_config = model_config.get_predictor_config()
13891389
if isinstance(model_scores, ModelPredictedLabelConfig):
@@ -1423,11 +1423,14 @@ def explainability(
14231423
return cls._common(analysis_config)
14241424

14251425
@classmethod
1426-
def bias_pre_training(cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]):
1426+
def bias_pre_training(
1427+
cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]
1428+
):
1429+
""" Generates a config for Bias Pre Training"""
14271430
analysis_config = {
14281431
**data_config.get_config(),
14291432
**bias_config.get_config(),
1430-
"methods": {"pre_training_bias": {"methods": methods}}
1433+
"methods": {"pre_training_bias": {"methods": methods}},
14311434
}
14321435
return cls._common(analysis_config)
14331436

@@ -1440,14 +1443,18 @@ def bias_post_training(
14401443
methods: Union[str, List[str]],
14411444
model_config: ModelConfig,
14421445
):
1446+
""" Generates a config for Bias Post Training """
14431447
analysis_config = {
14441448
**data_config.get_config(),
14451449
**bias_config.get_config(),
14461450
"predictor": {**model_config.get_predictor_config()},
14471451
"methods": {"post_training_bias": {"methods": methods}},
14481452
}
14491453
if model_predicted_label_config:
1450-
probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config()
1454+
(
1455+
probability_threshold,
1456+
predictor_config,
1457+
) = model_predicted_label_config.get_predictor_config()
14511458
if predictor_config:
14521459
analysis_config["predictor"].update(predictor_config)
14531460
_set(probability_threshold, "probability_threshold", analysis_config)
@@ -1463,24 +1470,29 @@ def bias(
14631470
pre_training_methods: Union[str, List[str]] = "all",
14641471
post_training_methods: Union[str, List[str]] = "all",
14651472
):
1473+
""" Generates a config for Bias """
14661474
analysis_config = {
14671475
**data_config.get_config(),
14681476
**bias_config.get_config(),
14691477
"predictor": model_config.get_predictor_config(),
14701478
"methods": {
14711479
"pre_training_bias": {"methods": pre_training_methods},
14721480
"post_training_bias": {"methods": post_training_methods},
1473-
}
1481+
},
14741482
}
14751483
if model_predicted_label_config:
1476-
probability_threshold, predictor_config = model_predicted_label_config.get_predictor_config()
1484+
(
1485+
probability_threshold,
1486+
predictor_config,
1487+
) = model_predicted_label_config.get_predictor_config()
14771488
if predictor_config:
14781489
analysis_config["predictor"].update(predictor_config)
14791490
_set(probability_threshold, "probability_threshold", analysis_config)
14801491
return cls._common(analysis_config)
14811492

14821493
@staticmethod
14831494
def _common(analysis_config):
1495+
""" Extends analysis config with common values """
14841496
analysis_config["methods"]["report"] = {
14851497
"name": "report",
14861498
"title": "Analysis Report",

tests/unit/test_clarify.py

Lines changed: 80 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -766,8 +766,8 @@ def test_pre_training_bias(
766766
"facet": [{"name_or_index": "F1"}],
767767
"group_variable": "F2",
768768
"methods": {
769-
'report': {'name': 'report', 'title': 'Analysis Report'},
770-
"pre_training_bias": {"methods": "all"}
769+
"report": {"name": "report", "title": "Analysis Report"},
770+
"pre_training_bias": {"methods": "all"},
771771
},
772772
}
773773
mock_method.assert_called_with(
@@ -832,8 +832,8 @@ def test_post_training_bias(
832832
"facet": [{"name_or_index": "F1"}],
833833
"group_variable": "F2",
834834
"methods": {
835-
'report': {'name': 'report', 'title': 'Analysis Report'},
836-
"post_training_bias": {"methods": "all"}
835+
"report": {"name": "report", "title": "Analysis Report"},
836+
"post_training_bias": {"methods": "all"},
837837
},
838838
"predictor": {
839839
"model_name": "xgboost-model",
@@ -993,7 +993,7 @@ def _run_test_explain(
993993
"top_k_features": 10,
994994
}
995995
expected_analysis_config["methods"] = {
996-
'report': {'name': 'report', 'title': 'Analysis Report'},
996+
"report": {"name": "report", "title": "Analysis Report"},
997997
**expected_explanation_configs,
998998
}
999999
mock_method.assert_called_with(
@@ -1300,43 +1300,49 @@ def test_analysis_config_generator_for_explainability(data_config, model_config)
13001300
model_scores,
13011301
SHAPConfig(),
13021302
)
1303-
expected = {'dataset_type': 'text/csv',
1304-
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1305-
'joinsource_name_or_index': 'F4',
1306-
'label': 'Label',
1307-
'methods': {
1308-
'report': {'name': 'report', 'title': 'Analysis Report'},
1309-
'shap': {'save_local_shap_values': True, 'use_logit': False}
1310-
},
1311-
'predictor': {'initial_instance_count': 1,
1312-
'instance_type': 'ml.c5.xlarge',
1313-
'label_headers': ['success'],
1314-
'model_name': 'xgboost-model',
1315-
'probability': 'pr'}}
1303+
expected = {
1304+
"dataset_type": "text/csv",
1305+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1306+
"joinsource_name_or_index": "F4",
1307+
"label": "Label",
1308+
"methods": {
1309+
"report": {"name": "report", "title": "Analysis Report"},
1310+
"shap": {"save_local_shap_values": True, "use_logit": False},
1311+
},
1312+
"predictor": {
1313+
"initial_instance_count": 1,
1314+
"instance_type": "ml.c5.xlarge",
1315+
"label_headers": ["success"],
1316+
"model_name": "xgboost-model",
1317+
"probability": "pr",
1318+
},
1319+
}
13161320
assert actual == expected
13171321

13181322

13191323
def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config):
13201324
actual = _AnalysisConfigGenerator.bias_pre_training(
1321-
data_config,
1322-
data_bias_config,
1323-
methods="all"
1325+
data_config, data_bias_config, methods="all"
13241326
)
1325-
expected = {'dataset_type': 'text/csv',
1326-
'facet': [{'name_or_index': 'F1'}],
1327-
'group_variable': 'F2',
1328-
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1329-
'joinsource_name_or_index': 'F4',
1330-
'label': 'Label',
1331-
'label_values_or_threshold': [1],
1332-
'methods': {
1333-
'report': {'name': 'report', 'title': 'Analysis Report'},
1334-
'pre_training_bias': {'methods': 'all'}}
1335-
}
1327+
expected = {
1328+
"dataset_type": "text/csv",
1329+
"facet": [{"name_or_index": "F1"}],
1330+
"group_variable": "F2",
1331+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1332+
"joinsource_name_or_index": "F4",
1333+
"label": "Label",
1334+
"label_values_or_threshold": [1],
1335+
"methods": {
1336+
"report": {"name": "report", "title": "Analysis Report"},
1337+
"pre_training_bias": {"methods": "all"},
1338+
},
1339+
}
13361340
assert actual == expected
13371341

13381342

1339-
def test_analysis_config_generator_for_bias_post_training(data_config, data_bias_config, model_config):
1343+
def test_analysis_config_generator_for_bias_post_training(
1344+
data_config, data_bias_config, model_config
1345+
):
13401346
model_predicted_label_config = ModelPredictedLabelConfig(
13411347
probability="pr",
13421348
label_headers=["success"],
@@ -1348,22 +1354,26 @@ def test_analysis_config_generator_for_bias_post_training(data_config, data_bias
13481354
methods="all",
13491355
model_config=model_config,
13501356
)
1351-
expected = {'dataset_type': 'text/csv',
1352-
'facet': [{'name_or_index': 'F1'}],
1353-
'group_variable': 'F2',
1354-
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1355-
'joinsource_name_or_index': 'F4',
1356-
'label': 'Label',
1357-
'label_values_or_threshold': [1],
1358-
'methods': {
1359-
'report': {'name': 'report', 'title': 'Analysis Report'},
1360-
'post_training_bias': {'methods': 'all'}
1361-
},
1362-
'predictor': {'initial_instance_count': 1,
1363-
'instance_type': 'ml.c5.xlarge',
1364-
'label_headers': ['success'],
1365-
'model_name': 'xgboost-model',
1366-
'probability': 'pr'}}
1357+
expected = {
1358+
"dataset_type": "text/csv",
1359+
"facet": [{"name_or_index": "F1"}],
1360+
"group_variable": "F2",
1361+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1362+
"joinsource_name_or_index": "F4",
1363+
"label": "Label",
1364+
"label_values_or_threshold": [1],
1365+
"methods": {
1366+
"report": {"name": "report", "title": "Analysis Report"},
1367+
"post_training_bias": {"methods": "all"},
1368+
},
1369+
"predictor": {
1370+
"initial_instance_count": 1,
1371+
"instance_type": "ml.c5.xlarge",
1372+
"label_headers": ["success"],
1373+
"model_name": "xgboost-model",
1374+
"probability": "pr",
1375+
},
1376+
}
13671377
assert actual == expected
13681378

13691379

@@ -1380,20 +1390,25 @@ def test_analysis_config_generator_for_bias(data_config, data_bias_config, model
13801390
pre_training_methods="all",
13811391
post_training_methods="all",
13821392
)
1383-
expected = {'dataset_type': 'text/csv',
1384-
'facet': [{'name_or_index': 'F1'}],
1385-
'group_variable': 'F2',
1386-
'headers': ['Label', 'F1', 'F2', 'F3', 'F4'],
1387-
'joinsource_name_or_index': 'F4',
1388-
'label': 'Label',
1389-
'label_values_or_threshold': [1],
1390-
'methods': {
1391-
'report': {'name': 'report', 'title': 'Analysis Report'},
1392-
'post_training_bias': {'methods': 'all'},
1393-
'pre_training_bias': {'methods': 'all'}},
1394-
'predictor': {'initial_instance_count': 1,
1395-
'instance_type': 'ml.c5.xlarge',
1396-
'label_headers': ['success'],
1397-
'model_name': 'xgboost-model',
1398-
'probability': 'pr'}}
1393+
expected = {
1394+
"dataset_type": "text/csv",
1395+
"facet": [{"name_or_index": "F1"}],
1396+
"group_variable": "F2",
1397+
"headers": ["Label", "F1", "F2", "F3", "F4"],
1398+
"joinsource_name_or_index": "F4",
1399+
"label": "Label",
1400+
"label_values_or_threshold": [1],
1401+
"methods": {
1402+
"report": {"name": "report", "title": "Analysis Report"},
1403+
"post_training_bias": {"methods": "all"},
1404+
"pre_training_bias": {"methods": "all"},
1405+
},
1406+
"predictor": {
1407+
"initial_instance_count": 1,
1408+
"instance_type": "ml.c5.xlarge",
1409+
"label_headers": ["success"],
1410+
"model_name": "xgboost-model",
1411+
"probability": "pr",
1412+
},
1413+
}
13991414
assert actual == expected

0 commit comments

Comments
 (0)