Skip to content

Commit 22ba84e

Browse files
feature: support job_name_prefix for Clarify (#2471)
* feat: add job_name_prefix * feat: remove TODO links * fix: fix link Co-authored-by: icywang86rui <[email protected]>
1 parent 530d21b commit 22ba84e

File tree

2 files changed

+169
-26
lines changed

2 files changed

+169
-26
lines changed

src/sagemaker/clarify.py

+74-20
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350
env=None,
351351
tags=None,
352352
network_config=None,
353+
job_name_prefix=None,
353354
version=None,
354355
):
355356
"""Initializes a ``Processor`` instance, computing bias metrics and model explanations.
@@ -384,9 +385,11 @@ def __init__(
384385
A :class:`~sagemaker.network.NetworkConfig`
385386
object that configures network isolation, encryption of
386387
inter-container traffic, security group IDs, and subnets.
388+
job_name_prefix (str): Processing job name prefix.
387389
version (str): Clarify version want to be used.
388390
"""
389391
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
392+
self.job_name_prefix = job_name_prefix
390393
super(SageMakerClarifyProcessor, self).__init__(
391394
role,
392395
container_uri,
@@ -500,13 +503,22 @@ def run_pre_training_bias(
500503
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
501504
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
502505
methods (str or list[str]): Selector of a subset of potential metrics:
503-
["CI", "DPL", "KL", "JS", "LP", "TVD", "KS", "CDDL"]. Defaults to computing all.
504-
# TODO: Provide a pointer to the official documentation of those.
506+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ci.html>`_",
507+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dpl.html>`_",
508+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-kl.html>`_",
509+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-js.html>`_",
510+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-lp.html>`_",
511+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-tvd.html>`_",
512+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ks.html>`_",
513+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cdd.html>`_"].
514+
Defaults to computing all.
505515
wait (bool): Whether the call should wait until the job completes (default: True).
506516
logs (bool): Whether to show the logs produced by the job.
507517
Only meaningful when ``wait`` is True (default: True).
508-
job_name (str): Processing job name. If not specified, a name is composed of
509-
"Clarify-Pretraining-Bias" and current timestamp.
518+
job_name (str): Processing job name. When ``job_name`` is not specified, if
519+
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
520+
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
521+
"Clarify-Pretraining-Bias" as prefix.
510522
kms_key (str): The ARN of the KMS key that is used to encrypt the
511523
user code file (default: None).
512524
experiment_config (dict[str, str]): Experiment management configuration.
@@ -517,7 +529,10 @@ def run_pre_training_bias(
517529
analysis_config.update(data_bias_config.get_config())
518530
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
519531
if job_name is None:
520-
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
532+
if self.job_name_prefix:
533+
job_name = utils.name_from_base(self.job_name_prefix)
534+
else:
535+
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
521536
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
522537

523538
def run_post_training_bias(
@@ -548,14 +563,25 @@ def run_post_training_bias(
548563
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
549564
Config of how to extract the predicted label from the model output.
550565
methods (str or list[str]): Selector of a subset of potential metrics:
551-
# TODO: Provide a pointer to the official documentation of those.
552-
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
566+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
567+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
568+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
569+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
570+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
571+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
572+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
573+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
574+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
575+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
576+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
553577
Defaults to computing all.
554578
wait (bool): Whether the call should wait until the job completes (default: True).
555579
logs (bool): Whether to show the logs produced by the job.
556580
Only meaningful when ``wait`` is True (default: True).
557-
job_name (str): Processing job name. If not specified, a name is composed of
558-
"Clarify-Posttraining-Bias" and current timestamp.
581+
job_name (str): Processing job name. When ``job_name`` is not specified, if
582+
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
583+
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
584+
"Clarify-Posttraining-Bias" as prefix.
559585
kms_key (str): The ARN of the KMS key that is used to encrypt the
560586
user code file (default: None).
561587
experiment_config (dict[str, str]): Experiment management configuration.
@@ -573,7 +599,10 @@ def run_post_training_bias(
573599
analysis_config["predictor"] = predictor_config
574600
_set(probability_threshold, "probability_threshold", analysis_config)
575601
if job_name is None:
576-
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
602+
if self.job_name_prefix:
603+
job_name = utils.name_from_base(self.job_name_prefix)
604+
else:
605+
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
577606
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
578607

579608
def run_bias(
@@ -605,18 +634,35 @@ def run_bias(
605634
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
606635
Config of how to extract the predicted label from the model output.
607636
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
608-
# TODO: Provide a pointer to the official documentation of those.
609-
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
637+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ci.html>`_",
638+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dpl.html>`_",
639+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-kl.html>`_",
640+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-js.html>`_",
641+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-lp.html>`_",
642+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-tvd.html>`_",
643+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ks.html>`_",
644+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cdd.html>`_"].
610645
Defaults to computing all.
611646
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
612-
# TODO: Provide a pointer to the official documentation of those.
613-
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
647+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
648+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
649+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
650+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
651+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
652+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
653+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
654+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
655+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
656+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
657+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
614658
Defaults to computing all.
615659
wait (bool): Whether the call should wait until the job completes (default: True).
616660
logs (bool): Whether to show the logs produced by the job.
617661
Only meaningful when ``wait`` is True (default: True).
618-
job_name (str): Processing job name. If not specified, a name is composed of
619-
"Clarify-Bias" and current timestamp.
662+
job_name (str): Processing job name. When ``job_name`` is not specified, if
663+
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
664+
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
665+
"Clarify-Bias" as prefix.
620666
kms_key (str): The ARN of the KMS key that is used to encrypt the
621667
user code file (default: None).
622668
experiment_config (dict[str, str]): Experiment management configuration.
@@ -641,7 +687,10 @@ def run_bias(
641687
"post_training_bias": {"methods": post_training_methods},
642688
}
643689
if job_name is None:
644-
job_name = utils.name_from_base("Clarify-Bias")
690+
if self.job_name_prefix:
691+
job_name = utils.name_from_base(self.job_name_prefix)
692+
else:
693+
job_name = utils.name_from_base("Clarify-Bias")
645694
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
646695

647696
def run_explainability(
@@ -679,8 +728,10 @@ def run_explainability(
679728
wait (bool): Whether the call should wait until the job completes (default: True).
680729
logs (bool): Whether to show the logs produced by the job.
681730
Only meaningful when ``wait`` is True (default: True).
682-
job_name (str): Processing job name. If not specified, a name is composed of
683-
"Clarify-Explainability" and current timestamp.
731+
job_name (str): Processing job name. When ``job_name`` is not specified, if
732+
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
733+
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
734+
"Clarify-Explainability" as prefix.
684735
kms_key (str): The ARN of the KMS key that is used to encrypt the
685736
user code file (default: None).
686737
experiment_config (dict[str, str]): Experiment management configuration.
@@ -693,7 +744,10 @@ def run_explainability(
693744
analysis_config["methods"] = explainability_config.get_explainability_config()
694745
analysis_config["predictor"] = predictor_config
695746
if job_name is None:
696-
job_name = utils.name_from_base("Clarify-Explainability")
747+
if self.job_name_prefix:
748+
job_name = utils.name_from_base(self.job_name_prefix)
749+
else:
750+
job_name = utils.name_from_base("Clarify-Explainability")
697751
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
698752

699753

tests/unit/test_clarify.py

+95-6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
)
2727
from sagemaker import image_uris
2828

29+
JOB_NAME_PREFIX = "my-prefix"
30+
TIMESTAMP = "2021-06-17-22-29-54-685"
31+
JOB_NAME = "{}-{}".format(JOB_NAME_PREFIX, TIMESTAMP)
32+
2933

3034
def test_uri():
3135
uri = image_uris.retrieve("clarify", "us-west-2")
@@ -248,6 +252,17 @@ def clarify_processor(sagemaker_session):
248252
)
249253

250254

255+
@pytest.fixture(scope="module")
256+
def clarify_processor_with_job_name_prefix(sagemaker_session):
257+
return SageMakerClarifyProcessor(
258+
role="AmazonSageMaker-ExecutionRole",
259+
instance_count=1,
260+
instance_type="ml.c5.xlarge",
261+
sagemaker_session=sagemaker_session,
262+
job_name_prefix=JOB_NAME_PREFIX,
263+
)
264+
265+
251266
@pytest.fixture(scope="module")
252267
def data_config():
253268
return DataConfig(
@@ -302,7 +317,14 @@ def shap_config():
302317
)
303318

304319

305-
def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
320+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
321+
def test_pre_training_bias(
322+
name_from_base,
323+
clarify_processor,
324+
clarify_processor_with_job_name_prefix,
325+
data_config,
326+
data_bias_config,
327+
):
306328
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
307329
clarify_processor.run_pre_training_bias(
308330
data_config,
@@ -325,7 +347,7 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
325347
"group_variable": "F2",
326348
"methods": {"pre_training_bias": {"methods": "all"}},
327349
}
328-
mock_method.assert_called_once_with(
350+
mock_method.assert_called_with(
329351
data_config,
330352
expected_analysis_config,
331353
True,
@@ -334,10 +356,33 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
334356
None,
335357
{"ExperimentName": "AnExperiment"},
336358
)
359+
clarify_processor_with_job_name_prefix.run_pre_training_bias(
360+
data_config,
361+
data_bias_config,
362+
wait=True,
363+
experiment_config={"ExperimentName": "AnExperiment"},
364+
)
365+
name_from_base.assert_called_with(JOB_NAME_PREFIX)
366+
mock_method.assert_called_with(
367+
data_config,
368+
expected_analysis_config,
369+
True,
370+
True,
371+
JOB_NAME,
372+
None,
373+
{"ExperimentName": "AnExperiment"},
374+
)
337375

338376

377+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
339378
def test_post_training_bias(
340-
clarify_processor, data_config, data_bias_config, model_config, model_predicted_label_config
379+
name_from_base,
380+
clarify_processor,
381+
clarify_processor_with_job_name_prefix,
382+
data_config,
383+
data_bias_config,
384+
model_config,
385+
model_predicted_label_config,
341386
):
342387
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
343388
clarify_processor.run_post_training_bias(
@@ -368,7 +413,7 @@ def test_post_training_bias(
368413
"initial_instance_count": 1,
369414
},
370415
}
371-
mock_method.assert_called_once_with(
416+
mock_method.assert_called_with(
372417
data_config,
373418
expected_analysis_config,
374419
True,
@@ -377,9 +422,35 @@ def test_post_training_bias(
377422
None,
378423
{"ExperimentName": "AnExperiment"},
379424
)
425+
clarify_processor_with_job_name_prefix.run_post_training_bias(
426+
data_config,
427+
data_bias_config,
428+
model_config,
429+
model_predicted_label_config,
430+
wait=True,
431+
experiment_config={"ExperimentName": "AnExperiment"},
432+
)
433+
name_from_base.assert_called_with(JOB_NAME_PREFIX)
434+
mock_method.assert_called_with(
435+
data_config,
436+
expected_analysis_config,
437+
True,
438+
True,
439+
JOB_NAME,
440+
None,
441+
{"ExperimentName": "AnExperiment"},
442+
)
380443

381444

382-
def test_shap(clarify_processor, data_config, model_config, shap_config):
445+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
446+
def test_shap(
447+
name_from_base,
448+
clarify_processor,
449+
clarify_processor_with_job_name_prefix,
450+
data_config,
451+
model_config,
452+
shap_config,
453+
):
383454
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
384455
clarify_processor.run_explainability(
385456
data_config,
@@ -420,7 +491,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
420491
"initial_instance_count": 1,
421492
},
422493
}
423-
mock_method.assert_called_once_with(
494+
mock_method.assert_called_with(
424495
data_config,
425496
expected_analysis_config,
426497
True,
@@ -429,3 +500,21 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
429500
None,
430501
{"ExperimentName": "AnExperiment"},
431502
)
503+
clarify_processor_with_job_name_prefix.run_explainability(
504+
data_config,
505+
model_config,
506+
shap_config,
507+
model_scores=None,
508+
wait=True,
509+
experiment_config={"ExperimentName": "AnExperiment"},
510+
)
511+
name_from_base.assert_called_with(JOB_NAME_PREFIX)
512+
mock_method.assert_called_with(
513+
data_config,
514+
expected_analysis_config,
515+
True,
516+
True,
517+
JOB_NAME,
518+
None,
519+
{"ExperimentName": "AnExperiment"},
520+
)

0 commit comments

Comments
 (0)