Skip to content

Commit 382d9dd

Browse files
committed
feature: support displayName and description for pipeline steps
1 parent 9ca12bc commit 382d9dd

12 files changed

+173
-25
lines changed

src/sagemaker/workflow/_utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def __init__(
5858
role,
5959
model_data: str,
6060
entry_point: str,
61+
display_name: str = None,
62+
description: str = None,
6163
source_dir: str = None,
6264
dependencies: List = None,
6365
depends_on: Union[List[str], List[Step]] = None,
@@ -108,7 +110,12 @@ def __init__(
108110

109111
# super!
110112
super(_RepackModelStep, self).__init__(
111-
name=name, depends_on=depends_on, estimator=repacker, inputs=inputs
113+
name=name,
114+
display_name=display_name,
115+
description=description,
116+
depends_on=depends_on,
117+
estimator=repacker,
118+
inputs=inputs,
112119
)
113120

114121
def _prepare_for_repacking(self):
@@ -228,6 +235,7 @@ def __init__(
228235
approval_status="PendingManualApproval",
229236
image_uri=None,
230237
compile_model_family=None,
238+
display_name: str = None,
231239
description=None,
232240
depends_on: Union[List[str], List[Step]] = None,
233241
tags=None,
@@ -269,7 +277,9 @@ def __init__(
269277
this step depends on
270278
**kwargs: additional arguments to `create_model`.
271279
"""
272-
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)
280+
super(_RegisterModelStep, self).__init__(
281+
name, display_name, description, StepTypeEnum.REGISTER_MODEL, depends_on
282+
)
273283
self.estimator = estimator
274284
self.model_data = model_data
275285
self.content_types = content_types

src/sagemaker/workflow/callback_step.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(
8383
sqs_queue_url: str,
8484
inputs: dict,
8585
outputs: List[CallbackOutput],
86+
display_name: str = None,
87+
description: str = None,
8688
cache_config: CacheConfig = None,
8789
depends_on: Union[List[str], List[Step]] = None,
8890
):
@@ -94,11 +96,15 @@ def __init__(
9496
inputs (dict): Input arguments that will be provided
9597
in the SQS message body of callback messages.
9698
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
99+
display_name (str): The display name of the callback step.
100+
description (str): The description of the callback step.
97101
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
98102
depends_on (List[str] or List[Step]): A list of step names or step instances
99103
this `sagemaker.workflow.steps.CallbackStep` depends on
100104
"""
101-
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
105+
super(CallbackStep, self).__init__(
106+
name, display_name, description, StepTypeEnum.CALLBACK, depends_on
107+
)
102108
self.sqs_queue_url = sqs_queue_url
103109
self.outputs = outputs
104110
self.cache_config = cache_config

src/sagemaker/workflow/condition_step.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def __init__(
4242
self,
4343
name: str,
4444
depends_on: Union[List[str], List[Step]] = None,
45+
display_name: str = None,
46+
description: str = None,
4547
conditions: List[Condition] = None,
4648
if_steps: List[Union[Step, StepCollection]] = None,
4749
else_steps: List[Union[Step, StepCollection]] = None,
@@ -53,6 +55,9 @@ def __init__(
5355
execution.
5456
5557
Args:
58+
name (str): The name of the condition step.
59+
display_name (str): The display name of the condition step.
60+
description (str): The description of the condition step.
5661
conditions (List[Condition]): A list of `sagemaker.workflow.conditions.Condition`
5762
instances.
5863
if_steps (List[Union[Step, StepCollection]]): A list of `sagemaker.workflow.steps.Step`
@@ -62,7 +67,9 @@ def __init__(
6267
or `sagemaker.workflow.step_collections.StepCollection` instances that are
6368
marked as ready for execution if the list of conditions evaluates to False.
6469
"""
65-
super(ConditionStep, self).__init__(name, StepTypeEnum.CONDITION, depends_on)
70+
super(ConditionStep, self).__init__(
71+
name, display_name, description, StepTypeEnum.CONDITION, depends_on
72+
)
6673
self.conditions = conditions or []
6774
self.if_steps = if_steps or []
6875
self.else_steps = else_steps or []

src/sagemaker/workflow/lambda_step.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(
8282
self,
8383
name: str,
8484
lambda_func: Lambda,
85+
display_name: str = None,
86+
description: str = None,
8587
inputs: dict = None,
8688
outputs: List[LambdaOutput] = None,
8789
cache_config: CacheConfig = None,
@@ -91,6 +93,8 @@ def __init__(
9193
9294
Args:
9395
name (str): The name of the lambda step.
96+
display_name (str): The display name of the Lambda step.
97+
description (str): The description of the Lambda step.
9498
lambda_func (str): An instance of sagemaker.lambda_helper.Lambda.
9599
If lambda arn is specified in the instance, LambdaStep just invokes the function,
96100
else lambda function will be created while creating the pipeline.
@@ -101,7 +105,9 @@ def __init__(
101105
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.LambdaStep`
102106
depends on
103107
"""
104-
super(LambdaStep, self).__init__(name, StepTypeEnum.LAMBDA, depends_on)
108+
super(LambdaStep, self).__init__(
109+
name, display_name, description, StepTypeEnum.LAMBDA, depends_on
110+
)
105111
self.lambda_func = lambda_func
106112
self.outputs = outputs if outputs is not None else []
107113
self.cache_config = cache_config

src/sagemaker/workflow/step_collections.py

+14
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
approval_status=None,
6868
image_uri=None,
6969
compile_model_family=None,
70+
display_name=None,
7071
description=None,
7172
tags=None,
7273
model=None,
@@ -125,6 +126,8 @@ def __init__(
125126
entry_point=entry_point,
126127
source_dir=source_dir,
127128
dependencies=dependencies,
129+
description=description,
130+
display_name=display_name,
128131
**kwargs,
129132
)
130133
steps.append(repack_model_step)
@@ -163,6 +166,8 @@ def __init__(
163166
entry_point=entry_point,
164167
source_dir=source_dir,
165168
dependencies=dependencies,
169+
description=description,
170+
display_name=display_name,
166171
**kwargs,
167172
)
168173
steps.append(repack_model_step)
@@ -192,6 +197,7 @@ def __init__(
192197
image_uri=image_uri,
193198
compile_model_family=compile_model_family,
194199
description=description,
200+
display_name=display_name,
195201
tags=tags,
196202
container_def_list=self.container_def_list,
197203
**kwargs,
@@ -215,6 +221,8 @@ def __init__(
215221
instance_count,
216222
instance_type,
217223
transform_inputs,
224+
description: str = None,
225+
display_name: str = None,
218226
# model arguments
219227
image_uri=None,
220228
predictor_cls=None,
@@ -283,6 +291,8 @@ def __init__(
283291
entry_point=entry_point,
284292
source_dir=source_dir,
285293
dependencies=dependencies,
294+
description=description,
295+
display_name=display_name,
286296
)
287297
steps.append(repack_model_step)
288298
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
@@ -305,6 +315,8 @@ def predict_wrapper(endpoint, session):
305315
name=f"{name}CreateModelStep",
306316
model=model,
307317
inputs=model_inputs,
318+
description=description,
319+
display_name=display_name,
308320
)
309321
if "entry_point" not in kwargs and depends_on:
310322
# if the CreateModelStep is the first step in the collection
@@ -332,6 +344,8 @@ def predict_wrapper(endpoint, session):
332344
name=f"{name}TransformStep",
333345
transformer=transformer,
334346
inputs=transform_inputs,
347+
description=description,
348+
display_name=display_name,
335349
)
336350
steps.append(transform_step)
337351

src/sagemaker/workflow/steps.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,16 @@ class Step(Entity):
6363
6464
Attributes:
6565
name (str): The name of the step.
66+
display_name (str): The display name of the step.
67+
description (str): The description of the step.
6668
step_type (StepTypeEnum): The type of the step.
6769
depends_on (List[str] or List[Step]): The list of step names or step
6870
instances the current step depends on
6971
"""
7072

7173
name: str = attr.ib(factory=str)
74+
display_name: str = attr.ib(default=None)
75+
description: str = attr.ib(default=None)
7276
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
7377
depends_on: Union[List[str], List["Step"]] = attr.ib(default=None)
7478

@@ -91,7 +95,10 @@ def to_request(self) -> RequestType:
9195
}
9296
if self.depends_on:
9397
request_dict["DependsOn"] = self._resolve_depends_on(self.depends_on)
94-
98+
if self.display_name:
99+
request_dict["DisplayName"] = self.display_name
100+
if self.description:
101+
request_dict["Description"] = self.description
95102
return request_dict
96103

97104
def add_depends_on(self, step_names: Union[List[str], List["Step"]]):
@@ -168,6 +175,8 @@ def __init__(
168175
self,
169176
name: str,
170177
estimator: EstimatorBase,
178+
display_name: str = None,
179+
description: str = None,
171180
inputs: Union[TrainingInput, dict, str, FileSystemInput] = None,
172181
cache_config: CacheConfig = None,
173182
depends_on: Union[List[str], List[Step]] = None,
@@ -180,6 +189,8 @@ def __init__(
180189
Args:
181190
name (str): The name of the training step.
182191
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
192+
display_name (str): The display name of the training step.
193+
description (str): The description of the training step.
183194
inputs (Union[str, dict, TrainingInput, FileSystemInput]): Information
184195
about the training data. This can be one of three types:
185196
@@ -200,7 +211,9 @@ def __init__(
200211
depends_on (List[str] or List[Step]): A list of step names or step instances
201212
this `sagemaker.workflow.steps.TrainingStep` depends on
202213
"""
203-
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
214+
super(TrainingStep, self).__init__(
215+
name, display_name, description, StepTypeEnum.TRAINING, depends_on
216+
)
204217
self.estimator = estimator
205218
self.inputs = inputs
206219
self._properties = Properties(
@@ -248,6 +261,8 @@ def __init__(
248261
model: Model,
249262
inputs: CreateModelInput,
250263
depends_on: Union[List[str], List[Step]] = None,
264+
display_name: str = None,
265+
description: str = None,
251266
):
252267
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
253268
@@ -261,8 +276,12 @@ def __init__(
261276
Defaults to `None`.
262277
depends_on (List[str] or List[Step]): A list of step names or step instances
263278
this `sagemaker.workflow.steps.CreateModelStep` depends on
279+
display_name (str): The display name of the CreateModel step.
280+
description (str): The description of the CreateModel step.
264281
"""
265-
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
282+
super(CreateModelStep, self).__init__(
283+
name, display_name, description, StepTypeEnum.CREATE_MODEL, depends_on
284+
)
266285
self.model = model
267286
self.inputs = inputs or CreateModelInput()
268287

@@ -304,6 +323,8 @@ def __init__(
304323
name: str,
305324
transformer: Transformer,
306325
inputs: TransformInput,
326+
display_name: str = None,
327+
description: str = None,
307328
cache_config: CacheConfig = None,
308329
depends_on: Union[List[str], List[Step]] = None,
309330
):
@@ -317,10 +338,14 @@ def __init__(
317338
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
318339
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
319340
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
320-
depends_on (List[str] or List[Step]): A list of step names or step instances
321-
this `sagemaker.workflow.steps.TransformStep` depends on
341+
display_name (str): The display name of the transform step.
342+
description (str): The description of the transform step.
343+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
344+
depends on
322345
"""
323-
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
346+
super(TransformStep, self).__init__(
347+
name, display_name, description, StepTypeEnum.TRANSFORM, depends_on
348+
)
324349
self.transformer = transformer
325350
self.inputs = inputs
326351
self.cache_config = cache_config
@@ -375,6 +400,8 @@ def __init__(
375400
self,
376401
name: str,
377402
processor: Processor,
403+
display_name: str = None,
404+
description: str = None,
378405
inputs: List[ProcessingInput] = None,
379406
outputs: List[ProcessingOutput] = None,
380407
job_arguments: List[str] = None,
@@ -391,6 +418,8 @@ def __init__(
391418
Args:
392419
name (str): The name of the processing step.
393420
processor (Processor): A `sagemaker.processing.Processor` instance.
421+
display_name (str): The display name of the processing step.
422+
description (str): The description of the processing step.
394423
inputs (List[ProcessingInput]): A list of `sagemaker.processing.ProcessorInput`
395424
instances. Defaults to `None`.
396425
outputs (List[ProcessingOutput]): A list of `sagemaker.processing.ProcessorOutput`
@@ -405,7 +434,9 @@ def __init__(
405434
depends_on (List[str] or List[Step]): A list of step names or step instance
406435
this `sagemaker.workflow.steps.ProcessingStep` depends on
407436
"""
408-
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
437+
super(ProcessingStep, self).__init__(
438+
name, display_name, description, StepTypeEnum.PROCESSING, depends_on
439+
)
409440
self.processor = processor
410441
self.inputs = inputs
411442
self.outputs = outputs
@@ -468,6 +499,8 @@ def __init__(
468499
self,
469500
name: str,
470501
tuner: HyperparameterTuner,
502+
display_name: str = None,
503+
description: str = None,
471504
inputs=None,
472505
job_arguments: List[str] = None,
473506
cache_config: CacheConfig = None,
@@ -481,6 +514,8 @@ def __init__(
481514
Args:
482515
name (str): The name of the tuning step.
483516
tuner (HyperparameterTuner): A `sagemaker.tuner.HyperparameterTuner` instance.
517+
display_name (str): The display name of the tuning step.
518+
description (str): The description of the tuning step.
484519
inputs: Information about the training data. Please refer to the
485520
``fit()`` method of the associated estimator, as this can take
486521
any of the following forms:
@@ -514,7 +549,9 @@ def __init__(
514549
depends_on (List[str] or List[Step]): A list of step names or step instance
515550
this `sagemaker.workflow.steps.ProcessingStep` depends on
516551
"""
517-
super(TuningStep, self).__init__(name, StepTypeEnum.TUNING, depends_on)
552+
super(TuningStep, self).__init__(
553+
name, display_name, description, StepTypeEnum.TUNING, depends_on
554+
)
518555
self.tuner = tuner
519556
self.inputs = inputs
520557
self.job_arguments = job_arguments

0 commit comments

Comments
 (0)