Skip to content

Commit 5e27a41

Browse files
Merge branch 'master' into fix-issue-2426
2 parents 5635282 + 22ba84e commit 5e27a41

File tree

9 files changed

+271
-47
lines changed

9 files changed

+271
-47
lines changed

src/sagemaker/clarify.py

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def __init__(
350350
env=None,
351351
tags=None,
352352
network_config=None,
353+
job_name_prefix=None,
353354
version=None,
354355
):
355356
"""Initializes a ``Processor`` instance, computing bias metrics and model explanations.
@@ -384,9 +385,11 @@ def __init__(
384385
A :class:`~sagemaker.network.NetworkConfig`
385386
object that configures network isolation, encryption of
386387
inter-container traffic, security group IDs, and subnets.
388+
job_name_prefix (str): Processing job name prefix.
387389
version (str): Clarify version want to be used.
388390
"""
389391
container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version)
392+
self.job_name_prefix = job_name_prefix
390393
super(SageMakerClarifyProcessor, self).__init__(
391394
role,
392395
container_uri,
@@ -500,13 +503,22 @@ def run_pre_training_bias(
500503
data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data.
501504
data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups.
502505
methods (str or list[str]): Selector of a subset of potential metrics:
503-
["CI", "DPL", "KL", "JS", "LP", "TVD", "KS", "CDDL"]. Defaults to computing all.
504-
# TODO: Provide a pointer to the official documentation of those.
506+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ci.html>`_",
507+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dpl.html>`_",
508+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-kl.html>`_",
509+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-js.html>`_",
510+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-lp.html>`_",
511+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-tvd.html>`_",
512+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ks.html>`_",
513+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cdd.html>`_"].
514+
Defaults to computing all.
505515
wait (bool): Whether the call should wait until the job completes (default: True).
506516
logs (bool): Whether to show the logs produced by the job.
507517
Only meaningful when ``wait`` is True (default: True).
508-
job_name (str): Processing job name. If not specified, a name is composed of
509-
"Clarify-Pretraining-Bias" and current timestamp.
518+
job_name (str): Processing job name. When ``job_name`` is not specified, if
519+
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
520+
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
521+
"Clarify-Pretraining-Bias" as prefix.
510522
kms_key (str): The ARN of the KMS key that is used to encrypt the
511523
user code file (default: None).
512524
experiment_config (dict[str, str]): Experiment management configuration.
@@ -517,7 +529,10 @@ def run_pre_training_bias(
517529
analysis_config.update(data_bias_config.get_config())
518530
analysis_config["methods"] = {"pre_training_bias": {"methods": methods}}
519531
if job_name is None:
520-
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
532+
if self.job_name_prefix:
533+
job_name = utils.name_from_base(self.job_name_prefix)
534+
else:
535+
job_name = utils.name_from_base("Clarify-Pretraining-Bias")
521536
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
522537

523538
def run_post_training_bias(
@@ -548,14 +563,25 @@ def run_post_training_bias(
548563
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
549564
Config of how to extract the predicted label from the model output.
550565
methods (str or list[str]): Selector of a subset of potential metrics:
551-
# TODO: Provide a pointer to the official documentation of those.
552-
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
566+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
567+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
568+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
569+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
570+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
571+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
572+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
573+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
574+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
575+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
576+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
553577
Defaults to computing all.
554578
wait (bool): Whether the call should wait until the job completes (default: True).
555579
logs (bool): Whether to show the logs produced by the job.
556580
Only meaningful when ``wait`` is True (default: True).
557-
job_name (str): Processing job name. If not specified, a name is composed of
558-
"Clarify-Posttraining-Bias" and current timestamp.
581+
job_name (str): Processing job name. When ``job_name`` is not specified, if
582+
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
583+
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
584+
"Clarify-Posttraining-Bias" as prefix.
559585
kms_key (str): The ARN of the KMS key that is used to encrypt the
560586
user code file (default: None).
561587
experiment_config (dict[str, str]): Experiment management configuration.
@@ -573,7 +599,10 @@ def run_post_training_bias(
573599
analysis_config["predictor"] = predictor_config
574600
_set(probability_threshold, "probability_threshold", analysis_config)
575601
if job_name is None:
576-
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
602+
if self.job_name_prefix:
603+
job_name = utils.name_from_base(self.job_name_prefix)
604+
else:
605+
job_name = utils.name_from_base("Clarify-Posttraining-Bias")
577606
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
578607

579608
def run_bias(
@@ -605,18 +634,35 @@ def run_bias(
605634
model_predicted_label_config (:class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
606635
Config of how to extract the predicted label from the model output.
607636
pre_training_methods (str or list[str]): Selector of a subset of potential metrics:
608-
# TODO: Provide a pointer to the official documentation of those.
609-
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
637+
["`CI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ci.html>`_",
638+
"`DPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dpl.html>`_",
639+
"`KL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-kl.html>`_",
640+
"`JS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-js.html>`_",
641+
"`LP <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-lp.html>`_",
642+
"`TVD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-tvd.html>`_",
643+
"`KS <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ks.html>`_",
644+
"`CDDL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cdd.html>`_"].
610645
Defaults to computing all.
611646
post_training_methods (str or list[str]): Selector of a subset of potential metrics:
612-
# TODO: Provide a pointer to the official documentation of those.
613-
["DPPL", "DI", "DCA", "DCR", "RD", "DAR", "DRR", "AD", "CDDPL", "TE", "FT"].
647+
["`DPPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dppl.html>`_"
648+
, "`DI <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-di.html>`_",
649+
"`DCA <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dca.html>`_",
650+
"`DCR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dcr.html>`_",
651+
"`RD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-rd.html>`_",
652+
"`DAR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-dar.html>`_",
653+
"`DRR <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-drr.html>`_",
654+
"`AD <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ad.html>`_",
655+
"`CDDPL <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-cddpl.html>`_
656+
", "`TE <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-te.html>`_",
657+
"`FT <https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-post-training-bias-metric-ft.html>`_"].
614658
Defaults to computing all.
615659
wait (bool): Whether the call should wait until the job completes (default: True).
616660
logs (bool): Whether to show the logs produced by the job.
617661
Only meaningful when ``wait`` is True (default: True).
618-
job_name (str): Processing job name. If not specified, a name is composed of
619-
"Clarify-Bias" and current timestamp.
662+
job_name (str): Processing job name. When ``job_name`` is not specified, if
663+
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
664+
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
665+
"Clarify-Bias" as prefix.
620666
kms_key (str): The ARN of the KMS key that is used to encrypt the
621667
user code file (default: None).
622668
experiment_config (dict[str, str]): Experiment management configuration.
@@ -641,7 +687,10 @@ def run_bias(
641687
"post_training_bias": {"methods": post_training_methods},
642688
}
643689
if job_name is None:
644-
job_name = utils.name_from_base("Clarify-Bias")
690+
if self.job_name_prefix:
691+
job_name = utils.name_from_base(self.job_name_prefix)
692+
else:
693+
job_name = utils.name_from_base("Clarify-Bias")
645694
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
646695

647696
def run_explainability(
@@ -679,8 +728,10 @@ def run_explainability(
679728
wait (bool): Whether the call should wait until the job completes (default: True).
680729
logs (bool): Whether to show the logs produced by the job.
681730
Only meaningful when ``wait`` is True (default: True).
682-
job_name (str): Processing job name. If not specified, a name is composed of
683-
"Clarify-Explainability" and current timestamp.
731+
job_name (str): Processing job name. When ``job_name`` is not specified, if
732+
``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name
733+
will be composed of ``job_name_prefix`` and current timestamp; otherwise use
734+
"Clarify-Explainability" as prefix.
684735
kms_key (str): The ARN of the KMS key that is used to encrypt the
685736
user code file (default: None).
686737
experiment_config (dict[str, str]): Experiment management configuration.
@@ -693,7 +744,10 @@ def run_explainability(
693744
analysis_config["methods"] = explainability_config.get_explainability_config()
694745
analysis_config["predictor"] = predictor_config
695746
if job_name is None:
696-
job_name = utils.name_from_base("Clarify-Explainability")
747+
if self.job_name_prefix:
748+
job_name = utils.name_from_base(self.job_name_prefix)
749+
else:
750+
job_name = utils.name_from_base("Clarify-Explainability")
697751
self._run(data_config, analysis_config, wait, logs, job_name, kms_key, experiment_config)
698752

699753

src/sagemaker/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,8 @@ def register(
10071007
if compile_model_family is not None:
10081008
model = self._compiled_models[compile_model_family]
10091009
else:
1010+
if "model_kms_key" not in kwargs:
1011+
kwargs["model_kms_key"] = self.output_kms_key
10101012
model = self.create_model(image_uri=image_uri, **kwargs)
10111013
model.name = model_name
10121014
return model.register(

src/sagemaker/workflow/_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
source_dir: str = None,
6161
dependencies: List = None,
6262
depends_on: List[str] = None,
63+
**kwargs,
6364
):
6465
"""Constructs a TrainingStep, given an `EstimatorBase` instance.
6566
@@ -98,6 +99,7 @@ def __init__(
9899
"inference_script": self._entry_point_basename,
99100
"model_archive": self._model_archive,
100101
},
102+
**kwargs,
101103
)
102104
repacker.disable_profiler = True
103105
inputs = TrainingInput(self._model_prefix)

src/sagemaker/workflow/callback_step.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,20 @@ def to_request(self) -> RequestType:
5858
"OutputType": self.output_type.value,
5959
}
6060

61-
@property
62-
def expr(self) -> Dict[str, str]:
63-
"""The 'Get' expression dict for a `Parameter`."""
64-
return CallbackOutput._expr(self.output_name)
61+
def expr(self, step_name) -> Dict[str, str]:
62+
"""The 'Get' expression dict for a `CallbackOutput`."""
63+
return CallbackOutput._expr(self.output_name, step_name)
6564

6665
@classmethod
67-
def _expr(cls, name):
66+
def _expr(cls, name, step_name):
6867
"""An internal classmethod for the 'Get' expression dict for a `CallbackOutput`.
6968
7069
Args:
7170
name (str): The name of the callback output.
71+
step_name (str): The name of the step the callback step associated
72+
with this output belongs to.
7273
"""
73-
return {"Get": f"Steps.{name}.OutputParameters['{name}']"}
74+
return {"Get": f"Steps.{step_name}.OutputParameters['{name}']"}
7475

7576

7677
class CallbackStep(Step):

src/sagemaker/workflow/pipeline.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from sagemaker._studio import _append_project_tags
2626
from sagemaker.session import Session
27-
from sagemaker.workflow.callback_step import CallbackOutput
27+
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
2828
from sagemaker.workflow.entities import (
2929
Entity,
3030
Expression,
@@ -240,9 +240,12 @@ def definition(self) -> str:
240240
"""Converts a request structure to string representation for workflow service calls."""
241241
request_dict = self.to_request()
242242
request_dict["PipelineExperimentConfig"] = interpolate(
243-
request_dict["PipelineExperimentConfig"]
243+
request_dict["PipelineExperimentConfig"], {}
244+
)
245+
callback_output_to_step_map = _map_callback_outputs(self.steps)
246+
request_dict["Steps"] = interpolate(
247+
request_dict["Steps"], callback_output_to_step_map=callback_output_to_step_map
244248
)
245-
request_dict["Steps"] = interpolate(request_dict["Steps"])
246249

247250
return json.dumps(request_dict)
248251

@@ -263,38 +266,62 @@ def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
263266
return [{"Name": name, "Value": str(value)} for name, value in parameters.items()]
264267

265268

266-
def interpolate(request_obj: RequestType) -> RequestType:
269+
def interpolate(
270+
request_obj: RequestType, callback_output_to_step_map: Dict[str, str]
271+
) -> RequestType:
267272
"""Replaces Parameter values in a list of nested Dict[str, Any] with their workflow expression.
268273
269274
Args:
270275
request_obj (RequestType): The request dict.
276+
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
271277
272278
Returns:
273279
RequestType: The request dict with Parameter values replaced by their expression.
274280
"""
275281
request_obj_copy = deepcopy(request_obj)
276-
return _interpolate(request_obj_copy)
282+
return _interpolate(request_obj_copy, callback_output_to_step_map=callback_output_to_step_map)
277283

278284

279-
def _interpolate(obj: Union[RequestType, Any]):
285+
def _interpolate(obj: Union[RequestType, Any], callback_output_to_step_map: Dict[str, str]):
280286
"""Walks the nested request dict, replacing Parameter type values with workflow expressions.
281287
282288
Args:
283289
obj (Union[RequestType, Any]): The request dict.
290+
callback_output_to_step_map (Dict[str, str]): A dict of output name -> step name.
284291
"""
285-
if isinstance(obj, (Expression, Parameter, Properties, CallbackOutput)):
292+
if isinstance(obj, (Expression, Parameter, Properties)):
286293
return obj.expr
294+
if isinstance(obj, CallbackOutput):
295+
step_name = callback_output_to_step_map[obj.output_name]
296+
return obj.expr(step_name)
287297
if isinstance(obj, dict):
288298
new = obj.__class__()
289299
for key, value in obj.items():
290-
new[key] = interpolate(value)
300+
new[key] = interpolate(value, callback_output_to_step_map)
291301
elif isinstance(obj, (list, set, tuple)):
292-
new = obj.__class__(interpolate(value) for value in obj)
302+
new = obj.__class__(interpolate(value, callback_output_to_step_map) for value in obj)
293303
else:
294304
return obj
295305
return new
296306

297307

308+
def _map_callback_outputs(steps: List[Step]):
309+
"""Iterate over the provided steps, building a map of callback output parameters to step names.
310+
311+
Args:
312+
step (List[Step]): The steps list.
313+
"""
314+
315+
callback_output_map = {}
316+
for step in steps:
317+
if isinstance(step, CallbackStep):
318+
if step.outputs:
319+
for output in step.outputs:
320+
callback_output_map[output.output_name] = step.name
321+
322+
return callback_output_map
323+
324+
298325
def update_args(args: Dict[str, Any], **kwargs):
299326
"""Updates the request arguments dict with a value, if populated.
300327

0 commit comments

Comments
 (0)