Skip to content

Commit e87e988

Browse files
jayatalrahsan-z-khanshreyapandit
authored
feature: support displayName and description for pipeline steps (#2580)
Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Shreya Pandit <[email protected]>
1 parent 5d89f62 commit e87e988

12 files changed

+173
-25
lines changed

src/sagemaker/workflow/_utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(
5555
role,
5656
model_data: str,
5757
entry_point: str,
58+
display_name: str = None,
59+
description: str = None,
5860
source_dir: str = None,
5961
dependencies: List = None,
6062
depends_on: Union[List[str], List[Step]] = None,
@@ -165,7 +167,12 @@ def __init__(
165167

166168
# super!
167169
super(_RepackModelStep, self).__init__(
168-
name=name, depends_on=depends_on, estimator=repacker, inputs=inputs
170+
name=name,
171+
display_name=display_name,
172+
description=description,
173+
depends_on=depends_on,
174+
estimator=repacker,
175+
inputs=inputs,
169176
)
170177

171178
def _prepare_for_repacking(self):
@@ -285,6 +292,7 @@ def __init__(
285292
approval_status="PendingManualApproval",
286293
image_uri=None,
287294
compile_model_family=None,
295+
display_name: str = None,
288296
description=None,
289297
depends_on: Union[List[str], List[Step]] = None,
290298
tags=None,
@@ -326,7 +334,9 @@ def __init__(
326334
this step depends on
327335
**kwargs: additional arguments to `create_model`.
328336
"""
329-
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)
337+
super(_RegisterModelStep, self).__init__(
338+
name, display_name, description, StepTypeEnum.REGISTER_MODEL, depends_on
339+
)
330340
self.estimator = estimator
331341
self.model_data = model_data
332342
self.content_types = content_types

src/sagemaker/workflow/callback_step.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(
8383
sqs_queue_url: str,
8484
inputs: dict,
8585
outputs: List[CallbackOutput],
86+
display_name: str = None,
87+
description: str = None,
8688
cache_config: CacheConfig = None,
8789
depends_on: Union[List[str], List[Step]] = None,
8890
):
@@ -94,11 +96,15 @@ def __init__(
9496
inputs (dict): Input arguments that will be provided
9597
in the SQS message body of callback messages.
9698
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
99+
display_name (str): The display name of the callback step.
100+
description (str): The description of the callback step.
97101
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
98102
depends_on (List[str] or List[Step]): A list of step names or step instances
99103
this `sagemaker.workflow.steps.CallbackStep` depends on
100104
"""
101-
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
105+
super(CallbackStep, self).__init__(
106+
name, display_name, description, StepTypeEnum.CALLBACK, depends_on
107+
)
102108
self.sqs_queue_url = sqs_queue_url
103109
self.outputs = outputs
104110
self.cache_config = cache_config

src/sagemaker/workflow/condition_step.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def __init__(
4242
self,
4343
name: str,
4444
depends_on: Union[List[str], List[Step]] = None,
45+
display_name: str = None,
46+
description: str = None,
4547
conditions: List[Condition] = None,
4648
if_steps: List[Union[Step, StepCollection]] = None,
4749
else_steps: List[Union[Step, StepCollection]] = None,
@@ -53,6 +55,9 @@ def __init__(
5355
execution.
5456
5557
Args:
58+
name (str): The name of the condition step.
59+
display_name (str): The display name of the condition step.
60+
description (str): The description of the condition step.
5661
conditions (List[Condition]): A list of `sagemaker.workflow.conditions.Condition`
5762
instances.
5863
if_steps (List[Union[Step, StepCollection]]): A list of `sagemaker.workflow.steps.Step`
@@ -62,7 +67,9 @@ def __init__(
6267
or `sagemaker.workflow.step_collections.StepCollection` instances that are
6368
marked as ready for execution if the list of conditions evaluates to False.
6469
"""
65-
super(ConditionStep, self).__init__(name, StepTypeEnum.CONDITION, depends_on)
70+
super(ConditionStep, self).__init__(
71+
name, display_name, description, StepTypeEnum.CONDITION, depends_on
72+
)
6673
self.conditions = conditions or []
6774
self.if_steps = if_steps or []
6875
self.else_steps = else_steps or []

src/sagemaker/workflow/lambda_step.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(
8282
self,
8383
name: str,
8484
lambda_func: Lambda,
85+
display_name: str = None,
86+
description: str = None,
8587
inputs: dict = None,
8688
outputs: List[LambdaOutput] = None,
8789
cache_config: CacheConfig = None,
@@ -91,6 +93,8 @@ def __init__(
9193
9294
Args:
9395
name (str): The name of the lambda step.
96+
display_name (str): The display name of the Lambda step.
97+
description (str): The description of the Lambda step.
9498
lambda_func (str): An instance of sagemaker.lambda_helper.Lambda.
9599
If lambda arn is specified in the instance, LambdaStep just invokes the function,
96100
else lambda function will be created while creating the pipeline.
@@ -101,7 +105,9 @@ def __init__(
101105
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.LambdaStep`
102106
depends on
103107
"""
104-
super(LambdaStep, self).__init__(name, StepTypeEnum.LAMBDA, depends_on)
108+
super(LambdaStep, self).__init__(
109+
name, display_name, description, StepTypeEnum.LAMBDA, depends_on
110+
)
105111
self.lambda_func = lambda_func
106112
self.outputs = outputs if outputs is not None else []
107113
self.cache_config = cache_config

src/sagemaker/workflow/step_collections.py

+14
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
approval_status=None,
6868
image_uri=None,
6969
compile_model_family=None,
70+
display_name=None,
7071
description=None,
7172
tags=None,
7273
model: Union[Model, PipelineModel] = None,
@@ -138,6 +139,8 @@ def __init__(
138139
tags=tags,
139140
subnets=subnets,
140141
security_group_ids=security_group_ids,
142+
description=description,
143+
display_name=display_name,
141144
**kwargs,
142145
)
143146
steps.append(repack_model_step)
@@ -179,6 +182,8 @@ def __init__(
179182
tags=tags,
180183
subnets=subnets,
181184
security_group_ids=security_group_ids,
185+
description=description,
186+
display_name=display_name,
182187
**kwargs,
183188
)
184189
steps.append(repack_model_step)
@@ -208,6 +213,7 @@ def __init__(
208213
image_uri=image_uri,
209214
compile_model_family=compile_model_family,
210215
description=description,
216+
display_name=display_name,
211217
tags=tags,
212218
container_def_list=self.container_def_list,
213219
**kwargs,
@@ -231,6 +237,8 @@ def __init__(
231237
instance_count,
232238
instance_type,
233239
transform_inputs,
240+
description: str = None,
241+
display_name: str = None,
234242
# model arguments
235243
image_uri=None,
236244
predictor_cls=None,
@@ -302,6 +310,8 @@ def __init__(
302310
tags=tags,
303311
subnets=estimator.subnets,
304312
security_group_ids=estimator.security_group_ids,
313+
description=description,
314+
display_name=display_name,
305315
)
306316
steps.append(repack_model_step)
307317
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
@@ -324,6 +334,8 @@ def predict_wrapper(endpoint, session):
324334
name=f"{name}CreateModelStep",
325335
model=model,
326336
inputs=model_inputs,
337+
description=description,
338+
display_name=display_name,
327339
)
328340
if "entry_point" not in kwargs and depends_on:
329341
# if the CreateModelStep is the first step in the collection
@@ -351,6 +363,8 @@ def predict_wrapper(endpoint, session):
351363
name=f"{name}TransformStep",
352364
transformer=transformer,
353365
inputs=transform_inputs,
366+
description=description,
367+
display_name=display_name,
354368
)
355369
steps.append(transform_step)
356370

src/sagemaker/workflow/steps.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,16 @@ class Step(Entity):
6363
6464
Attributes:
6565
name (str): The name of the step.
66+
display_name (str): The display name of the step.
67+
description (str): The description of the step.
6668
step_type (StepTypeEnum): The type of the step.
6769
depends_on (List[str] or List[Step]): The list of step names or step
6870
instances the current step depends on
6971
"""
7072

7173
name: str = attr.ib(factory=str)
74+
display_name: str = attr.ib(default=None)
75+
description: str = attr.ib(default=None)
7276
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
7377
depends_on: Union[List[str], List["Step"]] = attr.ib(default=None)
7478

@@ -91,7 +95,10 @@ def to_request(self) -> RequestType:
9195
}
9296
if self.depends_on:
9397
request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on)
94-
98+
if self.display_name:
99+
request_dict["DisplayName"] = self.display_name
100+
if self.description:
101+
request_dict["Description"] = self.description
95102
return request_dict
96103

97104
def add_depends_on(self, step_names: Union[List[str], List["Step"]]):
@@ -168,6 +175,8 @@ def __init__(
168175
self,
169176
name: str,
170177
estimator: EstimatorBase,
178+
display_name: str = None,
179+
description: str = None,
171180
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
172181
cache_config: CacheConfig = None,
173182
depends_on: Union[List[str], List[Step]] = None,
@@ -180,6 +189,8 @@ def __init__(
180189
Args:
181190
name (str): The name of the training step.
182191
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
192+
display_name (str): The display name of the training step.
193+
description (str): The description of the training step.
183194
inputs (Union[str, dict, TrainingInput, FileSystemInput]): Information
184195
about the training data. This can be one of three types:
185196
@@ -200,7 +211,9 @@ def __init__(
200211
depends_on (List[str] or List[Step]): A list of step names or step instances
201212
this `sagemaker.workflow.steps.TrainingStep` depends on
202213
"""
203-
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
214+
super(TrainingStep, self).__init__(
215+
name, display_name, description, StepTypeEnum.TRAINING, depends_on
216+
)
204217
self.estimator = estimator
205218
self.inputs = inputs
206219
self._properties = Properties(
@@ -248,6 +261,8 @@ def __init__(
248261
model: Model,
249262
inputs: CreateModelInput,
250263
depends_on: Union[List[str], List[Step]] = None,
264+
display_name: str = None,
265+
description: str = None,
251266
):
252267
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
253268
@@ -261,8 +276,12 @@ def __init__(
261276
Defaults to `None`.
262277
depends_on (List[str] or List[Step]): A list of step names or step instances
263278
this `sagemaker.workflow.steps.CreateModelStep` depends on
279+
display_name (str): The display name of the CreateModel step.
280+
description (str): The description of the CreateModel step.
264281
"""
265-
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
282+
super(CreateModelStep, self).__init__(
283+
name, display_name, description, StepTypeEnum.CREATE_MODEL, depends_on
284+
)
266285
self.model = model
267286
self.inputs = inputs or CreateModelInput()
268287

@@ -304,6 +323,8 @@ def __init__(
304323
name: str,
305324
transformer: Transformer,
306325
inputs: TransformInput,
326+
display_name: str = None,
327+
description: str = None,
307328
cache_config: CacheConfig = None,
308329
depends_on: Union[List[str], List[Step]] = None,
309330
):
@@ -317,10 +338,14 @@ def __init__(
317338
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
318339
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
319340
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
320-
depends_on (List[str] or List[Step]): A list of step names or step instances
321-
this `sagemaker.workflow.steps.TransformStep` depends on
341+
display_name (str): The display name of the transform step.
342+
description (str): The description of the transform step.
343+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
344+
depends on
322345
"""
323-
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
346+
super(TransformStep, self).__init__(
347+
name, display_name, description, StepTypeEnum.TRANSFORM, depends_on
348+
)
324349
self.transformer = transformer
325350
self.inputs = inputs
326351
self.cache_config = cache_config
@@ -375,6 +400,8 @@ def __init__(
375400
self,
376401
name: str,
377402
processor: Processor,
403+
display_name: str = None,
404+
description: str = None,
378405
inputs: List[ProcessingInput] = None,
379406
outputs: List[ProcessingOutput] = None,
380407
job_arguments: List[str] = None,
@@ -391,6 +418,8 @@ def __init__(
391418
Args:
392419
name (str): The name of the processing step.
393420
processor (Processor): A `sagemaker.processing.Processor` instance.
421+
display_name (str): The display name of the processing step.
422+
description (str): The description of the processing step.
394423
inputs (List[ProcessingInput]): A list of `sagemaker.processing.ProcessorInput`
395424
instances. Defaults to `None`.
396425
outputs (List[ProcessingOutput]): A list of `sagemaker.processing.ProcessorOutput`
@@ -405,7 +434,9 @@ def __init__(
405434
depends_on (List[str] or List[Step]): A list of step names or step instance
406435
this `sagemaker.workflow.steps.ProcessingStep` depends on
407436
"""
408-
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
437+
super(ProcessingStep, self).__init__(
438+
name, display_name, description, StepTypeEnum.PROCESSING, depends_on
439+
)
409440
self.processor = processor
410441
self.inputs = inputs
411442
self.outputs = outputs
@@ -468,6 +499,8 @@ def __init__(
468499
self,
469500
name: str,
470501
tuner: HyperparameterTuner,
502+
display_name: str = None,
503+
description: str = None,
471504
inputs=None,
472505
job_arguments: List[str] = None,
473506
cache_config: CacheConfig = None,
@@ -481,6 +514,8 @@ def __init__(
481514
Args:
482515
name (str): The name of the tuning step.
483516
tuner (HyperparameterTuner): A `sagemaker.tuner.HyperparameterTuner` instance.
517+
display_name (str): The display name of the tuning step.
518+
description (str): The description of the tuning step.
484519
inputs: Information about the training data. Please refer to the
485520
``fit()`` method of the associated estimator, as this can take
486521
any of the following forms:
@@ -514,7 +549,9 @@ def __init__(
514549
depends_on (List[str] or List[Step]): A list of step names or step instance
515550
this `sagemaker.workflow.steps.ProcessingStep` depends on
516551
"""
517-
super(TuningStep, self).__init__(name, StepTypeEnum.TUNING, depends_on)
552+
super(TuningStep, self).__init__(
553+
name, display_name, description, StepTypeEnum.TUNING, depends_on
554+
)
518555
self.tuner = tuner
519556
self.inputs = inputs
520557
self.job_arguments = job_arguments

0 commit comments

Comments
 (0)