Skip to content

Commit eb13b1f

Browse files
committed
feat: add job_name_prefix
1 parent c2fbe75 commit eb13b1f

File tree

2 files changed

+114
-10
lines changed

2 files changed

+114
-10
lines changed

src/sagemaker/clarify.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350
env=None,
351351
tags=None,
352352
network_config=None,
353+
job_name_prefix=None,
353354
version=None,
354355
):
355356
"""Initializes a ``Processor`` instance, computing bias metrics and model explanations.
@@ -384,9 +385,11 @@ def __init__(
384385
A :class:`~sagemaker.network.NetworkConfig`
385386
object that configures network isolation, encryption of
386387
inter-container traffic, security group IDs, and subnets.
388+
job_name_prefix (str): Processing job name prefix.
387389
version (str): Clarify version want to be used.
388390
"""
389391
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
392+
self.job_name_prefix = job_name_prefix
390393
super(SageMakerClarifyProcessor, self).__init__(
391394
role,
392395
container_uri,
@@ -517,7 +520,10 @@ def run_pre_training_bias(
517520
analysis_config.update(data_bias_config.get_config())
518521
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
519522
if job_name is None:
520-
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
523+
if self.job_name_prefix:
524+
job_name = utils.name_from_base(self.job_name_prefix)
525+
else:
526+
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
521527
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
522528

523529
def run_post_training_bias(
@@ -573,7 +579,10 @@ def run_post_training_bias(
573579
analysis_config["predictor"] = predictor_config
574580
_set(probability_threshold, "probability_threshold", analysis_config)
575581
if job_name is None:
576-
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
582+
if self.job_name_prefix:
583+
job_name = utils.name_from_base(self.job_name_prefix)
584+
else:
585+
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
577586
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
578587

579588
def run_bias(
@@ -641,7 +650,10 @@ def run_bias(
641650
"post_training_bias": {"methods": post_training_methods},
642651
}
643652
if job_name is None:
644-
job_name = utils.name_from_base("Clarify-Bias")
653+
if self.job_name_prefix:
654+
job_name = utils.name_from_base(self.job_name_prefix)
655+
else:
656+
job_name = utils.name_from_base("Clarify-Bias")
645657
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
646658

647659
def run_explainability(
@@ -693,7 +705,10 @@ def run_explainability(
693705
analysis_config["methods"] = explainability_config.get_explainability_config()
694706
analysis_config["predictor"] = predictor_config
695707
if job_name is None:
696-
job_name = utils.name_from_base("Clarify-Explainability")
708+
if self.job_name_prefix:
709+
job_name = utils.name_from_base(self.job_name_prefix)
710+
else:
711+
job_name = utils.name_from_base("Clarify-Explainability")
697712
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
698713

699714

tests/unit/test_clarify.py

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
)
2727
from sagemaker import image_uris
2828

29+
JOB_NAME_PREFIX = "my-prefix"
30+
TIMESTAMP = "2021-06-17-22-29-54-685"
31+
JOB_NAME = "{}-{}".format(JOB_NAME_PREFIX, TIMESTAMP)
32+
2933

3034
def test_uri():
3135
uri = image_uris.retrieve("clarify", "us-west-2")
@@ -248,6 +252,17 @@ def clarify_processor(sagemaker_session):
248252
)
249253

250254

255+
@pytest.fixture(scope="module")
256+
def clarify_processor_with_job_name_prefix(sagemaker_session):
257+
return SageMakerClarifyProcessor(
258+
role="AmazonSageMaker-ExecutionRole",
259+
instance_count=1,
260+
instance_type="ml.c5.xlarge",
261+
sagemaker_session=sagemaker_session,
262+
job_name_prefix=JOB_NAME_PREFIX,
263+
)
264+
265+
251266
@pytest.fixture(scope="module")
252267
def data_config():
253268
return DataConfig(
@@ -302,7 +317,14 @@ def shap_config():
302317
)
303318

304319

305-
def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
320+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
321+
def test_pre_training_bias(
322+
name_from_base,
323+
clarify_processor,
324+
clarify_processor_with_job_name_prefix,
325+
data_config,
326+
data_bias_config,
327+
):
306328
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
307329
clarify_processor.run_pre_training_bias(
308330
data_config,
@@ -325,7 +347,7 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
325347
"group_variable": "F2",
326348
"methods": {"pre_training_bias": {"methods": "all"}},
327349
}
328-
mock_method.assert_called_once_with(
350+
mock_method.assert_called_with(
329351
data_config,
330352
expected_analysis_config,
331353
True,
@@ -334,10 +356,33 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config):
334356
None,
335357
{"ExperimentName": "AnExperiment"},
336358
)
359+
clarify_processor_with_job_name_prefix.run_pre_training_bias(
360+
data_config,
361+
data_bias_config,
362+
wait=True,
363+
experiment_config={"ExperimentName": "AnExperiment"},
364+
)
365+
name_from_base.assert_called_with(JOB_NAME_PREFIX)
366+
mock_method.assert_called_with(
367+
data_config,
368+
expected_analysis_config,
369+
True,
370+
True,
371+
JOB_NAME,
372+
None,
373+
{"ExperimentName": "AnExperiment"},
374+
)
337375

338376

377+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
339378
def test_post_training_bias(
340-
clarify_processor, data_config, data_bias_config, model_config, model_predicted_label_config
379+
name_from_base,
380+
clarify_processor,
381+
clarify_processor_with_job_name_prefix,
382+
data_config,
383+
data_bias_config,
384+
model_config,
385+
model_predicted_label_config,
341386
):
342387
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
343388
clarify_processor.run_post_training_bias(
@@ -368,7 +413,7 @@ def test_post_training_bias(
368413
"initial_instance_count": 1,
369414
},
370415
}
371-
mock_method.assert_called_once_with(
416+
mock_method.assert_called_with(
372417
data_config,
373418
expected_analysis_config,
374419
True,
@@ -377,9 +422,35 @@ def test_post_training_bias(
377422
None,
378423
{"ExperimentName": "AnExperiment"},
379424
)
425+
clarify_processor_with_job_name_prefix.run_post_training_bias(
426+
data_config,
427+
data_bias_config,
428+
model_config,
429+
model_predicted_label_config,
430+
wait=True,
431+
experiment_config={"ExperimentName": "AnExperiment"},
432+
)
433+
name_from_base.assert_called_with(JOB_NAME_PREFIX)
434+
mock_method.assert_called_with(
435+
data_config,
436+
expected_analysis_config,
437+
True,
438+
True,
439+
JOB_NAME,
440+
None,
441+
{"ExperimentName": "AnExperiment"},
442+
)
380443

381444

382-
def test_shap(clarify_processor, data_config, model_config, shap_config):
445+
@patch("sagemaker.utils.name_from_base", return_value=JOB_NAME)
446+
def test_shap(
447+
name_from_base,
448+
clarify_processor,
449+
clarify_processor_with_job_name_prefix,
450+
data_config,
451+
model_config,
452+
shap_config,
453+
):
383454
with patch.object(SageMakerClarifyProcessor, "_run", return_value=None) as mock_method:
384455
clarify_processor.run_explainability(
385456
data_config,
@@ -420,7 +491,7 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
420491
"initial_instance_count": 1,
421492
},
422493
}
423-
mock_method.assert_called_once_with(
494+
mock_method.assert_called_with(
424495
data_config,
425496
expected_analysis_config,
426497
True,
@@ -429,3 +500,21 @@ def test_shap(clarify_processor, data_config, model_config, shap_config):
429500
None,
430501
{"ExperimentName": "AnExperiment"},
431502
)
503+
clarify_processor_with_job_name_prefix.run_explainability(
504+
data_config,
505+
model_config,
506+
shap_config,
507+
model_scores=None,
508+
wait=True,
509+
experiment_config={"ExperimentName": "AnExperiment"},
510+
)
511+
name_from_base.assert_called_with(JOB_NAME_PREFIX)
512+
mock_method.assert_called_with(
513+
data_config,
514+
expected_analysis_config,
515+
True,
516+
True,
517+
JOB_NAME,
518+
None,
519+
{"ExperimentName": "AnExperiment"},
520+
)

0 commit comments

Comments
 (0)