Skip to content

fix: Allow StepCollection added in ConditionStep to be depended on #3261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 73 additions & 54 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,48 +37,58 @@
from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig
from sagemaker.workflow.parallelism_config import ParallelismConfiguration
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.steps import Step
from sagemaker.workflow.steps import Step, StepTypeEnum
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.utilities import list_to_request

_DEFAULT_EXPERIMENT_CFG = PipelineExperimentConfig(
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
)


@attr.s
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the @attr.s and define __init__ directly, as the former is not thread safe.
With the use of @attr.s, multiple unit tests of the PR failed due to ValueError("Pipeline steps cannot have duplicate names."), when invoking _generate_step_map(self.steps, self._step_map)
The reason is, in this case, in multiple unit tests, _generate_step_map methods are building the same global _step_map. As a result, if a step name duplicates in different tests, it can cause the ValueError("Pipeline steps cannot have duplicate names.").

class Pipeline(Entity):
"""Pipeline for workflow.
"""Pipeline for workflow."""

Attributes:
name (str): The name of the pipeline.
parameters (Sequence[Parameter]): The list of the parameters.
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
the workflow will attempt to create an experiment and trial before
executing the steps. Creation will be skipped if an experiment or a trial with
the same name already exists. By default, pipeline name is used as
experiment name and execution id is used as the trial name.
If set to None, no experiment or trial will be created automatically.
steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps
associated with the pipeline. Any steps that are within the
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
pipeline. Of particular note, the workflow service rejects any pipeline definitions that
specify a step in the list of steps of a pipeline and that step in the `if_steps` or
`else_steps` of any `ConditionStep`.
sagemaker_session (sagemaker.session.Session): Session object that manages interactions
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
pipeline creates one using the default AWS configuration chain.
"""
def __init__(
self,
name: str = "",
parameters: Optional[Sequence[Parameter]] = None,
pipeline_experiment_config: Optional[PipelineExperimentConfig] = _DEFAULT_EXPERIMENT_CFG,
steps: Optional[Sequence[Union[Step, StepCollection]]] = None,
sagemaker_session: Optional[Session] = None,
):
"""Initialize a Pipeline

name: str = attr.ib(factory=str)
parameters: Sequence[Parameter] = attr.ib(factory=list)
pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib(
default=PipelineExperimentConfig(
ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID
)
)
steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list)
sagemaker_session: Session = attr.ib(factory=Session)
Args:
name (str): The name of the pipeline.
parameters (Sequence[Parameter]): The list of the parameters.
pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set,
the workflow will attempt to create an experiment and trial before
executing the steps. Creation will be skipped if an experiment or a trial with
the same name already exists. By default, pipeline name is used as
experiment name and execution id is used as the trial name.
If set to None, no experiment or trial will be created automatically.
steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps
associated with the pipeline. Any steps that are within the
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
pipeline. Of particular note, the workflow service rejects any pipeline definitions
that specify a step in the list of steps of a pipeline and that step in the
`if_steps` or `else_steps` of any `ConditionStep`.
sagemaker_session (sagemaker.session.Session): Session object that manages interactions
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
pipeline creates one using the default AWS configuration chain.
"""
self.name = name
self.parameters = parameters if parameters else []
self.pipeline_experiment_config = pipeline_experiment_config
self.steps = steps if steps else []
self.sagemaker_session = sagemaker_session if sagemaker_session else Session()

_version: str = "2020-12-01"
_metadata: Dict[str, Any] = dict()
self._version = "2020-12-01"
self._metadata = dict()
self._step_map = dict()
_generate_step_map(self.steps, self._step_map)

def to_request(self) -> RequestType:
"""Gets the request structure for workflow service calls."""
Expand Down Expand Up @@ -193,6 +203,8 @@ def update(
Returns:
A response dict from the service.
"""
self._step_map = dict()
_generate_step_map(self.steps, self._step_map)
kwargs = self._create_args(role_arn, description, parallelism_config)
return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs)

Expand Down Expand Up @@ -305,23 +317,27 @@ def definition(self) -> str:

return json.dumps(request_dict)

def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict):
def _interpolate_step_collection_name_in_depends_on(self, step_requests: list):
"""Insert step names as per `StepCollection` name in depends_on list

Args:
step_requests (dict): The raw step request dict without any interpolation.
step_requests (list): The list of raw step request dicts without any interpolation.
"""
step_name_map = {s.name: s for s in self.steps}
for step_request in step_requests:
if not step_request.get("DependsOn", None):
continue
depends_on = []
for depend_step_name in step_request["DependsOn"]:
if isinstance(step_name_map[depend_step_name], StepCollection):
depends_on.extend([s.name for s in step_name_map[depend_step_name].steps])
for depend_step_name in step_request.get("DependsOn", []):
if isinstance(self._step_map[depend_step_name], StepCollection):
depends_on.extend([s.name for s in self._step_map[depend_step_name].steps])
else:
depends_on.append(depend_step_name)
step_request["DependsOn"] = depends_on
if depends_on:
step_request["DependsOn"] = depends_on

if step_request["Type"] == StepTypeEnum.CONDITION.value:
sub_step_requests = (
step_request["Arguments"]["IfSteps"] + step_request["Arguments"]["ElseSteps"]
)
self._interpolate_step_collection_name_in_depends_on(sub_step_requests)


def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -448,6 +464,20 @@ def update_args(args: Dict[str, Any], **kwargs):
args.update({key: value})


def _generate_step_map(
steps: Sequence[Union[Step, StepCollection]], step_map: dict
) -> Dict[str, Any]:
"""Helper method to create a mapping from Step/Step Collection name to itself."""
for step in steps:
if step.name in step_map:
raise ValueError("Pipeline steps cannot have duplicate names.")
step_map[step.name] = step
if isinstance(step, ConditionStep):
_generate_step_map(step.if_steps + step.else_steps, step_map)
if isinstance(step, StepCollection):
_generate_step_map(step.steps, step_map)


@attr.s
class _PipelineExecution:
"""Internal class for encapsulating pipeline execution instances.
Expand Down Expand Up @@ -547,22 +577,11 @@ class PipelineGraph:

def __init__(self, steps: Sequence[Union[Step, StepCollection]]):
self.step_map = {}
self._generate_step_map(steps)
_generate_step_map(steps, self.step_map)
self.adjacency_list = self._initialize_adjacency_list()
if self.is_cyclic():
raise ValueError("Cycle detected in pipeline step graph.")

def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]):
"""Helper method to create a mapping from Step/Step Collection name to itself."""
for step in steps:
if step.name in self.step_map:
raise ValueError("Pipeline steps cannot have duplicate names.")
self.step_map[step.name] = step
if isinstance(step, ConditionStep):
self._generate_step_map(step.if_steps + step.else_steps)
if isinstance(step, StepCollection):
self._generate_step_map(step.steps)

@classmethod
def from_pipeline(cls, pipeline: Pipeline):
"""Create a PipelineGraph object from the Pipeline object."""
Expand Down
33 changes: 31 additions & 2 deletions tests/unit/sagemaker/workflow/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from mock import Mock

from sagemaker import s3
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.conditions import ConditionEquals
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
Expand All @@ -28,6 +30,7 @@
PipelineExperimentConfig,
PipelineExperimentConfigProperties,
)
from sagemaker.workflow.step_collections import StepCollection
from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep


Expand Down Expand Up @@ -78,7 +81,7 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
pipeline = Pipeline(
name="MyPipeline",
parameters=[parameter],
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
steps=_generate_large_pipeline_steps(parameter),
sagemaker_session=sagemaker_session_mock,
)

Expand All @@ -105,6 +108,25 @@ def test_pipeline_update(sagemaker_session_mock, role_arn):
sagemaker_session=sagemaker_session_mock,
)
pipeline.update(role_arn=role_arn)
assert len(json.loads(pipeline.definition())["Steps"]) == 0
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
)

step1 = CustomStep(name="MyStep1")
step2 = CustomStep(name="MyStep2", input_data=step1.properties)
step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2])
cond_step = ConditionStep(
name="MyConditionStep",
depends_on=[],
conditions=[ConditionEquals(left=2, right=1)],
if_steps=[step_collection],
else_steps=[],
)
step3 = CustomStep(name="MyStep3", depends_on=[step_collection])
pipeline.steps = [cond_step, step3]
pipeline.update(role_arn=role_arn)
assert len(json.loads(pipeline.definition())["Steps"]) > 0
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
)
Expand Down Expand Up @@ -132,7 +154,7 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
pipeline = Pipeline(
name="MyPipeline",
parameters=[parameter],
steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000,
steps=_generate_large_pipeline_steps(parameter),
sagemaker_session=sagemaker_session_mock,
)

Expand Down Expand Up @@ -437,3 +459,10 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
PipelineExecutionArn="my:arn"
)
assert len(steps) == 1


def _generate_large_pipeline_steps(input_data: object):
steps = []
for i in range(2000):
steps.append(CustomStep(name=f"MyStep{i}", input_data=input_data))
return steps
26 changes: 13 additions & 13 deletions tests/unit/sagemaker/workflow/test_pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def role_arn():
def test_pipeline_duplicate_step_name(sagemaker_session_mock):
step1 = CustomStep(name="foo")
step2 = CustomStep(name="foo")
pipeline = Pipeline(
name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock
)
with pytest.raises(ValueError) as error:
pipeline = Pipeline(
name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock
)
PipelineGraph.from_pipeline(pipeline)
assert "Pipeline steps cannot have duplicate names." in str(error.value)

Expand All @@ -61,25 +61,25 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock):
condition_step = ConditionStep(
name="condStep", conditions=[cond], depends_on=[custom_step], if_steps=[custom_step2]
)
pipeline = Pipeline(
name="MyPipeline",
steps=[custom_step, condition_step],
sagemaker_session=sagemaker_session_mock,
)
with pytest.raises(ValueError) as error:
pipeline = Pipeline(
name="MyPipeline",
steps=[custom_step, condition_step],
sagemaker_session=sagemaker_session_mock,
)
PipelineGraph.from_pipeline(pipeline)
assert "Pipeline steps cannot have duplicate names." in str(error.value)


def test_pipeline_duplicate_step_name_in_step_collection(sagemaker_session_mock):
custom_step = CustomStep(name="foo-1")
custom_step_collection = CustomStepCollection(name="foo", depends_on=[custom_step])
pipeline = Pipeline(
name="MyPipeline",
steps=[custom_step, custom_step_collection],
sagemaker_session=sagemaker_session_mock,
)
with pytest.raises(ValueError) as error:
pipeline = Pipeline(
name="MyPipeline",
steps=[custom_step, custom_step_collection],
sagemaker_session=sagemaker_session_mock,
)
PipelineGraph.from_pipeline(pipeline)
assert "Pipeline steps cannot have duplicate names." in str(error.value)

Expand Down
Loading