diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index f560945752..275d952f81 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -37,48 +37,58 @@ from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig from sagemaker.workflow.parallelism_config import ParallelismConfiguration from sagemaker.workflow.properties import Properties -from sagemaker.workflow.steps import Step +from sagemaker.workflow.steps import Step, StepTypeEnum from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.utilities import list_to_request +_DEFAULT_EXPERIMENT_CFG = PipelineExperimentConfig( + ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID +) + -@attr.s class Pipeline(Entity): - """Pipeline for workflow. + """Pipeline for workflow.""" - Attributes: - name (str): The name of the pipeline. - parameters (Sequence[Parameter]): The list of the parameters. - pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set, - the workflow will attempt to create an experiment and trial before - executing the steps. Creation will be skipped if an experiment or a trial with - the same name already exists. By default, pipeline name is used as - experiment name and execution id is used as the trial name. - If set to None, no experiment or trial will be created automatically. - steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps - associated with the pipeline. Any steps that are within the - `if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a - pipeline. Of particular note, the workflow service rejects any pipeline definitions that - specify a step in the list of steps of a pipeline and that step in the `if_steps` or - `else_steps` of any `ConditionStep`. - sagemaker_session (sagemaker.session.Session): Session object that manages interactions - with Amazon SageMaker APIs and any other AWS services needed. If not specified, the - pipeline creates one using the default AWS configuration chain. - """ + def __init__( + self, + name: str = "", + parameters: Optional[Sequence[Parameter]] = None, + pipeline_experiment_config: Optional[PipelineExperimentConfig] = _DEFAULT_EXPERIMENT_CFG, + steps: Optional[Sequence[Union[Step, StepCollection]]] = None, + sagemaker_session: Optional[Session] = None, + ): + """Initialize a Pipeline - name: str = attr.ib(factory=str) - parameters: Sequence[Parameter] = attr.ib(factory=list) - pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib( - default=PipelineExperimentConfig( - ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID - ) - ) - steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list) - sagemaker_session: Session = attr.ib(factory=Session) + Args: + name (str): The name of the pipeline. + parameters (Sequence[Parameter]): The list of the parameters. + pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set, + the workflow will attempt to create an experiment and trial before + executing the steps. Creation will be skipped if an experiment or a trial with + the same name already exists. By default, pipeline name is used as + experiment name and execution id is used as the trial name. + If set to None, no experiment or trial will be created automatically. + steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps + associated with the pipeline. Any steps that are within the + `if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a + pipeline. Of particular note, the workflow service rejects any pipeline definitions + that specify a step in the list of steps of a pipeline and that step in the + `if_steps` or `else_steps` of any `ConditionStep`. + sagemaker_session (sagemaker.session.Session): Session object that manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + pipeline creates one using the default AWS configuration chain. + """ + self.name = name + self.parameters = parameters if parameters else [] + self.pipeline_experiment_config = pipeline_experiment_config + self.steps = steps if steps else [] + self.sagemaker_session = sagemaker_session if sagemaker_session else Session() - _version: str = "2020-12-01" - _metadata: Dict[str, Any] = dict() + self._version = "2020-12-01" + self._metadata = dict() + self._step_map = dict() + _generate_step_map(self.steps, self._step_map) def to_request(self) -> RequestType: """Gets the request structure for workflow service calls.""" @@ -193,6 +203,8 @@ def update( Returns: A response dict from the service. """ + self._step_map = dict() + _generate_step_map(self.steps, self._step_map) kwargs = self._create_args(role_arn, description, parallelism_config) return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs) @@ -305,23 +317,27 @@ def definition(self) -> str: return json.dumps(request_dict) - def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict): + def _interpolate_step_collection_name_in_depends_on(self, step_requests: list): """Insert step names as per `StepCollection` name in depends_on list Args: - step_requests (dict): The raw step request dict without any interpolation. + step_requests (list): The list of raw step request dicts without any interpolation. """ - step_name_map = {s.name: s for s in self.steps} for step_request in step_requests: - if not step_request.get("DependsOn", None): - continue depends_on = [] - for depend_step_name in step_request["DependsOn"]: - if isinstance(step_name_map[depend_step_name], StepCollection): - depends_on.extend([s.name for s in step_name_map[depend_step_name].steps]) + for depend_step_name in step_request.get("DependsOn", []): + if isinstance(self._step_map[depend_step_name], StepCollection): + depends_on.extend([s.name for s in self._step_map[depend_step_name].steps]) else: depends_on.append(depend_step_name) - step_request["DependsOn"] = depends_on + if depends_on: + step_request["DependsOn"] = depends_on + + if step_request["Type"] == StepTypeEnum.CONDITION.value: + sub_step_requests = ( + step_request["Arguments"]["IfSteps"] + step_request["Arguments"]["ElseSteps"] + ) + self._interpolate_step_collection_name_in_depends_on(sub_step_requests) def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]: @@ -448,6 +464,20 @@ def update_args(args: Dict[str, Any], **kwargs): args.update({key: value}) +def _generate_step_map( + steps: Sequence[Union[Step, StepCollection]], step_map: dict +) -> Dict[str, Any]: + """Helper method to create a mapping from Step/Step Collection name to itself.""" + for step in steps: + if step.name in step_map: + raise ValueError("Pipeline steps cannot have duplicate names.") + step_map[step.name] = step + if isinstance(step, ConditionStep): + _generate_step_map(step.if_steps + step.else_steps, step_map) + if isinstance(step, StepCollection): + _generate_step_map(step.steps, step_map) + + @attr.s class _PipelineExecution: """Internal class for encapsulating pipeline execution instances. @@ -547,22 +577,11 @@ class PipelineGraph: def __init__(self, steps: Sequence[Union[Step, StepCollection]]): self.step_map = {} - self._generate_step_map(steps) + _generate_step_map(steps, self.step_map) self.adjacency_list = self._initialize_adjacency_list() if self.is_cyclic(): raise ValueError("Cycle detected in pipeline step graph.") - def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]): - """Helper method to create a mapping from Step/Step Collection name to itself.""" - for step in steps: - if step.name in self.step_map: - raise ValueError("Pipeline steps cannot have duplicate names.") - self.step_map[step.name] = step - if isinstance(step, ConditionStep): - self._generate_step_map(step.if_steps + step.else_steps) - if isinstance(step, StepCollection): - self._generate_step_map(step.steps) - @classmethod def from_pipeline(cls, pipeline: Pipeline): """Create a PipelineGraph object from the Pipeline object.""" diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index a9e9474013..5cd94dd76a 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -20,6 +20,8 @@ from mock import Mock from sagemaker import s3 +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionEquals from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline import Pipeline, PipelineGraph @@ -28,6 +30,7 @@ PipelineExperimentConfig, PipelineExperimentConfigProperties, ) +from sagemaker.workflow.step_collections import StepCollection from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep @@ -78,7 +81,7 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn): pipeline = Pipeline( name="MyPipeline", parameters=[parameter], - steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000, + steps=_generate_large_pipeline_steps(parameter), sagemaker_session=sagemaker_session_mock, ) @@ -105,6 +108,25 @@ def test_pipeline_update(sagemaker_session_mock, role_arn): sagemaker_session=sagemaker_session_mock, ) pipeline.update(role_arn=role_arn) + assert len(json.loads(pipeline.definition())["Steps"]) == 0 + assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn + ) + + step1 = CustomStep(name="MyStep1") + step2 = CustomStep(name="MyStep2", input_data=step1.properties) + step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2]) + cond_step = ConditionStep( + name="MyConditionStep", + depends_on=[], + conditions=[ConditionEquals(left=2, right=1)], + if_steps=[step_collection], + else_steps=[], + ) + step3 = CustomStep(name="MyStep3", depends_on=[step_collection]) + pipeline.steps = [cond_step, step3] + pipeline.update(role_arn=role_arn) + assert len(json.loads(pipeline.definition())["Steps"]) > 0 assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) @@ -132,7 +154,7 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn): pipeline = Pipeline( name="MyPipeline", parameters=[parameter], - steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000, + steps=_generate_large_pipeline_steps(parameter), sagemaker_session=sagemaker_session_mock, ) @@ -437,3 +459,10 @@ def test_pipeline_execution_basics(sagemaker_session_mock): PipelineExecutionArn="my:arn" ) assert len(steps) == 1 + + +def _generate_large_pipeline_steps(input_data: object): + steps = [] + for i in range(2000): + steps.append(CustomStep(name=f"MyStep{i}", input_data=input_data)) + return steps diff --git a/tests/unit/sagemaker/workflow/test_pipeline_graph.py b/tests/unit/sagemaker/workflow/test_pipeline_graph.py index b7d69e617a..003dd8d048 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_graph.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_graph.py @@ -45,10 +45,10 @@ def role_arn(): def test_pipeline_duplicate_step_name(sagemaker_session_mock): step1 = CustomStep(name="foo") step2 = CustomStep(name="foo") - pipeline = Pipeline( - name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock - ) with pytest.raises(ValueError) as error: + pipeline = Pipeline( + name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock + ) PipelineGraph.from_pipeline(pipeline) assert "Pipeline steps cannot have duplicate names." in str(error.value) @@ -61,12 +61,12 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock): condition_step = ConditionStep( name="condStep", conditions=[cond], depends_on=[custom_step], if_steps=[custom_step2] ) - pipeline = Pipeline( - name="MyPipeline", - steps=[custom_step, condition_step], - sagemaker_session=sagemaker_session_mock, - ) with pytest.raises(ValueError) as error: + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step, condition_step], + sagemaker_session=sagemaker_session_mock, + ) PipelineGraph.from_pipeline(pipeline) assert "Pipeline steps cannot have duplicate names." in str(error.value) @@ -74,12 +74,12 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock): def test_pipeline_duplicate_step_name_in_step_collection(sagemaker_session_mock): custom_step = CustomStep(name="foo-1") custom_step_collection = CustomStepCollection(name="foo", depends_on=[custom_step]) - pipeline = Pipeline( - name="MyPipeline", - steps=[custom_step, custom_step_collection], - sagemaker_session=sagemaker_session_mock, - ) with pytest.raises(ValueError) as error: + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step, custom_step_collection], + sagemaker_session=sagemaker_session_mock, + ) PipelineGraph.from_pipeline(pipeline) assert "Pipeline steps cannot have duplicate names." in str(error.value) diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index d3b2a19fe3..d3d1ab022b 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -20,6 +20,8 @@ import pytest from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionEquals from sagemaker.workflow.model_step import ( ModelStep, _CREATE_MODEL_NAME_BASE, @@ -360,6 +362,154 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session): ) +def test_step_collection_in_condition_branch_is_depended_on(pipeline_session, sagemaker_session): + custom_step1 = CustomStep(name="MyStep1") + + # Define a step collection which will be inserted into the ConditionStep + model_name = "MyModel" + model = Model( + name=model_name, + image_uri=IMAGE_URI, + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=ROLE, + ) + step_args = model.create( + instance_type="c4.4xlarge", + accelerator_type="ml.eia1.medium", + ) + model_step_name = "MyModelStep" + model_step = ModelStep( + name=model_step_name, + step_args=step_args, + ) + + # Define another step collection which will be inserted into the ConditionStep + # This StepCollection object depends on a StepCollection object in the ConditionStep + # And a normal step outside ConditionStep + model.sagemaker_session = sagemaker_session + register_model_name = "RegisterModelStep" + register_model = RegisterModel( + name=register_model_name, + model=model, + model_data="s3://", + content_types=["content_type"], + response_types=["response_type"], + inference_instances=["inference_instance"], + transform_instances=["transform_instance"], + model_package_group_name="mpg", + depends_on=["MyStep1", model_step], + ) + + # StepCollection objects are depended on by a normal step in the ConditionStep + custom_step2 = CustomStep( + name="MyStep2", depends_on=["MyStep1", model_step, register_model_name] + ) + # StepCollection objects are depended on by a normal step outside the ConditionStep + custom_step3 = CustomStep( + name="MyStep3", depends_on=[custom_step1, model_step_name, register_model] + ) + + cond_step = ConditionStep( + name="CondStep", + conditions=[ConditionEquals(left=2, right=1)], + if_steps=[], + else_steps=[model_step, register_model, custom_step2], + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[cond_step, custom_step1, custom_step3], + ) + step_list = json.loads(pipeline.definition())["Steps"] + assert len(step_list) == 3 + for step in step_list: + if step["Name"] == "MyStep1": + assert "DependsOn" not in step + elif step["Name"] == "CondStep": + assert not step["Arguments"]["IfSteps"] + for sub_step in step["Arguments"]["ElseSteps"]: + if sub_step["Name"] == f"{model_name}-RepackModel": + assert set(sub_step["DependsOn"]) == { + "MyStep1", + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", + } + if sub_step["Name"] == "MyStep2": + assert set(sub_step["DependsOn"]) == { + "MyStep1", + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", + f"{model_name}-RepackModel", + f"{register_model_name}-RegisterModel", + } + else: + assert set(step["DependsOn"]) == { + "MyStep1", + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", + f"{model_name}-RepackModel", + f"{register_model_name}-RegisterModel", + } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "CondStep": ["MyModel-RepackModel", "MyModelStep-RepackModel-MyModel", "MyStep2"], + "MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"], + "MyStep2": [], + "MyStep3": [], + "MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"], + "MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"], + "MyModel-RepackModel": [], + "RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"], + } + ) + + +def test_condition_step_depends_on_step_collection(): + step1 = CustomStep(name="MyStep1") + step2 = CustomStep(name="MyStep2", input_data=step1.properties) + step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2]) + cond_step = ConditionStep( + name="MyConditionStep", + depends_on=[step_collection], + conditions=[ConditionEquals(left=2, right=1)], + if_steps=[], + else_steps=[], + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[step_collection, cond_step], + ) + step_list = json.loads(pipeline.definition())["Steps"] + assert len(step_list) == 3 + for step in step_list: + if step["Name"] != "MyConditionStep": + continue + assert step == { + "Name": "MyConditionStep", + "Type": "Condition", + "DependsOn": ["MyStep1", "MyStep2"], + "Arguments": { + "Conditions": [ + { + "Type": "Equals", + "LeftValue": 2, + "RightValue": 1, + }, + ], + "IfSteps": [], + "ElseSteps": [], + }, + } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + [("MyConditionStep", []), ("MyStep1", ["MyStep2"]), ("MyStep2", ["MyConditionStep"])] + ) + + def test_register_model(estimator, model_metrics, drift_check_baselines): model_data = f"s3://{BUCKET}/model.tar.gz" register_model = RegisterModel(