Skip to content

Commit 2d59111

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
authored
fix: Allow StepCollection added in ConditionStep to be depended on (#3261)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 6f72e3c commit 2d59111

File tree

4 files changed

+267
-69
lines changed

4 files changed

+267
-69
lines changed

src/sagemaker/workflow/pipeline.py

+73-54
Original file line numberDiff line numberDiff line change
@@ -37,48 +37,58 @@
3737
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
3838
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
3939
from sagemaker.workflow.properties import Properties
40-
from sagemaker.workflow.steps import Step
40+
from sagemaker.workflow.steps import Step, StepTypeEnum
4141
from sagemaker.workflow.step_collections import StepCollection
4242
from sagemaker.workflow.condition_step import ConditionStep
4343
from sagemaker.workflow.utilities import list_to_request
4444

45+
_DEFAULT_EXPERIMENT_CFG = PipelineExperimentConfig(
46+
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
47+
)
48+
4549

46-
@attr.s
4750
class Pipeline(Entity):
48-
"""Pipeline for workflow.
51+
"""Pipeline for workflow."""
4952

50-
Attributes:
51-
name (str): The name of the pipeline.
52-
parameters (Sequence[Parameter]): The list of the parameters.
53-
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
54-
the workflow will attempt to create an experiment and trial before
55-
executing the steps. Creation will be skipped if an experiment or a trial with
56-
the same name already exists. By default, pipeline name is used as
57-
experiment name and execution id is used as the trial name.
58-
If set to None, no experiment or trial will be created automatically.
59-
steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps
60-
associated with the pipeline. Any steps that are within the
61-
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
62-
pipeline. Of particular note, the workflow service rejects any pipeline definitions that
63-
specify a step in the list of steps of a pipeline and that step in the `if_steps` or
64-
`else_steps` of any `ConditionStep`.
65-
sagemaker_session (sagemaker.session.Session): Session object that manages interactions
66-
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
67-
pipeline creates one using the default AWS configuration chain.
68-
"""
53+
def __init__(
54+
self,
55+
name: str = "",
56+
parameters: Optional[Sequence[Parameter]] = None,
57+
pipeline_experiment_config: Optional[PipelineExperimentConfig] = _DEFAULT_EXPERIMENT_CFG,
58+
steps: Optional[Sequence[Union[Step, StepCollection]]] = None,
59+
sagemaker_session: Optional[Session] = None,
60+
):
61+
"""Initialize a Pipeline
6962
70-
name: str = attr.ib(factory=str)
71-
parameters: Sequence[Parameter] = attr.ib(factory=list)
72-
pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib(
73-
default=PipelineExperimentConfig(
74-
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
75-
)
76-
)
77-
steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list)
78-
sagemaker_session: Session = attr.ib(factory=Session)
63+
Args:
64+
name (str): The name of the pipeline.
65+
parameters (Sequence[Parameter]): The list of the parameters.
66+
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
67+
the workflow will attempt to create an experiment and trial before
68+
executing the steps. Creation will be skipped if an experiment or a trial with
69+
the same name already exists. By default, pipeline name is used as
70+
experiment name and execution id is used as the trial name.
71+
If set to None, no experiment or trial will be created automatically.
72+
steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps
73+
associated with the pipeline. Any steps that are within the
74+
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
75+
pipeline. Of particular note, the workflow service rejects any pipeline definitions
76+
that specify a step in the list of steps of a pipeline and that step in the
77+
`if_steps` or `else_steps` of any `ConditionStep`.
78+
sagemaker_session (sagemaker.session.Session): Session object that manages interactions
79+
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
80+
pipeline creates one using the default AWS configuration chain.
81+
"""
82+
self.name = name
83+
self.parameters = parameters if parameters else []
84+
self.pipeline_experiment_config = pipeline_experiment_config
85+
self.steps = steps if steps else []
86+
self.sagemaker_session = sagemaker_session if sagemaker_session else Session()
7987

80-
_version: str = "2020-12-01"
81-
_metadata: Dict[str, Any] = dict()
88+
self._version = "2020-12-01"
89+
self._metadata = dict()
90+
self._step_map = dict()
91+
_generate_step_map(self.steps, self._step_map)
8292

8393
def to_request(self) -> RequestType:
8494
"""Gets the request structure for workflow service calls."""
@@ -193,6 +203,8 @@ def update(
193203
Returns:
194204
A response dict from the service.
195205
"""
206+
self._step_map = dict()
207+
_generate_step_map(self.steps, self._step_map)
196208
kwargs = self._create_args(role_arn, description, parallelism_config)
197209
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)
198210

@@ -305,23 +317,27 @@ def definition(self) -> str:
305317

306318
return json.dumps(request_dict)
307319

308-
def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
320+
def _interpolate_step_collection_name_in_depends_on(self, step_requests: list):
309321
"""Insert step names as per `StepCollection` name in depends_on list
310322
311323
Args:
312-
step_requests (dict): The raw step request dict without any interpolation.
324+
step_requests (list): The list of raw step request dicts without any interpolation.
313325
"""
314-
step_name_map = {s.name: s for s in self.steps}
315326
for step_request in step_requests:
316-
if not step_request.get("DependsOn", None):
317-
continue
318327
depends_on = []
319-
for depend_step_name in step_request["DependsOn"]:
320-
if isinstance(step_name_map[depend_step_name], StepCollection):
321-
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
328+
for depend_step_name in step_request.get("DependsOn", []):
329+
if isinstance(self._step_map[depend_step_name], StepCollection):
330+
depends_on.extend([s.name for s in self._step_map[depend_step_name].steps])
322331
else:
323332
depends_on.append(depend_step_name)
324-
step_request["DependsOn"] = depends_on
333+
if depends_on:
334+
step_request["DependsOn"] = depends_on
335+
336+
if step_request["Type"] == StepTypeEnum.CONDITION.value:
337+
sub_step_requests = (
338+
step_request["Arguments"]["IfSteps"] + step_request["Arguments"]["ElseSteps"]
339+
)
340+
self._interpolate_step_collection_name_in_depends_on(sub_step_requests)
325341

326342

327343
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
@@ -448,6 +464,20 @@ def update_args(args: Dict[str, Any], **kwargs):
448464
args.update({key: value})
449465

450466

467+
def _generate_step_map(
468+
steps: Sequence[Union[Step, StepCollection]], step_map: dict
469+
) -> Dict[str, Any]:
470+
"""Helper method to create a mapping from Step/Step Collection name to itself."""
471+
for step in steps:
472+
if step.name in step_map:
473+
raise ValueError("Pipeline steps cannot have duplicate names.")
474+
step_map[step.name] = step
475+
if isinstance(step, ConditionStep):
476+
_generate_step_map(step.if_steps + step.else_steps, step_map)
477+
if isinstance(step, StepCollection):
478+
_generate_step_map(step.steps, step_map)
479+
480+
451481
@attr.s
452482
class _PipelineExecution:
453483
"""Internal class for encapsulating pipeline execution instances.
@@ -547,22 +577,11 @@ class PipelineGraph:
547577

548578
def __init__(self, steps: Sequence[Union[Step, StepCollection]]):
549579
self.step_map = {}
550-
self._generate_step_map(steps)
580+
_generate_step_map(steps, self.step_map)
551581
self.adjacency_list = self._initialize_adjacency_list()
552582
if self.is_cyclic():
553583
raise ValueError("Cycle detected in pipeline step graph.")
554584

555-
def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]):
556-
"""Helper method to create a mapping from Step/Step Collection name to itself."""
557-
for step in steps:
558-
if step.name in self.step_map:
559-
raise ValueError("Pipeline steps cannot have duplicate names.")
560-
self.step_map[step.name] = step
561-
if isinstance(step, ConditionStep):
562-
self._generate_step_map(step.if_steps + step.else_steps)
563-
if isinstance(step, StepCollection):
564-
self._generate_step_map(step.steps)
565-
566585
@classmethod
567586
def from_pipeline(cls, pipeline: Pipeline):
568587
"""Create a PipelineGraph object from the Pipeline object."""

tests/unit/sagemaker/workflow/test_pipeline.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from mock import Mock
2121

2222
from sagemaker import s3
23+
from sagemaker.workflow.condition_step import ConditionStep
24+
from sagemaker.workflow.conditions import ConditionEquals
2325
from sagemaker.workflow.execution_variables import ExecutionVariables
2426
from sagemaker.workflow.parameters import ParameterString
2527
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
@@ -28,6 +30,7 @@
2830
PipelineExperimentConfig,
2931
PipelineExperimentConfigProperties,
3032
)
33+
from sagemaker.workflow.step_collections import StepCollection
3134
from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep
3235

3336

@@ -78,7 +81,7 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
7881
pipeline = Pipeline(
7982
name="MyPipeline",
8083
parameters=[parameter],
81-
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
84+
steps=_generate_large_pipeline_steps(parameter),
8285
sagemaker_session=sagemaker_session_mock,
8386
)
8487

@@ -105,6 +108,25 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
105108
sagemaker_session=sagemaker_session_mock,
106109
)
107110
pipeline.update(role_arn=role_arn)
111+
assert len(json.loads(pipeline.definition())["Steps"]) == 0
112+
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
113+
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
114+
)
115+
116+
step1 = CustomStep(name="MyStep1")
117+
step2 = CustomStep(name="MyStep2", input_data=step1.properties)
118+
step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2])
119+
cond_step = ConditionStep(
120+
name="MyConditionStep",
121+
depends_on=[],
122+
conditions=[ConditionEquals(left=2, right=1)],
123+
if_steps=[step_collection],
124+
else_steps=[],
125+
)
126+
step3 = CustomStep(name="MyStep3", depends_on=[step_collection])
127+
pipeline.steps = [cond_step, step3]
128+
pipeline.update(role_arn=role_arn)
129+
assert len(json.loads(pipeline.definition())["Steps"]) > 0
108130
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
109131
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
110132
)
@@ -132,7 +154,7 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
132154
pipeline = Pipeline(
133155
name="MyPipeline",
134156
parameters=[parameter],
135-
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
157+
steps=_generate_large_pipeline_steps(parameter),
136158
sagemaker_session=sagemaker_session_mock,
137159
)
138160

@@ -437,3 +459,10 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
437459
PipelineExecutionArn="my:arn"
438460
)
439461
assert len(steps) == 1
462+
463+
464+
def _generate_large_pipeline_steps(input_data: object):
465+
steps = []
466+
for i in range(2000):
467+
steps.append(CustomStep(name=f"MyStep{i}", input_data=input_data))
468+
return steps

tests/unit/sagemaker/workflow/test_pipeline_graph.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def role_arn():
4545
def test_pipeline_duplicate_step_name(sagemaker_session_mock):
4646
step1 = CustomStep(name="foo")
4747
step2 = CustomStep(name="foo")
48-
pipeline = Pipeline(
49-
name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock
50-
)
5148
with pytest.raises(ValueError) as error:
49+
pipeline = Pipeline(
50+
name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock
51+
)
5252
PipelineGraph.from_pipeline(pipeline)
5353
assert "Pipeline steps cannot have duplicate names." in str(error.value)
5454

@@ -61,25 +61,25 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock):
6161
condition_step = ConditionStep(
6262
name="condStep", conditions=[cond], depends_on=[custom_step], if_steps=[custom_step2]
6363
)
64-
pipeline = Pipeline(
65-
name="MyPipeline",
66-
steps=[custom_step, condition_step],
67-
sagemaker_session=sagemaker_session_mock,
68-
)
6964
with pytest.raises(ValueError) as error:
65+
pipeline = Pipeline(
66+
name="MyPipeline",
67+
steps=[custom_step, condition_step],
68+
sagemaker_session=sagemaker_session_mock,
69+
)
7070
PipelineGraph.from_pipeline(pipeline)
7171
assert "Pipeline steps cannot have duplicate names." in str(error.value)
7272

7373

7474
def test_pipeline_duplicate_step_name_in_step_collection(sagemaker_session_mock):
7575
custom_step = CustomStep(name="foo-1")
7676
custom_step_collection = CustomStepCollection(name="foo", depends_on=[custom_step])
77-
pipeline = Pipeline(
78-
name="MyPipeline",
79-
steps=[custom_step, custom_step_collection],
80-
sagemaker_session=sagemaker_session_mock,
81-
)
8277
with pytest.raises(ValueError) as error:
78+
pipeline = Pipeline(
79+
name="MyPipeline",
80+
steps=[custom_step, custom_step_collection],
81+
sagemaker_session=sagemaker_session_mock,
82+
)
8383
PipelineGraph.from_pipeline(pipeline)
8484
assert "Pipeline steps cannot have duplicate names." in str(error.value)
8585

0 commit comments

Comments
 (0)