Skip to content

Commit f46782f

Browse files
author
Dewen Qi
committed
fix: Allow StepCollection added in ConditionStep to be depended on
1 parent b4f05b8 commit f46782f

File tree

4 files changed

+206
-37
lines changed

4 files changed

+206
-37
lines changed

src/sagemaker/workflow/pipeline.py

+34-22
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
3838
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
3939
from sagemaker.workflow.properties import Properties
40-
from sagemaker.workflow.steps import Step
40+
from sagemaker.workflow.steps import Step, StepTypeEnum
4141
from sagemaker.workflow.step_collections import StepCollection
4242
from sagemaker.workflow.condition_step import ConditionStep
4343
from sagemaker.workflow.utilities import list_to_request
@@ -79,6 +79,11 @@ class Pipeline(Entity):
7979

8080
_version: str = "2020-12-01"
8181
_metadata: Dict[str, Any] = dict()
82+
_step_map: Dict[str, Any] = dict()
83+
84+
def __attrs_post_init__(self):
85+
"""Set attributes post init"""
86+
_generate_step_map(self.steps, self._step_map)
8287

8388
def to_request(self) -> RequestType:
8489
"""Gets the request structure for workflow service calls."""
@@ -305,23 +310,27 @@ def definition(self) -> str:
305310

306311
return json.dumps(request_dict)
307312

308-
def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
313+
def _interpolate_step_collection_name_in_depends_on(self, step_requests: list):
309314
"""Insert step names as per `StepCollection` name in depends_on list
310315
311316
Args:
312-
step_requests (dict): The raw step request dict without any interpolation.
317+
step_requests (list): The list of raw step request dicts without any interpolation.
313318
"""
314-
step_name_map = {s.name: s for s in self.steps}
315319
for step_request in step_requests:
316-
if not step_request.get("DependsOn", None):
317-
continue
318320
depends_on = []
319-
for depend_step_name in step_request["DependsOn"]:
320-
if isinstance(step_name_map[depend_step_name], StepCollection):
321-
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
321+
for depend_step_name in step_request.get("DependsOn", []):
322+
if isinstance(self._step_map[depend_step_name], StepCollection):
323+
depends_on.extend([s.name for s in self._step_map[depend_step_name].steps])
322324
else:
323325
depends_on.append(depend_step_name)
324-
step_request["DependsOn"] = depends_on
326+
if depends_on:
327+
step_request["DependsOn"] = depends_on
328+
329+
if step_request["Type"] == StepTypeEnum.CONDITION.value:
330+
sub_step_requests = (
331+
step_request["Arguments"]["IfSteps"] + step_request["Arguments"]["ElseSteps"]
332+
)
333+
self._interpolate_step_collection_name_in_depends_on(sub_step_requests)
325334

326335

327336
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -448,6 +457,20 @@ def update_args(args: Dict[str, Any], **kwargs):
448457
args.update({key: value})
449458

450459

460+
def _generate_step_map(
461+
steps: Sequence[Union[Step, StepCollection]], step_map: dict
462+
) -> Dict[str, Any]:
463+
"""Helper method to create a mapping from Step/Step Collection name to itself."""
464+
for step in steps:
465+
if step.name in step_map:
466+
raise ValueError("Pipeline steps cannot have duplicate names.")
467+
step_map[step.name] = step
468+
if isinstance(step, ConditionStep):
469+
_generate_step_map(step.if_steps + step.else_steps, step_map)
470+
if isinstance(step, StepCollection):
471+
_generate_step_map(step.steps, step_map)
472+
473+
451474
@attr.s
452475
class _PipelineExecution:
453476
"""Internal class for encapsulating pipeline execution instances.
@@ -547,22 +570,11 @@ class PipelineGraph:
547570

548571
def __init__(self, steps: Sequence[Union[Step, StepCollection]]):
549572
self.step_map = {}
550-
self._generate_step_map(steps)
573+
_generate_step_map(steps, self.step_map)
551574
self.adjacency_list = self._initialize_adjacency_list()
552575
if self.is_cyclic():
553576
raise ValueError("Cycle detected in pipeline step graph.")
554577

555-
def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]):
556-
"""Helper method to create a mapping from Step/Step Collection name to itself."""
557-
for step in steps:
558-
if step.name in self.step_map:
559-
raise ValueError("Pipeline steps cannot have duplicate names.")
560-
self.step_map[step.name] = step
561-
if isinstance(step, ConditionStep):
562-
self._generate_step_map(step.if_steps + step.else_steps)
563-
if isinstance(step, StepCollection):
564-
self._generate_step_map(step.steps)
565-
566578
@classmethod
567579
def from_pipeline(cls, pipeline: Pipeline):
568580
"""Create a PipelineGraph object from the Pipeline object."""

tests/unit/sagemaker/workflow/test_pipeline.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
7878
pipeline = Pipeline(
7979
name="MyPipeline",
8080
parameters=[parameter],
81-
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
81+
steps=_generate_large_pipeline_steps(parameter),
8282
sagemaker_session=sagemaker_session_mock,
8383
)
8484

@@ -132,7 +132,7 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
132132
pipeline = Pipeline(
133133
name="MyPipeline",
134134
parameters=[parameter],
135-
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
135+
steps=_generate_large_pipeline_steps(parameter),
136136
sagemaker_session=sagemaker_session_mock,
137137
)
138138

@@ -437,3 +437,10 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
437437
PipelineExecutionArn="my:arn"
438438
)
439439
assert len(steps) == 1
440+
441+
442+
def _generate_large_pipeline_steps(input_data: object):
443+
steps = []
444+
for i in range(2000):
445+
steps.append(CustomStep(name=f"MyStep{i}", input_data=input_data))
446+
return steps

tests/unit/sagemaker/workflow/test_pipeline_graph.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def role_arn():
4545
def test_pipeline_duplicate_step_name(sagemaker_session_mock):
4646
step1 = CustomStep(name="foo")
4747
step2 = CustomStep(name="foo")
48-
pipeline = Pipeline(
49-
name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock
50-
)
5148
with pytest.raises(ValueError) as error:
49+
pipeline = Pipeline(
50+
name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock
51+
)
5252
PipelineGraph.from_pipeline(pipeline)
5353
assert "Pipeline steps cannot have duplicate names." in str(error.value)
5454

@@ -61,25 +61,25 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock):
6161
condition_step = ConditionStep(
6262
name="condStep", conditions=[cond], depends_on=[custom_step], if_steps=[custom_step2]
6363
)
64-
pipeline = Pipeline(
65-
name="MyPipeline",
66-
steps=[custom_step, condition_step],
67-
sagemaker_session=sagemaker_session_mock,
68-
)
6964
with pytest.raises(ValueError) as error:
65+
pipeline = Pipeline(
66+
name="MyPipeline",
67+
steps=[custom_step, condition_step],
68+
sagemaker_session=sagemaker_session_mock,
69+
)
7070
PipelineGraph.from_pipeline(pipeline)
7171
assert "Pipeline steps cannot have duplicate names." in str(error.value)
7272

7373

7474
def test_pipeline_duplicate_step_name_in_step_collection(sagemaker_session_mock):
7575
custom_step = CustomStep(name="foo-1")
7676
custom_step_collection = CustomStepCollection(name="foo", depends_on=[custom_step])
77-
pipeline = Pipeline(
78-
name="MyPipeline",
79-
steps=[custom_step, custom_step_collection],
80-
sagemaker_session=sagemaker_session_mock,
81-
)
8277
with pytest.raises(ValueError) as error:
78+
pipeline = Pipeline(
79+
name="MyPipeline",
80+
steps=[custom_step, custom_step_collection],
81+
sagemaker_session=sagemaker_session_mock,
82+
)
8383
PipelineGraph.from_pipeline(pipeline)
8484
assert "Pipeline steps cannot have duplicate names." in str(error.value)
8585

tests/unit/sagemaker/workflow/test_step_collections.py

+150
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import pytest
2121

2222
from sagemaker.drift_check_baselines import DriftCheckBaselines
23+
from sagemaker.workflow.condition_step import ConditionStep
24+
from sagemaker.workflow.conditions import ConditionEquals
2325
from sagemaker.workflow.model_step import (
2426
ModelStep,
2527
_CREATE_MODEL_NAME_BASE,
@@ -360,6 +362,154 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session):
360362
)
361363

362364

365+
def test_step_collection_in_condition_branch_is_depended_on(pipeline_session, sagemaker_session):
366+
custom_step1 = CustomStep(name="MyStep1")
367+
368+
# Define a step collection which will be inserted into the ConditionStep
369+
model_name = "MyModel"
370+
model = Model(
371+
name=model_name,
372+
image_uri=IMAGE_URI,
373+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
374+
sagemaker_session=pipeline_session,
375+
entry_point=f"{DATA_DIR}/dummy_script.py",
376+
source_dir=f"{DATA_DIR}",
377+
role=ROLE,
378+
)
379+
step_args = model.create(
380+
instance_type="c4.4xlarge",
381+
accelerator_type="ml.eia1.medium",
382+
)
383+
model_step_name = "MyModelStep"
384+
model_step = ModelStep(
385+
name=model_step_name,
386+
step_args=step_args,
387+
)
388+
389+
# Define another step collection which will be inserted into the ConditionStep
390+
# This StepCollection object depends on a StepCollection object in the ConditionStep
391+
# And a normal step outside ConditionStep
392+
model.sagemaker_session = sagemaker_session
393+
register_model_name = "RegisterModelStep"
394+
register_model = RegisterModel(
395+
name=register_model_name,
396+
model=model,
397+
model_data="s3://",
398+
content_types=["content_type"],
399+
response_types=["response_type"],
400+
inference_instances=["inference_instance"],
401+
transform_instances=["transform_instance"],
402+
model_package_group_name="mpg",
403+
depends_on=["MyStep1", model_step],
404+
)
405+
406+
# StepCollection objects are depended on by a normal step in the ConditionStep
407+
custom_step2 = CustomStep(
408+
name="MyStep2", depends_on=["MyStep1", model_step, register_model_name]
409+
)
410+
# StepCollection objects are depended on by a normal step outside the ConditionStep
411+
custom_step3 = CustomStep(
412+
name="MyStep3", depends_on=[custom_step1, model_step_name, register_model]
413+
)
414+
415+
cond_step = ConditionStep(
416+
name="CondStep",
417+
conditions=[ConditionEquals(left=2, right=1)],
418+
if_steps=[],
419+
else_steps=[model_step, register_model, custom_step2],
420+
)
421+
422+
pipeline = Pipeline(
423+
name="MyPipeline",
424+
steps=[cond_step, custom_step1, custom_step3],
425+
)
426+
step_list = json.loads(pipeline.definition())["Steps"]
427+
assert len(step_list) == 3
428+
for step in step_list:
429+
if step["Name"] == "MyStep1":
430+
assert "DependsOn" not in step
431+
elif step["Name"] == "CondStep":
432+
assert not step["Arguments"]["IfSteps"]
433+
for sub_step in step["Arguments"]["ElseSteps"]:
434+
if sub_step["Name"] == f"{model_name}-RepackModel":
435+
assert set(sub_step["DependsOn"]) == {
436+
"MyStep1",
437+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
438+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
439+
}
440+
if sub_step["Name"] == "MyStep2":
441+
assert set(sub_step["DependsOn"]) == {
442+
"MyStep1",
443+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
444+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
445+
f"{model_name}-RepackModel",
446+
f"{register_model_name}-RegisterModel",
447+
}
448+
else:
449+
assert set(step["DependsOn"]) == {
450+
"MyStep1",
451+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
452+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
453+
f"{model_name}-RepackModel",
454+
f"{register_model_name}-RegisterModel",
455+
}
456+
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
457+
assert ordered(adjacency_list) == ordered(
458+
{
459+
"CondStep": ["MyModel-RepackModel", "MyModelStep-RepackModel-MyModel", "MyStep2"],
460+
"MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"],
461+
"MyStep2": [],
462+
"MyStep3": [],
463+
"MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"],
464+
"MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"],
465+
"MyModel-RepackModel": [],
466+
"RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"],
467+
}
468+
)
469+
470+
471+
def test_condition_step_depends_on_step_collection():
472+
step1 = CustomStep(name="MyStep1")
473+
step2 = CustomStep(name="MyStep2", input_data=step1.properties)
474+
step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2])
475+
cond_step = ConditionStep(
476+
name="MyConditionStep",
477+
depends_on=[step_collection],
478+
conditions=[ConditionEquals(left=2, right=1)],
479+
if_steps=[],
480+
else_steps=[],
481+
)
482+
pipeline = Pipeline(
483+
name="MyPipeline",
484+
steps=[step_collection, cond_step],
485+
)
486+
step_list = json.loads(pipeline.definition())["Steps"]
487+
assert len(step_list) == 3
488+
for step in step_list:
489+
if step["Name"] != "MyConditionStep":
490+
continue
491+
assert step == {
492+
"Name": "MyConditionStep",
493+
"Type": "Condition",
494+
"DependsOn": ["MyStep1", "MyStep2"],
495+
"Arguments": {
496+
"Conditions": [
497+
{
498+
"Type": "Equals",
499+
"LeftValue": 2,
500+
"RightValue": 1,
501+
},
502+
],
503+
"IfSteps": [],
504+
"ElseSteps": [],
505+
},
506+
}
507+
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
508+
assert ordered(adjacency_list) == ordered(
509+
[("MyConditionStep", []), ("MyStep1", ["MyStep2"]), ("MyStep2", ["MyConditionStep"])]
510+
)
511+
512+
363513
def test_register_model(estimator, model_metrics, drift_check_baselines):
364514
model_data = f"s3://{BUCKET}/model.tar.gz"
365515
register_model = RegisterModel(

0 commit comments

Comments
 (0)