Skip to content

Commit 8eb4b01

Browse files
committed
added run_bias_and_explainability method
1 parent 2bb714f commit 8eb4b01

File tree

2 files changed

+173
-9
lines changed

2 files changed

+173
-9
lines changed

src/sagemaker/clarify.py

Lines changed: 129 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,15 +1368,135 @@ def run_explainability(
13681368
experiment_config,
13691369
)
13701370

1371-
def run_bias_and_explainability(self):
1372-
"""
1373-
TODO:
1374-
- add doc string
1375-
- add logic
1376-
- add tests
1377-
"""
1378-
raise NotImplementedError(
1379-
"Please choose a method of run_pre_training_bias, run_post_training_bias or run_explainability."
1371+
def run_bias_and_explainability(
1372+
self,
1373+
data_config: DataConfig,
1374+
model_config: ModelConfig,
1375+
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
1376+
bias_config: BiasConfig,
1377+
pre_training_methods: Union[str, List[str]] = "all",
1378+
post_training_methods: Union[str, List[str]] = "all",
1379+
model_predicted_label_config: ModelPredictedLabelConfig = None,
1380+
wait=True,
1381+
logs=True,
1382+
job_name=None,
1383+
kms_key=None,
1384+
experiment_config=None,
1385+
):
1386+
"""Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
1387+
1388+
For bias:
1389+
Computes metrics for both the pre-training and the post-training methods.
1390+
To calculate post-training methods, it spins up a model endpoint and runs inference over the
1391+
input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`)
1392+
to obtain predicted labels.
1393+
1394+
For Explainability:
1395+
Spins up a model endpoint.
1396+
1397+
Currently, only SHAP and Partial Dependence Plots (PDP) are supported
1398+
as explainability methods.
1399+
You can request both methods or one at a time with the ``explainability_config`` parameter.
1400+
1401+
When SHAP is requested in the ``explainability_config``,
1402+
the SHAP algorithm calculates the feature importance for each input example
1403+
in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`,
1404+
by creating ``num_samples`` copies of the example with a subset of features
1405+
replaced with values from the ``baseline``.
1406+
It then runs model inference to see how the model's prediction changes with the replaced
1407+
features. If the model output returns multiple scores importance is computed for each score.
1408+
Across examples, feature importance is aggregated using ``agg_method``.
1409+
1410+
When PDP is requested in the ``explainability_config``,
1411+
the PDP algorithm calculates the dependence of the target response
1412+
on the input features and marginalizes over the values of all other input features.
1413+
The Partial Dependence Plots are included in the output
1414+
`report <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-feature-attribute-baselines-reports.html>`__
1415+
and the corresponding values are included in the analysis output.
1416+
1417+
Args:
1418+
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
1419+
model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
1420+
endpoint to be created.
1421+
explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list):
1422+
Config of the specific explainability method or a list of
1423+
:class:`~sagemaker.clarify.ExplainabilityConfig` objects.
1424+
Currently, SHAP and PDP are the two methods supported.
1425+
You can request multiple methods at once by passing in a list of
1426+
`~sagemaker.clarify.ExplainabilityConfig`.
1427+
bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
1428+
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
1429+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-bias-metric-class-imbalance.html>`_",
1430+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-true-label-imbalance.html>`_",
1431+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kl-divergence.html>`_",
1432+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-jensen-shannon-divergence.html>`_",
1433+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-lp-norm.html>`_",
1434+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-total-variation-distance.html>`_",
1435+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-kolmogorov-smirnov.html>`_",
1436+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-data-bias-metric-cddl.html>`_"].
1437+
Defaults to str "all" to run all metrics if left unspecified.
1438+
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
1439+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
1440+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
1441+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
1442+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
1443+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
1444+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
1445+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
1446+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
1447+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
1448+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
1449+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
1450+
Defaults to str "all" to run all metrics if left unspecified.
1451+
model_predicted_label_config (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
1452+
Index or JSONPath to locate the predicted scores in the model output. This is not
1453+
required if the model output is a single score. Alternatively, it can be an instance
1454+
of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1455+
to provide more parameters like ``label_headers``.
1456+
wait (bool): Whether the call should wait until the job completes (default: True).
1457+
logs (bool): Whether to show the logs produced by the job.
1458+
Only meaningful when ``wait`` is True (default: True).
1459+
job_name (str): Processing job name. When ``job_name`` is not specified,
1460+
if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor`
1461+
is specified, the job name will be composed of ``job_name_prefix`` and current
1462+
timestamp; otherwise use ``"Clarify-Explainability"`` as prefix.
1463+
kms_key (str): The ARN of the KMS key that is used to encrypt the
1464+
user code file (default: None).
1465+
experiment_config (dict[str, str]): Experiment management configuration.
1466+
Optionally, the dict can contain three keys:
1467+
``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``.
1468+
1469+
The behavior of setting these keys is as follows:
1470+
1471+
* If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be
1472+
automatically created and the job's Trial Component associated with the Trial.
1473+
* If ``'TrialName'`` is supplied and the Trial already exists,
1474+
the job's Trial Component will be associated with the Trial.
1475+
* If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied,
1476+
the Trial Component will be unassociated.
1477+
* ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio.
1478+
""" # noqa E501 # pylint: disable=c0301
1479+
analysis_config = _AnalysisConfigGenerator.bias_and_explainability(
1480+
data_config,
1481+
model_config,
1482+
model_predicted_label_config,
1483+
explainability_config,
1484+
bias_config,
1485+
pre_training_methods,
1486+
post_training_methods,
1487+
)
1488+
# when name is either not provided (is None) or an empty string ("")
1489+
job_name = job_name or utils.name_from_base(
1490+
self.job_name_prefix or "Clarify-Bias-And-Explainability"
1491+
)
1492+
return self._run(
1493+
data_config,
1494+
analysis_config,
1495+
wait,
1496+
logs,
1497+
job_name,
1498+
kms_key,
1499+
experiment_config,
13801500
)
13811501

13821502

tests/integ/test_clarify.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,50 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak
704704
check_analysis_config(data_config, sagemaker_session, "shap")
705705

706706

707+
def test_bias_and_explainability(
708+
clarify_processor,
709+
data_config,
710+
model_config,
711+
shap_config,
712+
data_bias_config,
713+
sagemaker_session
714+
):
715+
with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES):
716+
clarify_processor.run_bias_and_explainability(
717+
data_config,
718+
model_config,
719+
shap_config,
720+
data_bias_config,
721+
pre_training_methods="all",
722+
post_training_methods="all",
723+
model_predicted_label_config="score",
724+
job_name=utils.unique_name_from_base("clarify-bias-and-explainability"),
725+
wait=True,
726+
)
727+
analysis_result_json = s3.S3Downloader.read_file(
728+
data_config.s3_output_path + "/analysis.json",
729+
sagemaker_session,
730+
)
731+
analysis_result = json.loads(analysis_result_json)
732+
assert (
733+
math.fabs(
734+
analysis_result["explanations"]["kernel_shap"]["label0"]["global_shap_values"]["F2"]
735+
)
736+
<= 1
737+
)
738+
check_analysis_config(data_config, sagemaker_session, "shap")
739+
740+
assert (
741+
math.fabs(
742+
analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][
743+
"value"
744+
]
745+
)
746+
<= 1.0
747+
)
748+
check_analysis_config(data_config, sagemaker_session, "post_training_bias")
749+
750+
707751
def check_analysis_config(data_config, sagemaker_session, method):
708752
analysis_config_json = s3.S3Downloader.read_file(
709753
data_config.s3_output_path + "/analysis_config.json",

0 commit comments

Comments
 (0)