Skip to content

Commit a7ef598

Browse files
author
Dewen Qi
committed
fix: Allow StepCollection added in ConditionStep to be depended on
1 parent b4f05b8 commit a7ef598

File tree

2 files changed

+184
-23
lines changed

2 files changed

+184
-23
lines changed

src/sagemaker/workflow/pipeline.py

+34-23
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
3838
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
3939
from sagemaker.workflow.properties import Properties
40-
from sagemaker.workflow.steps import Step
40+
from sagemaker.workflow.steps import Step, StepTypeEnum
4141
from sagemaker.workflow.step_collections import StepCollection
4242
from sagemaker.workflow.condition_step import ConditionStep
4343
from sagemaker.workflow.utilities import list_to_request
@@ -79,6 +79,11 @@ class Pipeline(Entity):
7979

8080
_version: str = "2020-12-01"
8181
_metadata: Dict[str, Any] = dict()
82+
_step_map: Dict[str, Any] = dict()
83+
84+
def __attrs_post_init__(self):
85+
"""Set attributes post init"""
86+
self._step_map = _generate_step_map(self.steps)
8287

8388
def to_request(self) -> RequestType:
8489
"""Gets the request structure for workflow service calls."""
@@ -305,23 +310,27 @@ def definition(self) -> str:
305310

306311
return json.dumps(request_dict)
307312

308-
def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
313+
def _interpolate_step_collection_name_in_depends_on(self, step_requests: list):
309314
"""Insert step names as per `StepCollection` name in depends_on list
310315
311316
Args:
312-
step_requests (dict): The raw step request dict without any interpolation.
317+
step_requests (list): The list of raw step request dicts without any interpolation.
313318
"""
314-
step_name_map = {s.name: s for s in self.steps}
315319
for step_request in step_requests:
316-
if not step_request.get("DependsOn", None):
317-
continue
318320
depends_on = []
319-
for depend_step_name in step_request["DependsOn"]:
320-
if isinstance(step_name_map[depend_step_name], StepCollection):
321-
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
321+
for depend_step_name in step_request.get("DependsOn", []):
322+
if isinstance(self._step_map[depend_step_name], StepCollection):
323+
depends_on.extend([s.name for s in self._step_map[depend_step_name].steps])
322324
else:
323325
depends_on.append(depend_step_name)
324-
step_request["DependsOn"] = depends_on
326+
if depends_on:
327+
step_request["DependsOn"] = depends_on
328+
329+
if step_request["Type"] == StepTypeEnum.CONDITION.value:
330+
sub_step_requests = (
331+
step_request["Arguments"]["IfSteps"] + step_request["Arguments"]["ElseSteps"]
332+
)
333+
self._interpolate_step_collection_name_in_depends_on(sub_step_requests)
325334

326335

327336
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -448,6 +457,20 @@ def update_args(args: Dict[str, Any], **kwargs):
448457
args.update({key: value})
449458

450459

460+
def _generate_step_map(steps: Sequence[Union[Step, StepCollection]]) -> Dict[str, Any]:
461+
"""Helper method to create a mapping from Step/Step Collection name to itself."""
462+
step_map = dict()
463+
for step in steps:
464+
if step.name in step_map:
465+
raise ValueError("Pipeline steps cannot have duplicate names.")
466+
step_map[step.name] = step
467+
if isinstance(step, ConditionStep):
468+
step_map.update(_generate_step_map(step.if_steps + step.else_steps))
469+
if isinstance(step, StepCollection):
470+
step_map.update(_generate_step_map(step.steps))
471+
return step_map
472+
473+
451474
@attr.s
452475
class _PipelineExecution:
453476
"""Internal class for encapsulating pipeline execution instances.
@@ -546,23 +569,11 @@ class PipelineGraph:
546569
"""
547570

548571
def __init__(self, steps: Sequence[Union[Step, StepCollection]]):
549-
self.step_map = {}
550-
self._generate_step_map(steps)
572+
self.step_map = _generate_step_map(steps)
551573
self.adjacency_list = self._initialize_adjacency_list()
552574
if self.is_cyclic():
553575
raise ValueError("Cycle detected in pipeline step graph.")
554576

555-
def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]):
556-
"""Helper method to create a mapping from Step/Step Collection name to itself."""
557-
for step in steps:
558-
if step.name in self.step_map:
559-
raise ValueError("Pipeline steps cannot have duplicate names.")
560-
self.step_map[step.name] = step
561-
if isinstance(step, ConditionStep):
562-
self._generate_step_map(step.if_steps + step.else_steps)
563-
if isinstance(step, StepCollection):
564-
self._generate_step_map(step.steps)
565-
566577
@classmethod
567578
def from_pipeline(cls, pipeline: Pipeline):
568579
"""Create a PipelineGraph object from the Pipeline object."""

tests/unit/sagemaker/workflow/test_step_collections.py

+150
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import pytest
2121

2222
from sagemaker.drift_check_baselines import DriftCheckBaselines
23+
from sagemaker.workflow.condition_step import ConditionStep
24+
from sagemaker.workflow.conditions import ConditionEquals
2325
from sagemaker.workflow.model_step import (
2426
ModelStep,
2527
_CREATE_MODEL_NAME_BASE,
@@ -360,6 +362,154 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session):
360362
)
361363

362364

365+
def test_step_collection_in_condition_branch_is_depended_on(pipeline_session, sagemaker_session):
366+
custom_step1 = CustomStep(name="MyStep1")
367+
368+
# Define a step collection which will be inserted into the ConditionStep
369+
model_name = "MyModel"
370+
model = Model(
371+
name=model_name,
372+
image_uri=IMAGE_URI,
373+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
374+
sagemaker_session=pipeline_session,
375+
entry_point=f"{DATA_DIR}/dummy_script.py",
376+
source_dir=f"{DATA_DIR}",
377+
role=ROLE,
378+
)
379+
step_args = model.create(
380+
instance_type="c4.4xlarge",
381+
accelerator_type="ml.eia1.medium",
382+
)
383+
model_step_name = "MyModelStep"
384+
model_step = ModelStep(
385+
name=model_step_name,
386+
step_args=step_args,
387+
)
388+
389+
# Define another step collection which will be inserted into the ConditionStep
390+
# This StepCollection object depends on a StepCollection object in the ConditionStep
391+
# And a normal step outside ConditionStep
392+
model.sagemaker_session = sagemaker_session
393+
register_model_name = "RegisterModelStep"
394+
register_model = RegisterModel(
395+
name=register_model_name,
396+
model=model,
397+
model_data="s3://",
398+
content_types=["content_type"],
399+
response_types=["response_type"],
400+
inference_instances=["inference_instance"],
401+
transform_instances=["transform_instance"],
402+
model_package_group_name="mpg",
403+
depends_on=["MyStep1", model_step],
404+
)
405+
406+
# StepCollection objects are depended on by a normal step in the ConditionStep
407+
custom_step2 = CustomStep(
408+
name="MyStep2", depends_on=["MyStep1", model_step, register_model_name]
409+
)
410+
# StepCollection objects are depended on by a normal step outside the ConditionStep
411+
custom_step3 = CustomStep(
412+
name="MyStep3", depends_on=[custom_step1, model_step_name, register_model]
413+
)
414+
415+
cond_step = ConditionStep(
416+
name="CondStep",
417+
conditions=[ConditionEquals(left=2, right=1)],
418+
if_steps=[],
419+
else_steps=[model_step, register_model, custom_step2],
420+
)
421+
422+
pipeline = Pipeline(
423+
name="MyPipeline",
424+
steps=[cond_step, custom_step1, custom_step3],
425+
)
426+
step_list = json.loads(pipeline.definition())["Steps"]
427+
assert len(step_list) == 3
428+
for step in step_list:
429+
if step["Name"] == "MyStep1":
430+
assert "DependsOn" not in step
431+
elif step["Name"] == "CondStep":
432+
assert not step["Arguments"]["IfSteps"]
433+
for sub_step in step["Arguments"]["ElseSteps"]:
434+
if sub_step["Name"] == f"{model_name}-RepackModel":
435+
assert set(sub_step["DependsOn"]) == {
436+
"MyStep1",
437+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
438+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
439+
}
440+
if sub_step["Name"] == "MyStep2":
441+
assert set(sub_step["DependsOn"]) == {
442+
"MyStep1",
443+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
444+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
445+
f"{model_name}-RepackModel",
446+
f"{register_model_name}-RegisterModel",
447+
}
448+
else:
449+
assert set(step["DependsOn"]) == {
450+
"MyStep1",
451+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
452+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
453+
f"{model_name}-RepackModel",
454+
f"{register_model_name}-RegisterModel",
455+
}
456+
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
457+
assert ordered(adjacency_list) == ordered(
458+
{
459+
"CondStep": ["MyModel-RepackModel", "MyModelStep-RepackModel-MyModel", "MyStep2"],
460+
"MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"],
461+
"MyStep2": [],
462+
"MyStep3": [],
463+
"MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"],
464+
"MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"],
465+
"MyModel-RepackModel": [],
466+
"RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"],
467+
}
468+
)
469+
470+
471+
def test_condition_step_depends_on_step_collection():
472+
step1 = CustomStep(name="MyStep1")
473+
step2 = CustomStep(name="MyStep2", input_data=step1.properties)
474+
step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2])
475+
cond_step = ConditionStep(
476+
name="MyConditionStep",
477+
depends_on=[step_collection],
478+
conditions=[ConditionEquals(left=2, right=1)],
479+
if_steps=[],
480+
else_steps=[],
481+
)
482+
pipeline = Pipeline(
483+
name="MyPipeline",
484+
steps=[step_collection, cond_step],
485+
)
486+
step_list = json.loads(pipeline.definition())["Steps"]
487+
assert len(step_list) == 3
488+
for step in step_list:
489+
if step["Name"] != "MyConditionStep":
490+
continue
491+
assert step == {
492+
"Name": "MyConditionStep",
493+
"Type": "Condition",
494+
"DependsOn": ["MyStep1", "MyStep2"],
495+
"Arguments": {
496+
"Conditions": [
497+
{
498+
"Type": "Equals",
499+
"LeftValue": 2,
500+
"RightValue": 1,
501+
},
502+
],
503+
"IfSteps": [],
504+
"ElseSteps": [],
505+
},
506+
}
507+
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
508+
assert ordered(adjacency_list) == ordered(
509+
[("MyConditionStep", []), ("MyStep1", ["MyStep2"]), ("MyStep2", ["MyConditionStep"])]
510+
)
511+
512+
363513
def test_register_model(estimator, model_metrics, drift_check_baselines):
364514
model_data = f"s3://{BUCKET}/model.tar.gz"
365515
register_model = RegisterModel(

0 commit comments

Comments
 (0)