Skip to content

Commit 67708bf

Browse files
committed
added run_bias_and_explainability method
1 parent 64a69c5 commit 67708bf

File tree

2 files changed

+178
-10
lines changed

2 files changed

+178
-10
lines changed

src/sagemaker/clarify.py

Lines changed: 139 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,15 +1369,139 @@ def run_explainability(
13691369
experiment_config,
13701370
)
13711371

1372-
def run_bias_and_explainability(self):
1373-
"""
1374-
TODO:
1375-
- add doc string
1376-
- add logic
1377-
- add tests
1378-
"""
1379-
raise NotImplementedError(
1380-
"Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1372+
def run_bias_and_explainability(
1373+
self,
1374+
data_config: DataConfig,
1375+
model_config: ModelConfig,
1376+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
1377+
bias_config: BiasConfig,
1378+
pre_training_methods: Union[str, List[str]] = "all",
1379+
post_training_methods: Union[str, List[str]] = "all",
1380+
model_predicted_label_config: ModelPredictedLabelConfig = None,
1381+
wait=True,
1382+
logs=True,
1383+
job_name=None,
1384+
kms_key=None,
1385+
experiment_config=None,
1386+
):
1387+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
1388+
1389+
For bias:
1390+
Computes metrics for both the pre-training and the post-training methods.
1391+
To calculate post-training methods, it spins up a model endpoint and runs inference over the
1392+
input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
1393+
to obtain predicted labels.
1394+
1395+
For Explainability:
1396+
Spins up a model endpoint.
1397+
1398+
Currently, only SHAP and Partial Dependence Plots (PDP) are supported
1399+
as explainability methods.
1400+
You can request both methods or one at a time with the ``explainability_config`` parameter.
1401+
1402+
When SHAP is requested in the ``explainability_config``,
1403+
the SHAP algorithm calculates the feature importance for each input example
1404+
in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
1405+
by creating ``num_samples`` copies of the example with a subset of features
1406+
replaced with values from the ``baseline``.
1407+
It then runs model inference to see how the model's prediction changes with the replaced
1408+
features. If the model output returns multiple scores importance is computed for each score.
1409+
Across examples, feature importance is aggregated using ``agg_method``.
1410+
1411+
When PDP is requested in the ``explainability_config``,
1412+
the PDP algorithm calculates the dependence of the target response
1413+
on the input features and marginalizes over the values of all other input features.
1414+
The Partial Dependence Plots are included in the output
1415+
`report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
1416+
and the corresponding values are included in the analysis output.
1417+
1418+
Args:
1419+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1420+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1421+
endpoint to be created.
1422+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
1423+
Config of the specific explainability method or a list of
1424+
:class:`~sagemaker.clarify.ExplainabilityConfig` objects.
1425+
Currently, SHAP and PDP are the two methods supported.
1426+
You can request multiple methods at once by passing in a list of
1427+
`~sagemaker.clarify.ExplainabilityConfig`.
1428+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1429+
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
1430+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
1431+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
1432+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
1433+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
1434+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
1435+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
1436+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
1437+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
1438+
Defaults to str "all" to run all metrics if left unspecified.
1439+
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
1440+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
1441+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
1442+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
1443+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
1444+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
1445+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
1446+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
1447+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
1448+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
1449+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
1450+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
1451+
Defaults to str "all" to run all metrics if left unspecified.
1452+
model_predicted_label_config (
1453+
int or
1454+
str or
1455+
:class:`~sagemaker.clarify.ModelPredictedLabelConfig`
1456+
):
1457+
Index or JSONPath to locate the predicted scores in the model output. This is not
1458+
required if the model output is a single score. Alternatively, it can be an instance
1459+
of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1460+
to provide more parameters like ``label_headers``.
1461+
wait (bool): Whether the call should wait until the job completes (default: True).
1462+
logs (bool): Whether to show the logs produced by the job.
1463+
Only meaningful when ``wait`` is True (default: True).
1464+
job_name (str): Processing job name. When ``job_name`` is not specified,
1465+
if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1466+
is specified, the job name will be composed of ``job_name_prefix`` and current
1467+
timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
1468+
kms_key (str): The ARN of the KMS key that is used to encrypt the
1469+
user code file (default: None).
1470+
experiment_config (dict[str, str]): Experiment management configuration.
1471+
Optionally, the dict can contain three keys:
1472+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
1473+
1474+
The behavior of setting these keys is as follows:
1475+
1476+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
1477+
automatically created and the job's Trial Component associated with the Trial.
1478+
* If ``'TrialName'`` is supplied and the Trial already exists,
1479+
the job's Trial Component will be associated with the Trial.
1480+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
1481+
the Trial Component will be unassociated.
1482+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1483+
""" # noqa E501 # pylint: disable=c0301
1484+
analysis_config = _AnalysisConfigGenerator.bias_and_explainability(
1485+
data_config,
1486+
model_config,
1487+
model_predicted_label_config,
1488+
explainability_config,
1489+
bias_config,
1490+
pre_training_methods,
1491+
post_training_methods,
1492+
)
1493+
# when name is either not provided (is None) or an empty string ("")
1494+
job_name = job_name or utils.name_from_base(
1495+
self.job_name_prefix or "Clarify-Bias-And-Explainability"
1496+
)
1497+
return self._run(
1498+
data_config,
1499+
analysis_config,
1500+
wait,
1501+
logs,
1502+
job_name,
1503+
kms_key,
1504+
experiment_config,
13811505
)
13821506

13831507

@@ -1395,6 +1519,7 @@ def bias_and_explainability(
13951519
pre_training_methods: Union[str, List[str]] = "all",
13961520
post_training_methods: Union[str, List[str]] = "all",
13971521
):
1522+
"""Generates a config for Bias and Explainability"""
13981523
analysis_config = {**data_config.get_config(), **bias_config.get_config()}
13991524
analysis_config = cls._add_methods(
14001525
analysis_config,
@@ -1475,6 +1600,7 @@ def bias(
14751600

14761601
@classmethod
14771602
def _add_predictor(cls, analysis_config, model_config, model_predicted_label_config):
1603+
"""Extends analysis config with predictor."""
14781604
analysis_config = {**analysis_config}
14791605
analysis_config["predictor"] = model_config.get_predictor_config()
14801606
if isinstance(model_predicted_label_config, ModelPredictedLabelConfig):
@@ -1498,12 +1624,14 @@ def _add_methods(
14981624
explainability_config=None,
14991625
report=True,
15001626
):
1627+
"""Extends analysis config with methods."""
15011628
# validate
15021629
params = [pre_training_methods, post_training_methods, explainability_config]
15031630
if all([1 if p is None else 0 for p in params]):
15041631
raise AttributeError(
15051632
"analysis_config must have at least one working method: "
1506-
"One of the `pre_training_methods`, `post_training_methods`, `explainability_config`."
1633+
"One of the "
1634+
"`pre_training_methods`, `post_training_methods`, `explainability_config`."
15071635
)
15081636

15091637
# main logic
@@ -1529,6 +1657,7 @@ def _add_methods(
15291657
def _merge_explainability_configs(
15301658
cls, explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]]
15311659
):
1660+
"""Merges explainability configs, when more than one."""
15321661
if isinstance(explainability_config, list):
15331662
explainability_methods = {}
15341663
if len(explainability_config) == 0:

tests/integ/test_clarify.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,45 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak
704704
check_analysis_config(data_config, sagemaker_session, "shap")
705705

706706

707+
def test_bias_and_explainability(
708+
clarify_processor, data_config, model_config, shap_config, data_bias_config, sagemaker_session
709+
):
710+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
711+
clarify_processor.run_bias_and_explainability(
712+
data_config,
713+
model_config,
714+
shap_config,
715+
data_bias_config,
716+
pre_training_methods="all",
717+
post_training_methods="all",
718+
model_predicted_label_config="score",
719+
job_name=utils.unique_name_from_base("clarify-bias-and-explainability"),
720+
wait=True,
721+
)
722+
analysis_result_json = s3.S3Downloader.read_file(
723+
data_config.s3_output_path + "/analysis.json",
724+
sagemaker_session,
725+
)
726+
analysis_result = json.loads(analysis_result_json)
727+
assert (
728+
math.fabs(
729+
analysis_result["explanations"]["kernel_shap"]["label0"]["global_shap_values"]["F2"]
730+
)
731+
<= 1
732+
)
733+
check_analysis_config(data_config, sagemaker_session, "shap")
734+
735+
assert (
736+
math.fabs(
737+
analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][
738+
"value"
739+
]
740+
)
741+
<= 1.0
742+
)
743+
check_analysis_config(data_config, sagemaker_session, "post_training_bias")
744+
745+
707746
def check_analysis_config(data_config, sagemaker_session, method):
708747
analysis_config_json = s3.S3Downloader.read_file(
709748
data_config.s3_output_path + "/analysis_config.json",

0 commit comments

Comments
 (0)