Skip to content

Commit 48131a5

Browse files
committed
Added data types and Formatted
1 parent 67708bf commit 48131a5

File tree

3 files changed

+240
-93
lines changed

3 files changed

+240
-93
lines changed

src/sagemaker/clarify.py

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import tempfile
2727
from abc import ABC, abstractmethod
28-
from typing import List, Union
28+
from typing import List, Union, Dict
2929

3030
from sagemaker import image_uris, s3, utils
3131
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
@@ -173,7 +173,11 @@ def __init__(
173173
_set(joinsource, "joinsource_name_or_index", self.analysis_config)
174174
_set(facet_dataset_uri, "facet_dataset_uri", self.analysis_config)
175175
_set(facet_headers, "facet_headers", self.analysis_config)
176-
_set(predicted_label_dataset_uri, "predicted_label_dataset_uri", self.analysis_config)
176+
_set(
177+
predicted_label_dataset_uri,
178+
"predicted_label_dataset_uri",
179+
self.analysis_config,
180+
)
177181
_set(predicted_label_headers, "predicted_label_headers", self.analysis_config)
178182
_set(predicted_label, "predicted_label", self.analysis_config)
179183
_set(excluded_columns, "excluded_columns", self.analysis_config)
@@ -239,7 +243,8 @@ def __init__(
239243
assert len(facet_name) > 0, "Please provide at least one facet"
240244
if facet_values_or_threshold is None:
241245
facet_list = [
242-
{"name_or_index": single_facet_name} for single_facet_name in facet_name
246+
{"name_or_index": single_facet_name}
247+
for single_facet_name in facet_name
243248
]
244249
elif len(facet_values_or_threshold) == len(facet_name):
245250
facet_list = []
@@ -492,7 +497,10 @@ def __init__(self, features=None, grid_resolution=15, top_k_features=10):
492497
top_k_features (int): Sets the number of top SHAP attributes used to compute
493498
partial dependence plots.
494499
""" # noqa E501
495-
self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features}
500+
self.pdp_config = {
501+
"grid_resolution": grid_resolution,
502+
"top_k_features": top_k_features,
503+
}
496504
if features is not None:
497505
self.pdp_config["features"] = features
498506

@@ -825,9 +833,14 @@ def __init__(
825833
image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image
826834
features. Default is None.
827835
""" # noqa E501 # pylint: disable=c0301
828-
if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]:
836+
if agg_method is not None and agg_method not in [
837+
"mean_abs",
838+
"median",
839+
"mean_sq",
840+
]:
829841
raise ValueError(
830-
f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq."
842+
f"Invalid agg_method {agg_method}."
843+
f" Please choose mean_abs, median, or mean_sq."
831844
)
832845
if num_clusters is not None and baseline is not None:
833846
raise ValueError(
@@ -923,7 +936,9 @@ def __init__(
923936
job_name_prefix (str): Processing job name prefix.
924937
version (str): Clarify version to use.
925938
""" # noqa E501 # pylint: disable=c0301
926-
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
939+
container_uri = image_uris.retrieve(
940+
"clarify", sagemaker_session.boto_region_name, version
941+
)
927942
self._last_analysis_config = None
928943
self.job_name_prefix = job_name_prefix
929944
super(SageMakerClarifyProcessor, self).__init__(
@@ -996,7 +1011,8 @@ def _run(
9961011
json.dump(analysis_config, f)
9971012
s3_analysis_config_file = _upload_analysis_config(
9981013
analysis_config_file,
999-
data_config.s3_analysis_config_output_path or data_config.s3_output_path,
1014+
data_config.s3_analysis_config_output_path
1015+
or data_config.s3_output_path,
10001016
self.sagemaker_session,
10011017
kms_key,
10021018
)
@@ -1168,7 +1184,11 @@ def run_post_training_bias(
11681184
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
11691185
""" # noqa E501 # pylint: disable=c0301
11701186
analysis_config = _AnalysisConfigGenerator.bias_post_training(
1171-
data_config, data_bias_config, model_predicted_label_config, methods, model_config
1187+
data_config,
1188+
data_bias_config,
1189+
model_predicted_label_config,
1190+
methods,
1191+
model_config,
11721192
)
11731193
# when name is either not provided (is None) or an empty string ("")
11741194
job_name = job_name or utils.name_from_base(
@@ -1267,7 +1287,9 @@ def run_bias(
12671287
post_training_methods,
12681288
)
12691289
# when name is either not provided (is None) or an empty string ("")
1270-
job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias")
1290+
job_name = job_name or utils.name_from_base(
1291+
self.job_name_prefix or "Clarify-Bias"
1292+
)
12711293
return self._run(
12721294
data_config,
12731295
analysis_config,
@@ -1450,8 +1472,8 @@ def run_bias_and_explainability(
14501472
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
14511473
Defaults to str "all" to run all metrics if left unspecified.
14521474
model_predicted_label_config (
1453-
int or
1454-
str or
1475+
int or
1476+
str or
14551477
:class:`~sagemaker.clarify.ModelPredictedLabelConfig`
14561478
):
14571479
Index or JSONPath to locate the predicted scores in the model output. This is not
@@ -1552,11 +1574,16 @@ def explainability(
15521574

15531575
@classmethod
15541576
def bias_pre_training(
1555-
cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]]
1577+
cls,
1578+
data_config: DataConfig,
1579+
bias_config: BiasConfig,
1580+
methods: Union[str, List[str]],
15561581
):
15571582
"""Generates a config for Bias Pre Training"""
15581583
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1559-
analysis_config = cls._add_methods(analysis_config, pre_training_methods=methods)
1584+
analysis_config = cls._add_methods(
1585+
analysis_config, pre_training_methods=methods
1586+
)
15601587
return analysis_config
15611588

15621589
@classmethod
@@ -1570,7 +1597,9 @@ def bias_post_training(
15701597
):
15711598
"""Generates a config for Bias Post Training"""
15721599
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
1573-
analysis_config = cls._add_methods(analysis_config, post_training_methods=methods)
1600+
analysis_config = cls._add_methods(
1601+
analysis_config, post_training_methods=methods
1602+
)
15741603
analysis_config = cls._add_predictor(
15751604
analysis_config, model_config, model_predicted_label_config
15761605
)
@@ -1599,7 +1628,12 @@ def bias(
15991628
return analysis_config
16001629

16011630
@classmethod
1602-
def _add_predictor(cls, analysis_config, model_config, model_predicted_label_config):
1631+
def _add_predictor(
1632+
cls,
1633+
analysis_config: Dict,
1634+
model_config: ModelConfig,
1635+
model_predicted_label_config: ModelPredictedLabelConfig,
1636+
):
16031637
"""Extends analysis config with predictor."""
16041638
analysis_config = {**analysis_config}
16051639
analysis_config["predictor"] = model_config.get_predictor_config()
@@ -1618,16 +1652,18 @@ def _add_predictor(cls, analysis_config, model_config, model_predicted_label_con
16181652
@classmethod
16191653
def _add_methods(
16201654
cls,
1621-
analysis_config,
1622-
pre_training_methods=None,
1623-
post_training_methods=None,
1624-
explainability_config=None,
1655+
analysis_config: Dict,
1656+
pre_training_methods: Union[str, List[str]] = "all",
1657+
post_training_methods: Union[str, List[str]] = "all",
1658+
explainability_config: Union[
1659+
ExplainabilityConfig, List[ExplainabilityConfig]
1660+
] = None,
16251661
report=True,
16261662
):
16271663
"""Extends analysis config with methods."""
16281664
# validate
16291665
params = [pre_training_methods, post_training_methods, explainability_config]
1630-
if all([1 if p is None else 0 for p in params]):
1666+
if all(1 if p is None else 0 for p in params):
16311667
raise AttributeError(
16321668
"analysis_config must have at least one working method: "
16331669
"One of the "
@@ -1640,22 +1676,35 @@ def _add_methods(
16401676
analysis_config["methods"] = {}
16411677

16421678
if report:
1643-
analysis_config["methods"]["report"] = {"name": "report", "title": "Analysis Report"}
1679+
analysis_config["methods"]["report"] = {
1680+
"name": "report",
1681+
"title": "Analysis Report",
1682+
}
16441683

16451684
if pre_training_methods:
1646-
analysis_config["methods"]["pre_training_bias"] = {"methods": pre_training_methods}
1685+
analysis_config["methods"]["pre_training_bias"] = {
1686+
"methods": pre_training_methods
1687+
}
16471688

16481689
if post_training_methods:
1649-
analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods}
1690+
analysis_config["methods"]["post_training_bias"] = {
1691+
"methods": post_training_methods
1692+
}
16501693

16511694
if explainability_config is not None:
1652-
explainability_methods = cls._merge_explainability_configs(explainability_config)
1653-
analysis_config["methods"] = {**analysis_config["methods"], **explainability_methods}
1695+
explainability_methods = cls._merge_explainability_configs(
1696+
explainability_config
1697+
)
1698+
analysis_config["methods"] = {
1699+
**analysis_config["methods"],
1700+
**explainability_methods,
1701+
}
16541702
return analysis_config
16551703

16561704
@classmethod
16571705
def _merge_explainability_configs(
1658-
cls, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]]
1706+
cls,
1707+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
16591708
):
16601709
"""Merges explainability configs, when more than one."""
16611710
if isinstance(explainability_config, list):
@@ -1671,17 +1720,24 @@ def _merge_explainability_configs(
16711720
"shap" not in explainability_methods
16721721
and "features" not in explainability_methods["pdp"]
16731722
):
1674-
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1723+
raise ValueError(
1724+
"PDP features must be provided when ShapConfig is not provided"
1725+
)
16751726
return explainability_methods
16761727
if (
16771728
isinstance(explainability_config, PDPConfig)
1678-
and "features" not in explainability_config.get_explainability_config()["pdp"]
1729+
and "features"
1730+
not in explainability_config.get_explainability_config()["pdp"]
16791731
):
1680-
raise ValueError("PDP features must be provided when ShapConfig is not provided")
1732+
raise ValueError(
1733+
"PDP features must be provided when ShapConfig is not provided"
1734+
)
16811735
return explainability_config.get_explainability_config()
16821736

16831737

1684-
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
1738+
def _upload_analysis_config(
1739+
analysis_config_file, s3_output_path, sagemaker_session, kms_key
1740+
):
16851741
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.
16861742
16871743
Args:

0 commit comments

Comments
 (0)