Skip to content

Commit 6874e78

Browse files
author
Dewen Qi
committed
fix: Allow StepCollection added in ConditionStep to be depended on
1 parent b4f05b8 commit 6874e78

File tree

2 files changed

+176
-4
lines changed

2 files changed

+176
-4
lines changed

src/sagemaker/workflow/pipeline.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
3838
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
3939
from sagemaker.workflow.properties import Properties
40-
from sagemaker.workflow.steps import Step
40+
from sagemaker.workflow.steps import Step, StepTypeEnum
4141
from sagemaker.workflow.step_collections import StepCollection
4242
from sagemaker.workflow.condition_step import ConditionStep
4343
from sagemaker.workflow.utilities import list_to_request
@@ -311,16 +311,36 @@ def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
311311
Args:
312312
step_requests (dict): The raw step request dict without any interpolation.
313313
"""
314-
step_name_map = {s.name: s for s in self.steps}
314+
step_name_map = dict()
315+
for step in self.steps:
316+
if step.name in step_name_map:
317+
raise Exception()
318+
step_name_map[step.name] = step
319+
if not isinstance(step, ConditionStep):
320+
continue
321+
sub_steps = step.if_steps + step.else_steps
322+
for sub_step in sub_steps:
323+
step_name_map[sub_step.name] = sub_step
324+
315325
for step_request in step_requests:
316-
if not step_request.get("DependsOn", None):
326+
is_condition_step = step_request["Type"] == StepTypeEnum.CONDITION.value
327+
if not is_condition_step and not step_request.get("DependsOn", None):
317328
continue
318329
depends_on = []
319-
for depend_step_name in step_request["DependsOn"]:
330+
for depend_step_name in step_request.get("DependsOn", []):
320331
if isinstance(step_name_map[depend_step_name], StepCollection):
321332
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
322333
else:
323334
depends_on.append(depend_step_name)
335+
336+
if is_condition_step:
337+
sub_step_requests = step_request["Arguments"]["IfSteps"] + step_request["Arguments"]["ElseSteps"]
338+
for sub_step_request in sub_step_requests:
339+
for depend_step_name in sub_step_request.get("DependsOn", []):
340+
if isinstance(step_name_map[depend_step_name], StepCollection):
341+
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
342+
else:
343+
depends_on.append(depend_step_name)
324344
step_request["DependsOn"] = depends_on
325345

326346

tests/unit/sagemaker/workflow/test_step_collections.py

+152
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import pytest
2121

2222
from sagemaker.drift_check_baselines import DriftCheckBaselines
23+
from sagemaker.workflow.condition_step import ConditionStep
24+
from sagemaker.workflow.conditions import ConditionEquals
2325
from sagemaker.workflow.model_step import (
2426
ModelStep,
2527
_CREATE_MODEL_NAME_BASE,
@@ -360,6 +362,156 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session):
360362
)
361363

362364

365+
def test_step_collection_in_condition_branch_is_depended_on(pipeline_session, sagemaker_session):
366+
custom_step1 = CustomStep(name="MyStep1")
367+
368+
# Define a step collection which will be inserted into the ConditionStep
369+
model_name = "MyModel"
370+
model = Model(
371+
name=model_name,
372+
image_uri=IMAGE_URI,
373+
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
374+
sagemaker_session=pipeline_session,
375+
entry_point=f"{DATA_DIR}/dummy_script.py",
376+
source_dir=f"{DATA_DIR}",
377+
role=ROLE,
378+
)
379+
step_args = model.create(
380+
instance_type="c4.4xlarge",
381+
accelerator_type="ml.eia1.medium",
382+
)
383+
model_step_name = "MyModelStep"
384+
model_step = ModelStep(
385+
name=model_step_name,
386+
step_args=step_args,
387+
)
388+
389+
# Define another step collection which will be inserted into the ConditionStep
390+
# This StepCollection object depends on a StepCollection object in the ConditionStep
391+
# And a normal step outside ConditionStep
392+
model.sagemaker_session = sagemaker_session
393+
register_model_name = "RegisterModelStep"
394+
register_model = RegisterModel(
395+
name=register_model_name,
396+
model=model,
397+
model_data="s3://",
398+
content_types=["content_type"],
399+
response_types=["response_type"],
400+
inference_instances=["inference_instance"],
401+
transform_instances=["transform_instance"],
402+
model_package_group_name="mpg",
403+
depends_on=["MyStep1", model_step],
404+
)
405+
406+
# StepCollection objects are depended on by a normal step in the ConditionStep
407+
custom_step2 = CustomStep(
408+
name="MyStep2", depends_on=["MyStep1", model_step, register_model_name]
409+
)
410+
# StepCollection objects are depended on by a normal step outside the ConditionStep
411+
custom_step3 = CustomStep(
412+
name="MyStep3", depends_on=[custom_step1, model_step_name, register_model]
413+
)
414+
415+
cond_step = ConditionStep(
416+
name="CondStep",
417+
conditions=[ConditionEquals(left=2, right=1)],
418+
if_steps=[],
419+
else_steps=[model_step, register_model, custom_step2],
420+
)
421+
422+
pipeline = Pipeline(
423+
name="MyPipeline",
424+
steps=[cond_step, custom_step1, custom_step3],
425+
)
426+
step_list = json.loads(pipeline.definition())["Steps"]
427+
assert len(step_list) == 3
428+
for step in step_list:
429+
if step["Name"] == "MyStep1":
430+
assert "DependsOn" not in step
431+
elif step["Name"] == "CondStep":
432+
assert not step["Arguments"]["IfSteps"]
433+
for sub_step in step["Arguments"]["ElseSteps"]:
434+
if sub_step["Name"] == f"{model_name}-RepackModel":
435+
assert set(sub_step["DependsOn"]) == {
436+
"MyStep1",
437+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
438+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
439+
}
440+
if sub_step["Name"] == "MyStep2":
441+
assert set(step["DependsOn"]) == {
442+
"MyStep1",
443+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
444+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
445+
f"{model_name}-RepackModel",
446+
f"{register_model_name}-RegisterModel",
447+
}
448+
else:
449+
assert set(step["DependsOn"]) == {
450+
"MyStep1",
451+
f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}",
452+
f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}",
453+
f"{model_name}-RepackModel",
454+
f"{register_model_name}-RegisterModel",
455+
}
456+
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
457+
print(adjacency_list)
458+
assert ordered(adjacency_list) == ordered(
459+
{
460+
"MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"],
461+
"MyStep2": [],
462+
"MyStep3": [],
463+
"MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"],
464+
"MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"],
465+
"MyModel-RepackModel": [],
466+
"RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"],
467+
}
468+
)
469+
470+
471+
def test_condition_step_depends_on_step_collection():
472+
step1 = CustomStep(name="MyStep1")
473+
step2 = CustomStep(name="MyStep2", input_data=step1.properties)
474+
step_collection = StepCollection(
475+
name="MyStepCollection", steps=[step1, step2]
476+
)
477+
cond_step = ConditionStep(
478+
name="MyConditionStep",
479+
depends_on=[step_collection],
480+
conditions=[ConditionEquals(left=2, right=1)],
481+
if_steps=[],
482+
else_steps=[],
483+
)
484+
pipeline = Pipeline(
485+
name="MyPipeline",
486+
steps=[step_collection, cond_step],
487+
)
488+
step_list = json.loads(pipeline.definition())["Steps"]
489+
assert len(step_list) == 3
490+
for step in step_list:
491+
if step["Name"] != "MyConditionStep":
492+
continue
493+
assert step == {
494+
"Name": "MyConditionStep",
495+
"Type": "Condition",
496+
"DependsOn": ["MyStep1", "MyStep2"],
497+
"Arguments": {
498+
"Conditions": [
499+
{
500+
"Type": "Equals",
501+
"LeftValue": 2,
502+
"RightValue": 1,
503+
},
504+
],
505+
"IfSteps": [],
506+
"ElseSteps": [],
507+
},
508+
}
509+
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
510+
assert ordered(adjacency_list) == ordered(
511+
[("MyConditionStep", []), ("MyStep1", ["MyStep2"]), ("MyStep2", ["MyConditionStep"])]
512+
)
513+
514+
363515
def test_register_model(estimator, model_metrics, drift_check_baselines):
364516
model_data = f"s3://{BUCKET}/model.tar.gz"
365517
register_model = RegisterModel(

0 commit comments

Comments
 (0)