Skip to content

Commit e00a38f

Browse files
fix: RegisterModel step and custom dependency support (#2262)
Co-authored-by: Ahsan Khan <[email protected]>
1 parent 00dec4d commit e00a38f

File tree

9 files changed

+434
-15
lines changed

9 files changed

+434
-15
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
entry_point: str,
6060
source_dir: str = None,
6161
dependencies: List = None,
62+
depends_on: List[str] = None,
6263
):
6364
"""Constructs a TrainingStep, given an `EstimatorBase` instance.
6465
@@ -102,7 +103,9 @@ def __init__(
102103
inputs = TrainingInput(self._model_prefix)
103104

104105
# super!
105-
super(_RepackModelStep, self).__init__(name=name, estimator=repacker, inputs=inputs)
106+
super(_RepackModelStep, self).__init__(
107+
name=name, depends_on=depends_on, estimator=repacker, inputs=inputs
108+
)
106109

107110
def _prepare_for_repacking(self):
108111
"""Prepares the source for the estimator."""
@@ -221,6 +224,7 @@ def __init__(
221224
image_uri=None,
222225
compile_model_family=None,
223226
description=None,
227+
depends_on: List[str] = None,
224228
**kwargs,
225229
):
226230
"""Constructor of a register model step.
@@ -248,9 +252,11 @@ def __init__(
248252
compile_model_family (str): Instance family for compiled model, if specified, a compiled
249253
model will be used (default: None).
250254
description (str): Model Package description (default: None).
255+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
256+
depends on
251257
**kwargs: additional arguments to `create_model`.
252258
"""
253-
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL)
259+
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)
254260
self.estimator = estimator
255261
self.model_data = model_data
256262
self.content_types = content_types

src/sagemaker/workflow/condition_step.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ConditionStep(Step):
4040
def __init__(
4141
self,
4242
name: str,
43+
depends_on: List[str] = None,
4344
conditions: List[Condition] = None,
4445
if_steps: List[Union[Step, StepCollection]] = None,
4546
else_steps: List[Union[Step, StepCollection]] = None,
@@ -60,7 +61,7 @@ def __init__(
6061
and `sagemaker.workflow.step_collections.StepCollection` instances that are
6162
marked as ready for execution if the list of conditions evaluates to False.
6263
"""
63-
super(ConditionStep, self).__init__(name, StepTypeEnum.CONDITION)
64+
super(ConditionStep, self).__init__(name, StepTypeEnum.CONDITION, depends_on)
6465
self.conditions = conditions or []
6566
self.if_steps = if_steps or []
6667
self.else_steps = else_steps or []

src/sagemaker/workflow/step_collections.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
response_types,
6161
inference_instances,
6262
transform_instances,
63+
depends_on: List[str] = None,
6364
model_package_group_name=None,
6465
model_metrics=None,
6566
approval_status=None,
@@ -80,6 +81,8 @@ def __init__(
8081
generate inferences in real-time (default: None).
8182
transform_instances (list): A list of the instance types on which a transformation
8283
job can be run or on which an endpoint can be deployed (default: None).
84+
depends_on (List[str]): The list of step names the first step in the collection
85+
depends on
8386
model_package_group_name (str): The Model Package Group name, exclusive to
8487
`model_package_name`, using `model_package_group_name` makes the Model Package
8588
versioned (default: None).
@@ -94,12 +97,15 @@ def __init__(
9497
**kwargs: additional arguments to `create_model`.
9598
"""
9699
steps: List[Step] = []
100+
repack_model = False
97101
if "entry_point" in kwargs:
102+
repack_model = True
98103
entry_point = kwargs["entry_point"]
99104
source_dir = kwargs.get("source_dir")
100105
dependencies = kwargs.get("dependencies")
101106
repack_model_step = _RepackModelStep(
102107
name=f"{name}RepackModel",
108+
depends_on=depends_on,
103109
estimator=estimator,
104110
model_data=model_data,
105111
entry_point=entry_point,
@@ -109,6 +115,11 @@ def __init__(
109115
steps.append(repack_model_step)
110116
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
111117

118+
# remove kwargs consumed by model repacking step
119+
kwargs.pop("entry_point", None)
120+
kwargs.pop("source_dir", None)
121+
kwargs.pop("dependencies", None)
122+
112123
register_model_step = _RegisterModelStep(
113124
name=name,
114125
estimator=estimator,
@@ -125,6 +136,9 @@ def __init__(
125136
description=description,
126137
**kwargs,
127138
)
139+
if not repack_model:
140+
register_model_step.add_depends_on(depends_on)
141+
128142
steps.append(register_model_step)
129143
self.steps = steps
130144

@@ -155,6 +169,7 @@ def __init__(
155169
max_payload=None,
156170
tags=None,
157171
volume_kms_key=None,
172+
depends_on: List[str] = None,
158173
**kwargs,
159174
):
160175
"""Construct steps required for a Transformer step collection:
@@ -191,6 +206,8 @@ def __init__(
191206
it will be the format of the batch transform output.
192207
env (dict): The Environment variables to be set for use during the
193208
transform job (default: None).
209+
depends_on (List[str]): The list of step names the first step in
210+
the collection depends on
194211
"""
195212
steps = []
196213
if "entry_point" in kwargs:
@@ -199,6 +216,7 @@ def __init__(
199216
dependencies = kwargs.get("dependencies")
200217
repack_model_step = _RepackModelStep(
201218
name=f"{name}RepackModel",
219+
depends_on=depends_on,
202220
estimator=estimator,
203221
model_data=model_data,
204222
entry_point=entry_point,
@@ -227,6 +245,9 @@ def predict_wrapper(endpoint, session):
227245
model=model,
228246
inputs=model_inputs,
229247
)
248+
if "entry_point" not in kwargs and depends_on:
249+
# if the CreateModelStep is the first step in the collection
250+
model_step.add_depends_on(depends_on)
230251
steps.append(model_step)
231252

232253
transformer = Transformer(

src/sagemaker/workflow/steps.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ class Step(Entity):
6363
Attributes:
6464
name (str): The name of the step.
6565
step_type (StepTypeEnum): The type of the step.
66+
depends_on (List[str]): The list of step names the current step depends on
6667
"""
6768

6869
name: str = attr.ib(factory=str)
6970
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
71+
depends_on: List[str] = attr.ib(default=None)
7072

7173
@property
7274
@abc.abstractmethod
@@ -80,11 +82,22 @@ def properties(self):
8082

8183
def to_request(self) -> RequestType:
8284
"""Gets the request structure for workflow service calls."""
83-
return {
85+
request_dict = {
8486
"Name": self.name,
8587
"Type": self.step_type.value,
8688
"Arguments": self.arguments,
8789
}
90+
if self.depends_on:
91+
request_dict["DependsOn"] = self.depends_on
92+
return request_dict
93+
94+
def add_depends_on(self, step_names: List[str]):
95+
"""Add step names to the current step depends on list"""
96+
if not step_names:
97+
return
98+
if not self.depends_on:
99+
self.depends_on = []
100+
self.depends_on.extend(step_names)
88101

89102
@property
90103
def ref(self) -> Dict[str, str]:
@@ -133,6 +146,7 @@ def __init__(
133146
estimator: EstimatorBase,
134147
inputs: TrainingInput = None,
135148
cache_config: CacheConfig = None,
149+
depends_on: List[str] = None,
136150
):
137151
"""Construct a TrainingStep, given an `EstimatorBase` instance.
138152
@@ -144,8 +158,10 @@ def __init__(
144158
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
145159
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
146160
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
161+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep`
162+
depends on
147163
"""
148-
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING)
164+
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
149165
self.estimator = estimator
150166
self.inputs = inputs
151167
self._properties = Properties(
@@ -188,10 +204,7 @@ class CreateModelStep(Step):
188204
"""CreateModel step for workflow."""
189205

190206
def __init__(
191-
self,
192-
name: str,
193-
model: Model,
194-
inputs: CreateModelInput,
207+
self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None
195208
):
196209
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
197210
@@ -203,8 +216,10 @@ def __init__(
203216
model (Model): A `sagemaker.model.Model` instance.
204217
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
205218
Defaults to `None`.
219+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep`
220+
depends on
206221
"""
207-
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL)
222+
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
208223
self.model = model
209224
self.inputs = inputs or CreateModelInput()
210225

@@ -247,6 +262,7 @@ def __init__(
247262
transformer: Transformer,
248263
inputs: TransformInput,
249264
cache_config: CacheConfig = None,
265+
depends_on: List[str] = None,
250266
):
251267
"""Constructs a TransformStep, given an `Transformer` instance.
252268
@@ -258,8 +274,10 @@ def __init__(
258274
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
259275
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
260276
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
277+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
278+
depends on
261279
"""
262-
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM)
280+
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
263281
self.transformer = transformer
264282
self.inputs = inputs
265283
self.cache_config = cache_config
@@ -320,6 +338,7 @@ def __init__(
320338
code: str = None,
321339
property_files: List[PropertyFile] = None,
322340
cache_config: CacheConfig = None,
341+
depends_on: List[str] = None,
323342
):
324343
"""Construct a ProcessingStep, given a `Processor` instance.
325344
@@ -340,8 +359,10 @@ def __init__(
340359
property_files (List[PropertyFile]): A list of property files that workflow looks
341360
for and resolves from the configured processing output list.
342361
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
362+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
363+
depends on
343364
"""
344-
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING)
365+
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
345366
self.processor = processor
346367
self.inputs = inputs
347368
self.outputs = outputs

0 commit comments

Comments
 (0)