Skip to content

Commit 4fa6f98

Browse files
committed
added run_bias_and_explainability method
1 parent 64a69c5 commit 4fa6f98

File tree

2 files changed

+168
-9
lines changed

2 files changed

+168
-9
lines changed

src/sagemaker/clarify.py

Lines changed: 129 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,15 +1369,135 @@ def run_explainability(
13691369
experiment_config,
13701370
)
13711371

1372-
def run_bias_and_explainability(self):
1373-
"""
1374-
TODO:
1375-
- add doc string
1376-
- add logic
1377-
- add tests
1378-
"""
1379-
raise NotImplementedError(
1380-
"Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1372+
def run_bias_and_explainability(
1373+
self,
1374+
data_config: DataConfig,
1375+
model_config: ModelConfig,
1376+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
1377+
bias_config: BiasConfig,
1378+
pre_training_methods: Union[str, List[str]] = "all",
1379+
post_training_methods: Union[str, List[str]] = "all",
1380+
model_predicted_label_config: ModelPredictedLabelConfig = None,
1381+
wait=True,
1382+
logs=True,
1383+
job_name=None,
1384+
kms_key=None,
1385+
experiment_config=None,
1386+
):
1387+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
1388+
1389+
For bias:
1390+
Computes metrics for both the pre-training and the post-training methods.
1391+
To calculate post-training methods, it spins up a model endpoint and runs inference over the
1392+
input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
1393+
to obtain predicted labels.
1394+
1395+
For Explainability:
1396+
Spins up a model endpoint.
1397+
1398+
Currently, only SHAP and Partial Dependence Plots (PDP) are supported
1399+
as explainability methods.
1400+
You can request both methods or one at a time with the ``explainability_config`` parameter.
1401+
1402+
When SHAP is requested in the ``explainability_config``,
1403+
the SHAP algorithm calculates the feature importance for each input example
1404+
in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
1405+
by creating ``num_samples`` copies of the example with a subset of features
1406+
replaced with values from the ``baseline``.
1407+
It then runs model inference to see how the model's prediction changes with the replaced
1408+
features. If the model output returns multiple scores importance is computed for each score.
1409+
Across examples, feature importance is aggregated using ``agg_method``.
1410+
1411+
When PDP is requested in the ``explainability_config``,
1412+
the PDP algorithm calculates the dependence of the target response
1413+
on the input features and marginalizes over the values of all other input features.
1414+
The Partial Dependence Plots are included in the output
1415+
`report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
1416+
and the corresponding values are included in the analysis output.
1417+
1418+
Args:
1419+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1420+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1421+
endpoint to be created.
1422+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
1423+
Config of the specific explainability method or a list of
1424+
:class:`~sagemaker.clarify.ExplainabilityConfig` objects.
1425+
Currently, SHAP and PDP are the two methods supported.
1426+
You can request multiple methods at once by passing in a list of
1427+
`~sagemaker.clarify.ExplainabilityConfig`.
1428+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1429+
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
1430+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
1431+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
1432+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
1433+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
1434+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
1435+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
1436+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
1437+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
1438+
Defaults to str "all" to run all metrics if left unspecified.
1439+
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
1440+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
1441+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
1442+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
1443+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
1444+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
1445+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
1446+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
1447+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
1448+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
1449+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
1450+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
1451+
Defaults to str "all" to run all metrics if left unspecified.
1452+
model_predicted_label_config (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
1453+
Index or JSONPath to locate the predicted scores in the model output. This is not
1454+
required if the model output is a single score. Alternatively, it can be an instance
1455+
of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1456+
to provide more parameters like ``label_headers``.
1457+
wait (bool): Whether the call should wait until the job completes (default: True).
1458+
logs (bool): Whether to show the logs produced by the job.
1459+
Only meaningful when ``wait`` is True (default: True).
1460+
job_name (str): Processing job name. When ``job_name`` is not specified,
1461+
if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1462+
is specified, the job name will be composed of ``job_name_prefix`` and current
1463+
timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
1464+
kms_key (str): The ARN of the KMS key that is used to encrypt the
1465+
user code file (default: None).
1466+
experiment_config (dict[str, str]): Experiment management configuration.
1467+
Optionally, the dict can contain three keys:
1468+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
1469+
1470+
The behavior of setting these keys is as follows:
1471+
1472+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
1473+
automatically created and the job's Trial Component associated with the Trial.
1474+
* If ``'TrialName'`` is supplied and the Trial already exists,
1475+
the job's Trial Component will be associated with the Trial.
1476+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
1477+
the Trial Component will be unassociated.
1478+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1479+
""" # noqa E501 # pylint: disable=c0301
1480+
analysis_config = _AnalysisConfigGenerator.bias_and_explainability(
1481+
data_config,
1482+
model_config,
1483+
model_predicted_label_config,
1484+
explainability_config,
1485+
bias_config,
1486+
pre_training_methods,
1487+
post_training_methods,
1488+
)
1489+
# when name is either not provided (is None) or an empty string ("")
1490+
job_name = job_name or utils.name_from_base(
1491+
self.job_name_prefix or "Clarify-Bias-And-Explainability"
1492+
)
1493+
return self._run(
1494+
data_config,
1495+
analysis_config,
1496+
wait,
1497+
logs,
1498+
job_name,
1499+
kms_key,
1500+
experiment_config,
13811501
)
13821502

13831503

tests/integ/test_clarify.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,45 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak
704704
check_analysis_config(data_config, sagemaker_session, "shap")
705705

706706

707+
def test_bias_and_explainability(
708+
clarify_processor, data_config, model_config, shap_config, data_bias_config, sagemaker_session
709+
):
710+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
711+
clarify_processor.run_bias_and_explainability(
712+
data_config,
713+
model_config,
714+
shap_config,
715+
data_bias_config,
716+
pre_training_methods="all",
717+
post_training_methods="all",
718+
model_predicted_label_config="score",
719+
job_name=utils.unique_name_from_base("clarify-bias-and-explainability"),
720+
wait=True,
721+
)
722+
analysis_result_json = s3.S3Downloader.read_file(
723+
data_config.s3_output_path + "/analysis.json",
724+
sagemaker_session,
725+
)
726+
analysis_result = json.loads(analysis_result_json)
727+
assert (
728+
math.fabs(
729+
analysis_result["explanations"]["kernel_shap"]["label0"]["global_shap_values"]["F2"]
730+
)
731+
<= 1
732+
)
733+
check_analysis_config(data_config, sagemaker_session, "shap")
734+
735+
assert (
736+
math.fabs(
737+
analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][
738+
"value"
739+
]
740+
)
741+
<= 1.0
742+
)
743+
check_analysis_config(data_config, sagemaker_session, "post_training_bias")
744+
745+
707746
def check_analysis_config(data_config, sagemaker_session, method):
708747
analysis_config_json = s3.S3Downloader.read_file(
709748
data_config.s3_output_path + "/analysis_config.json",

0 commit comments

Comments
 (0)