Skip to content

feature: support job_name_prefix for Clarify #2471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 74 additions & 20 deletions src/sagemaker/clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(
env=None,
tags=None,
network_config=None,
job_name_prefix=None,
version=None,
):
"""Initializes a ``Processor`` instance, computing bias metrics and model explanations.
Expand Down Expand Up @@ -384,9 +385,11 @@ def __init__(
A :class:`~sagemaker.network.NetworkConfig`
object that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets.
job_name_prefix (str): Processing job name prefix.
version (str): Clarify version want to be used.
"""
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
self.job_name_prefix = job_name_prefix
super(SageMakerClarifyProcessor, self).__init__(
role,
container_uri,
Expand Down Expand Up @@ -500,13 +503,22 @@ def run_pre_training_bias(
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
methods (str or list[str]): Selector of a subset of potential metrics:
["CI", "DPL", "KL", "JS", "LP", "TVD", "KS", "CDDL"]. Defaults to computing all.
# TODO: Provide a pointer to the official documentation of those.
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ci.html>`_",
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dpl.html>`_",
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-kl.html>`_",
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-js.html>`_",
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-lp.html>`_",
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-tvd.html>`_",
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ks.html>`_",
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cdd.html>`_"].
Defaults to computing all.
wait (bool): Whether the call should wait until the job completes (default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when ``wait`` is True (default: True).
job_name (str): Processing job name. If not specified, a name is composed of
"Clarify-Pretraining-Bias" and current timestamp.
job_name (str): Processing job name. When ``job_name`` is not specified, if
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
"Clarify-Pretraining-Bias" as prefix.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
experiment_config (dict[str, str]): Experiment management configuration.
Expand All @@ -517,7 +529,10 @@ def run_pre_training_bias(
analysis_config.update(data_bias_config.get_config())
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
if job_name is None:
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
if self.job_name_prefix:
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)

def run_post_training_bias(
Expand Down Expand Up @@ -548,14 +563,25 @@ def run_post_training_bias(
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
Config of how to extract the predicted label from the model output.
methods (str or list[str]): Selector of a subset of potential metrics:
# TODO: Provide a pointer to the official documentation of those.
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
Defaults to computing all.
wait (bool): Whether the call should wait until the job completes (default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when ``wait`` is True (default: True).
job_name (str): Processing job name. If not specified, a name is composed of
"Clarify-Posttraining-Bias" and current timestamp.
job_name (str): Processing job name. When ``job_name`` is not specified, if
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
"Clarify-Posttraining-Bias" as prefix.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
experiment_config (dict[str, str]): Experiment management configuration.
Expand All @@ -573,7 +599,10 @@ def run_post_training_bias(
analysis_config["predictor"] = predictor_config
_set(probability_threshold, "probability_threshold", analysis_config)
if job_name is None:
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
if self.job_name_prefix:
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)

def run_bias(
Expand Down Expand Up @@ -605,18 +634,35 @@ def run_bias(
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
Config of how to extract the predicted label from the model output.
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
# TODO: Provide a pointer to the official documentation of those.
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ci.html>`_",
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dpl.html>`_",
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-kl.html>`_",
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-js.html>`_",
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-lp.html>`_",
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-tvd.html>`_",
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ks.html>`_",
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cdd.html>`_"].
Defaults to computing all.
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
# TODO: Provide a pointer to the official documentation of those.
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
Defaults to computing all.
wait (bool): Whether the call should wait until the job completes (default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when ``wait`` is True (default: True).
job_name (str): Processing job name. If not specified, a name is composed of
"Clarify-Bias" and current timestamp.
job_name (str): Processing job name. When ``job_name`` is not specified, if
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
"Clarify-Bias" as prefix.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
experiment_config (dict[str, str]): Experiment management configuration.
Expand All @@ -641,7 +687,10 @@ def run_bias(
"post_training_bias": {"methods": post_training_methods},
}
if job_name is None:
job_name = utils.name_from_base("Clarify-Bias")
if self.job_name_prefix:
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Bias")
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)

def run_explainability(
Expand Down Expand Up @@ -679,8 +728,10 @@ def run_explainability(
wait (bool): Whether the call should wait until the job completes (default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when ``wait`` is True (default: True).
job_name (str): Processing job name. If not specified, a name is composed of
"Clarify-Explainability" and current timestamp.
job_name (str): Processing job name. When ``job_name`` is not specified, if
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
"Clarify-Explainability" as prefix.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
experiment_config (dict[str, str]): Experiment management configuration.
Expand All @@ -693,7 +744,10 @@ def run_explainability(
analysis_config["methods"] = explainability_config.get_explainability_config()
analysis_config["predictor"] = predictor_config
if job_name is None:
job_name = utils.name_from_base("Clarify-Explainability")
if self.job_name_prefix:
job_name = utils.name_from_base(self.job_name_prefix)
else:
job_name = utils.name_from_base("Clarify-Explainability")
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)


Expand Down
101 changes: 95 additions & 6 deletions tests/unit/test_clarify.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
)
from sagemaker import image_uris

JOB_NAME_PREFIX = "my-prefix"
TIMESTAMP = "2021-06-17-22-29-54-685"
JOB_NAME = "{}-{}".format(JOB_NAME_PREFIX, TIMESTAMP)


def test_uri():
uri = image_uris.retrieve("clarify", "us-west-2")
Expand Down Expand Up @@ -248,6 +252,17 @@ def clarify_processor(sagemaker_session):
)


@pytest.fixture(scope="module")
def clarify_processor_with_job_name_prefix(sagemaker_session):
return SageMakerClarifyProcessor(
role="AmazonSageMaker-ExecutionRole",
instance_count=1,
instance_type="ml.c5.xlarge",
sagemaker_session=sagemaker_session,
job_name_prefix=JOB_NAME_PREFIX,
)


@pytest.fixture(scope="module")
def data_config():
return DataConfig(
Expand Down Expand Up @@ -302,7 +317,14 @@ def shap_config():
)


def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_pre_training_bias(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
data_bias_config,
):
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
clarify_processor.run_pre_training_bias(
data_config,
Expand All @@ -325,7 +347,7 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
"group_variable": "F2",
"methods": {"pre_training_bias": {"methods": "all"}},
}
mock_method.assert_called_once_with(
mock_method.assert_called_with(
data_config,
expected_analysis_config,
True,
Expand All @@ -334,10 +356,33 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
None,
{"ExperimentName": "AnExperiment"},
)
clarify_processor_with_job_name_prefix.run_pre_training_bias(
data_config,
data_bias_config,
wait=True,
experiment_config={"ExperimentName": "AnExperiment"},
)
name_from_base.assert_called_with(JOB_NAME_PREFIX)
mock_method.assert_called_with(
data_config,
expected_analysis_config,
True,
True,
JOB_NAME,
None,
{"ExperimentName": "AnExperiment"},
)


@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_post_training_bias(
clarify_processor, data_config, data_bias_config, model_config, model_predicted_label_config
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
data_bias_config,
model_config,
model_predicted_label_config,
):
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
clarify_processor.run_post_training_bias(
Expand Down Expand Up @@ -368,7 +413,7 @@ def test_post_training_bias(
"initial_instance_count": 1,
},
}
mock_method.assert_called_once_with(
mock_method.assert_called_with(
data_config,
expected_analysis_config,
True,
Expand All @@ -377,9 +422,35 @@ def test_post_training_bias(
None,
{"ExperimentName": "AnExperiment"},
)
clarify_processor_with_job_name_prefix.run_post_training_bias(
data_config,
data_bias_config,
model_config,
model_predicted_label_config,
wait=True,
experiment_config={"ExperimentName": "AnExperiment"},
)
name_from_base.assert_called_with(JOB_NAME_PREFIX)
mock_method.assert_called_with(
data_config,
expected_analysis_config,
True,
True,
JOB_NAME,
None,
{"ExperimentName": "AnExperiment"},
)


def test_shap(clarify_processor, data_config, model_config, shap_config):
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
def test_shap(
name_from_base,
clarify_processor,
clarify_processor_with_job_name_prefix,
data_config,
model_config,
shap_config,
):
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
clarify_processor.run_explainability(
data_config,
Expand Down Expand Up @@ -420,7 +491,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
"initial_instance_count": 1,
},
}
mock_method.assert_called_once_with(
mock_method.assert_called_with(
data_config,
expected_analysis_config,
True,
Expand All @@ -429,3 +500,21 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
None,
{"ExperimentName": "AnExperiment"},
)
clarify_processor_with_job_name_prefix.run_explainability(
data_config,
model_config,
shap_config,
model_scores=None,
wait=True,
experiment_config={"ExperimentName": "AnExperiment"},
)
name_from_base.assert_called_with(JOB_NAME_PREFIX)
mock_method.assert_called_with(
data_config,
expected_analysis_config,
True,
True,
JOB_NAME,
None,
{"ExperimentName": "AnExperiment"},
)