diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 680a426912..e22a82b431 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -350,6 +350,7 @@ def __init__( env=None, tags=None, network_config=None, + job_name_prefix=None, version=None, ): """Initializes a ``Processor`` instance, computing bias metrics and model explanations. @@ -384,9 +385,11 @@ def __init__( A :class:`~sagemaker.network.NetworkConfig` object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. + job_name_prefix (str): Processing job name prefix. version (str): Clarify version want to be used. """ container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version) + self.job_name_prefix = job_name_prefix super(SageMakerClarifyProcessor, self).__init__( role, container_uri, @@ -500,13 +503,22 @@ def run_pre_training_bias( data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. methods (str or list[str]): Selector of a subset of potential metrics: - ["CI", "DPL", "KL", "JS", "LP", "TVD", "KS", "CDDL"]. Defaults to computing all. - # TODO: Provide a pointer to the official documentation of those. + ["`CI `_", + "`DPL `_", + "`KL `_", + "`JS `_", + "`LP `_", + "`TVD `_", + "`KS `_", + "`CDDL `_"]. + Defaults to computing all. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). - job_name (str): Processing job name. If not specified, a name is composed of - "Clarify-Pretraining-Bias" and current timestamp. + job_name (str): Processing job name. When ``job_name`` is not specified, if + ``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name + will be composed of ``job_name_prefix`` and current timestamp; otherwise use + "Clarify-Pretraining-Bias" as prefix. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. @@ -517,7 +529,10 @@ def run_pre_training_bias( analysis_config.update(data_bias_config.get_config()) analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} if job_name is None: - job_name = utils.name_from_base("Clarify-Pretraining-Bias") + if self.job_name_prefix: + job_name = utils.name_from_base(self.job_name_prefix) + else: + job_name = utils.name_from_base("Clarify-Pretraining-Bias") self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config) def run_post_training_bias( @@ -548,14 +563,25 @@ def run_post_training_bias( model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`): Config of how to extract the predicted label from the model output. methods (str or list[str]): Selector of a subset of potential metrics: - # TODO: Provide a pointer to the official documentation of those. - ["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"]. + ["`DPPL `_" + , "`DI `_", + "`DCA `_", + "`DCR `_", + "`RD `_", + "`DAR `_", + "`DRR `_", + "`AD `_", + "`CDDPL `_ + ", "`TE `_", + "`FT `_"]. Defaults to computing all. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). - job_name (str): Processing job name. If not specified, a name is composed of - "Clarify-Posttraining-Bias" and current timestamp. + job_name (str): Processing job name. When ``job_name`` is not specified, if + ``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name + will be composed of ``job_name_prefix`` and current timestamp; otherwise use + "Clarify-Posttraining-Bias" as prefix. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. @@ -573,7 +599,10 @@ def run_post_training_bias( analysis_config["predictor"] = predictor_config _set(probability_threshold, "probability_threshold", analysis_config) if job_name is None: - job_name = utils.name_from_base("Clarify-Posttraining-Bias") + if self.job_name_prefix: + job_name = utils.name_from_base(self.job_name_prefix) + else: + job_name = utils.name_from_base("Clarify-Posttraining-Bias") self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config) def run_bias( @@ -605,18 +634,35 @@ def run_bias( model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`): Config of how to extract the predicted label from the model output. pre_training_methods (str or list[str]): Selector of a subset of potential metrics: - # TODO: Provide a pointer to the official documentation of those. - ["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"]. + ["`CI `_", + "`DPL `_", + "`KL `_", + "`JS `_", + "`LP `_", + "`TVD `_", + "`KS `_", + "`CDDL `_"]. Defaults to computing all. post_training_methods (str or list[str]): Selector of a subset of potential metrics: - # TODO: Provide a pointer to the official documentation of those. - ["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"]. + ["`DPPL `_" + , "`DI `_", + "`DCA `_", + "`DCR `_", + "`RD `_", + "`DAR `_", + "`DRR `_", + "`AD `_", + "`CDDPL `_ + ", "`TE `_", + "`FT `_"]. Defaults to computing all. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). - job_name (str): Processing job name. If not specified, a name is composed of - "Clarify-Bias" and current timestamp. + job_name (str): Processing job name. When ``job_name`` is not specified, if + ``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name + will be composed of ``job_name_prefix`` and current timestamp; otherwise use + "Clarify-Bias" as prefix. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. @@ -641,7 +687,10 @@ def run_bias( "post_training_bias": {"methods": post_training_methods}, } if job_name is None: - job_name = utils.name_from_base("Clarify-Bias") + if self.job_name_prefix: + job_name = utils.name_from_base(self.job_name_prefix) + else: + job_name = utils.name_from_base("Clarify-Bias") self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config) def run_explainability( @@ -679,8 +728,10 @@ def run_explainability( wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). - job_name (str): Processing job name. If not specified, a name is composed of - "Clarify-Explainability" and current timestamp. + job_name (str): Processing job name. When ``job_name`` is not specified, if + ``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name + will be composed of ``job_name_prefix`` and current timestamp; otherwise use + "Clarify-Explainability" as prefix. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. @@ -693,7 +744,10 @@ def run_explainability( analysis_config["methods"] = explainability_config.get_explainability_config() analysis_config["predictor"] = predictor_config if job_name is None: - job_name = utils.name_from_base("Clarify-Explainability") + if self.job_name_prefix: + job_name = utils.name_from_base(self.job_name_prefix) + else: + job_name = utils.name_from_base("Clarify-Explainability") self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config) diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 03004f2faa..9cf17bcf09 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -26,6 +26,10 @@ ) from sagemaker import image_uris +JOB_NAME_PREFIX = "my-prefix" +TIMESTAMP = "2021-06-17-22-29-54-685" +JOB_NAME = "{}-{}".format(JOB_NAME_PREFIX, TIMESTAMP) + def test_uri(): uri = image_uris.retrieve("clarify", "us-west-2") @@ -248,6 +252,17 @@ def clarify_processor(sagemaker_session): ) +@pytest.fixture(scope="module") +def clarify_processor_with_job_name_prefix(sagemaker_session): + return SageMakerClarifyProcessor( + role="AmazonSageMaker-ExecutionRole", + instance_count=1, + instance_type="ml.c5.xlarge", + sagemaker_session=sagemaker_session, + job_name_prefix=JOB_NAME_PREFIX, + ) + + @pytest.fixture(scope="module") def data_config(): return DataConfig( @@ -302,7 +317,14 @@ def shap_config(): ) -def test_pre_training_bias(clarify_processor, data_config, data_bias_config): +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_pre_training_bias( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + data_bias_config, +): with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method: clarify_processor.run_pre_training_bias( data_config, @@ -325,7 +347,7 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config): "group_variable": "F2", "methods": {"pre_training_bias": {"methods": "all"}}, } - mock_method.assert_called_once_with( + mock_method.assert_called_with( data_config, expected_analysis_config, True, @@ -334,10 +356,33 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config): None, {"ExperimentName": "AnExperiment"}, ) + clarify_processor_with_job_name_prefix.run_pre_training_bias( + data_config, + data_bias_config, + wait=True, + experiment_config={"ExperimentName": "AnExperiment"}, + ) + name_from_base.assert_called_with(JOB_NAME_PREFIX) + mock_method.assert_called_with( + data_config, + expected_analysis_config, + True, + True, + JOB_NAME, + None, + {"ExperimentName": "AnExperiment"}, + ) +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) def test_post_training_bias( - clarify_processor, data_config, data_bias_config, model_config, model_predicted_label_config + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + data_bias_config, + model_config, + model_predicted_label_config, ): with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method: clarify_processor.run_post_training_bias( @@ -368,7 +413,7 @@ def test_post_training_bias( "initial_instance_count": 1, }, } - mock_method.assert_called_once_with( + mock_method.assert_called_with( data_config, expected_analysis_config, True, @@ -377,9 +422,35 @@ def test_post_training_bias( None, {"ExperimentName": "AnExperiment"}, ) + clarify_processor_with_job_name_prefix.run_post_training_bias( + data_config, + data_bias_config, + model_config, + model_predicted_label_config, + wait=True, + experiment_config={"ExperimentName": "AnExperiment"}, + ) + name_from_base.assert_called_with(JOB_NAME_PREFIX) + mock_method.assert_called_with( + data_config, + expected_analysis_config, + True, + True, + JOB_NAME, + None, + {"ExperimentName": "AnExperiment"}, + ) -def test_shap(clarify_processor, data_config, model_config, shap_config): +@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) +def test_shap( + name_from_base, + clarify_processor, + clarify_processor_with_job_name_prefix, + data_config, + model_config, + shap_config, +): with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method: clarify_processor.run_explainability( data_config, @@ -420,7 +491,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config): "initial_instance_count": 1, }, } - mock_method.assert_called_once_with( + mock_method.assert_called_with( data_config, expected_analysis_config, True, @@ -429,3 +500,21 @@ def test_shap(clarify_processor, data_config, model_config, shap_config): None, {"ExperimentName": "AnExperiment"}, ) + clarify_processor_with_job_name_prefix.run_explainability( + data_config, + model_config, + shap_config, + model_scores=None, + wait=True, + experiment_config={"ExperimentName": "AnExperiment"}, + ) + name_from_base.assert_called_with(JOB_NAME_PREFIX) + mock_method.assert_called_with( + data_config, + expected_analysis_config, + True, + True, + JOB_NAME, + None, + {"ExperimentName": "AnExperiment"}, + )