Skip to content

Commit f9560fc

Browse files
nmadanNamrata Madan
authored andcommitted
fix: pipelines local mode minor bug fixes
Co-authored-by: Namrata Madan <[email protected]>
1 parent 6e8939c commit f9560fc

File tree

7 files changed

+83
-75
lines changed

7 files changed

+83
-75
lines changed

src/sagemaker/local/entities.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,8 @@ def start(self, **kwargs):
669669
execution = _LocalPipelineExecution(execution_id, self.pipeline, **kwargs)
670670

671671
self._executions[execution_id] = execution
672-
logger.info(
673-
"Starting execution for pipeline %s. Execution ID is %s",
674-
self.pipeline.name,
675-
execution_id,
672+
print(
673+
f"Starting execution for pipeline {self.pipeline.name}. Execution ID is {execution_id}"
676674
)
677675
self.last_modified_time = datetime.datetime.now().timestamp()
678676

@@ -690,6 +688,8 @@ def __init__(
690688
PipelineExecutionDescription=None,
691689
PipelineExecutionDisplayName=None,
692690
):
691+
from sagemaker.workflow.pipeline import PipelineGraph
692+
693693
self.pipeline = pipeline
694694
self.pipeline_execution_name = execution_id
695695
self.pipeline_execution_description = PipelineExecutionDescription
@@ -699,7 +699,8 @@ def __init__(
699699
self.creation_time = datetime.datetime.now().timestamp()
700700
self.last_modified_time = self.creation_time
701701
self.step_execution = {}
702-
self._initialize_step_execution(self.pipeline.steps)
702+
self.pipeline_dag = PipelineGraph.from_pipeline(self.pipeline)
703+
self._initialize_step_execution(self.pipeline_dag.step_map.values())
703704
self.pipeline_parameters = self._initialize_and_validate_parameters(PipelineParameters)
704705
self._blocked_steps = {}
705706

@@ -732,37 +733,36 @@ def update_execution_success(self):
732733
"""Mark execution as succeeded."""
733734
self.status = _LocalExecutionStatus.SUCCEEDED.value
734735
self.last_modified_time = datetime.datetime.now().timestamp()
735-
logger.info("Pipeline execution %s SUCCEEDED", self.pipeline_execution_name)
736+
print(f"Pipeline execution {self.pipeline_execution_name} SUCCEEDED")
736737

737738
def update_execution_failure(self, step_name, failure_message):
738739
"""Mark execution as failed."""
739740
self.status = _LocalExecutionStatus.FAILED.value
740-
self.failure_reason = f"Step {step_name} failed with message: {failure_message}"
741+
self.failure_reason = f"Step '{step_name}' failed with message: {failure_message}"
741742
self.last_modified_time = datetime.datetime.now().timestamp()
742-
logger.info(
743-
"Pipeline execution %s FAILED because step %s failed.",
744-
self.pipeline_execution_name,
745-
step_name,
743+
print(
744+
f"Pipeline execution {self.pipeline_execution_name} FAILED because step "
745+
f"'{step_name}' failed."
746746
)
747747

748748
def update_step_properties(self, step_name, step_properties):
749749
"""Update pipeline step execution output properties."""
750750
self.step_execution.get(step_name).update_step_properties(step_properties)
751-
logger.info("Pipeline step %s SUCCEEDED.", step_name)
751+
print(f"Pipeline step '{step_name}' SUCCEEDED.")
752752

753753
def update_step_failure(self, step_name, failure_message):
754754
"""Mark step_name as failed."""
755+
print(f"Pipeline step '{step_name}' FAILED. Failure message is: {failure_message}")
755756
self.step_execution.get(step_name).update_step_failure(failure_message)
756-
logger.info("Pipeline step %s FAILED. Failure message is: %s", step_name, failure_message)
757757

758758
def mark_step_executing(self, step_name):
759759
"""Update pipelines step's status to EXECUTING and start_time to now."""
760-
logger.info("Starting pipeline step: %s", step_name)
760+
print(f"Starting pipeline step: '{step_name}'")
761761
self.step_execution.get(step_name).mark_step_executing()
762762

763763
def _initialize_step_execution(self, steps):
764764
"""Initialize step_execution dict."""
765-
from sagemaker.workflow.steps import StepTypeEnum
765+
from sagemaker.workflow.steps import StepTypeEnum, Step
766766

767767
supported_steps_types = (
768768
StepTypeEnum.TRAINING,
@@ -774,16 +774,17 @@ def _initialize_step_execution(self, steps):
774774
)
775775

776776
for step in steps:
777-
if step.step_type not in supported_steps_types:
778-
error_msg = self._construct_validation_exception_message(
779-
"Step type {} is not supported in local mode.".format(step.step_type.value)
777+
if isinstance(step, Step):
778+
if step.step_type not in supported_steps_types:
779+
error_msg = self._construct_validation_exception_message(
780+
"Step type {} is not supported in local mode.".format(step.step_type.value)
781+
)
782+
raise ClientError(error_msg, "start_pipeline_execution")
783+
self.step_execution[step.name] = _LocalPipelineExecutionStep(
784+
step.name, step.step_type, step.description, step.display_name
780785
)
781-
raise ClientError(error_msg, "start_pipeline_execution")
782-
self.step_execution[step.name] = _LocalPipelineExecutionStep(
783-
step.name, step.step_type, step.description, step.display_name
784-
)
785-
if step.step_type == StepTypeEnum.CONDITION:
786-
self._initialize_step_execution(step.if_steps + step.else_steps)
786+
if step.step_type == StepTypeEnum.CONDITION:
787+
self._initialize_step_execution(step.if_steps + step.else_steps)
787788

788789
def _initialize_and_validate_parameters(self, overridden_parameters):
789790
"""Initialize and validate pipeline parameters."""

src/sagemaker/local/local_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def update_pipeline(
448448
}
449449
raise ClientError(error_response, "update_pipeline")
450450
LocalSagemakerClient._pipelines[pipeline.name].pipeline_description = pipeline_description
451+
LocalSagemakerClient._pipelines[pipeline.name].pipeline = pipeline
451452
LocalSagemakerClient._pipelines[
452453
pipeline.name
453454
].last_modified_time = datetime.now().timestamp()

src/sagemaker/local/pipeline.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import json
1818
from copy import deepcopy
1919
from datetime import datetime
20-
from typing import Dict, List
20+
from typing import Dict, List, Union
2121
from botocore.exceptions import ClientError
2222

2323
from sagemaker.workflow.conditions import ConditionTypeEnum
2424
from sagemaker.workflow.steps import StepTypeEnum, Step
25+
from sagemaker.workflow.step_collections import StepCollection
2526
from sagemaker.workflow.entities import PipelineVariable
2627
from sagemaker.workflow.parameters import Parameter
2728
from sagemaker.workflow.functions import Join, JsonGet, PropertyFile
@@ -256,8 +257,7 @@ def execute(self):
256257
return self.pipline_executor.local_sagemaker_client.describe_training_job(job_name)
257258
except Exception as e: # pylint: disable=W0703
258259
self.pipline_executor.execution.update_step_failure(
259-
self.step.name,
260-
f"Error when executing step {self.step.name} of type {type(self.step)}: {e}",
260+
self.step.name, f"{type(e).__name__}: {str(e)}"
261261
)
262262

263263

@@ -291,21 +291,18 @@ def execute(self):
291291

292292
except Exception as e: # pylint: disable=W0703
293293
self.pipline_executor.execution.update_step_failure(
294-
self.step.name,
295-
f"Error when executing step {self.step.name} of type {type(self.step)}: {e}",
294+
self.step.name, f"{type(e).__name__}: {str(e)}"
296295
)
297296

298297

299298
class _ConditionStepExecutor(_StepExecutor):
300299
"""Executor class to execute ConditionStep locally"""
301300

302301
def execute(self):
303-
def _block_all_downstream_steps(steps: List[Step]):
302+
def _block_all_downstream_steps(steps: List[Union[Step, StepCollection]]):
304303
steps_to_block = set()
305304
for step in steps:
306-
steps_to_block.update(
307-
self.pipline_executor.pipeline_dag.get_steps_in_sub_dag(step.name)
308-
)
305+
steps_to_block.update(self.pipline_executor.pipeline_dag.get_steps_in_sub_dag(step))
309306
self.pipline_executor._blocked_steps.update(steps_to_block)
310307

311308
if_steps = self.step.if_steps
@@ -469,8 +466,7 @@ def execute(self):
469466
return self.pipline_executor.local_sagemaker_client.describe_transform_job(job_name)
470467
except Exception as e: # pylint: disable=W0703
471468
self.pipline_executor.execution.update_step_failure(
472-
self.step.name,
473-
f"Error when executing step {self.step.name} of type {type(self.step)}: {e}",
469+
self.step.name, f"{type(e).__name__}: {str(e)}"
474470
)
475471

476472

@@ -485,8 +481,7 @@ def execute(self):
485481
return self.pipline_executor.local_sagemaker_client.describe_model(model_name)
486482
except Exception as e: # pylint: disable=W0703
487483
self.pipline_executor.execution.update_step_failure(
488-
self.step.name,
489-
f"Error when executing step {self.step.name} of type {type(self.step)}: {e}",
484+
self.step.name, f"{type(e).__name__}: {str(e)}"
490485
)
491486

492487

src/sagemaker/workflow/pipeline.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from sagemaker import s3
2727
from sagemaker._studio import _append_project_tags
2828
from sagemaker.session import Session
29-
from sagemaker.local import LocalSession
3029
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
3130
from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep
3231
from sagemaker.workflow.entities import (
@@ -162,9 +161,7 @@ def _create_args(
162161

163162
# If pipeline definition is large, upload to S3 bucket and
164163
# provide PipelineDefinitionS3Location to request instead.
165-
if len(pipeline_definition.encode("utf-8")) < 1024 * 100 or isinstance(
166-
self.sagemaker_session, LocalSession
167-
):
164+
if len(pipeline_definition.encode("utf-8")) < 1024 * 100:
168165
kwargs["PipelineDefinition"] = pipeline_definition
169166
else:
170167
desired_s3_uri = s3.s3_path_join(
@@ -660,19 +657,28 @@ def is_cyclic_helper(current_step):
660657
return True
661658
return False
662659

663-
def get_steps_in_sub_dag(self, current_step: str, steps: Set[str] = None) -> Set[str]:
660+
def get_steps_in_sub_dag(
661+
self, current_step: Union[Step, StepCollection], sub_dag_steps: Set[str] = None
662+
) -> Set[str]:
664663
"""Get names of all steps (including current step) in the sub dag of current step.
665664
666665
Returns a set of step names in the sub dag.
667666
"""
668-
if steps is None:
669-
steps = set()
670-
if current_step not in self.adjacency_list:
671-
raise ValueError("Step: %s does not exist in the pipeline." % current_step)
672-
steps.add(current_step)
673-
for step in self.adjacency_list[current_step]:
674-
self.get_steps_in_sub_dag(step, steps)
675-
return steps
667+
if sub_dag_steps is None:
668+
sub_dag_steps = set()
669+
670+
if isinstance(current_step, StepCollection):
671+
current_steps = current_step.steps
672+
else:
673+
current_steps = [current_step]
674+
675+
for step in current_steps:
676+
if step.name not in self.adjacency_list:
677+
raise ValueError("Step: %s does not exist in the pipeline." % step.name)
678+
sub_dag_steps.add(step.name)
679+
for sub_step in self.adjacency_list[step.name]:
680+
self.get_steps_in_sub_dag(self.step_map.get(sub_step), sub_dag_steps)
681+
return sub_dag_steps
676682

677683
def __iter__(self):
678684
"""Perform topological sort traversal of the Pipeline Graph."""

tests/integ/test_local_mode.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030

3131
from sagemaker.model import Model
3232
from sagemaker.transformer import Transformer
33-
from sagemaker.inputs import CreateModelInput
3433
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
3534
from sagemaker.sklearn.processing import SKLearnProcessor
3635
from sagemaker.workflow.pipeline import Pipeline
37-
from sagemaker.workflow.steps import TrainingStep, ProcessingStep, TransformStep, CreateModelStep
36+
from sagemaker.workflow.steps import TrainingStep, ProcessingStep, TransformStep
37+
from sagemaker.workflow.model_step import ModelStep
3838
from sagemaker.workflow.parameters import ParameterInteger
3939
from sagemaker.workflow.condition_step import ConditionStep
4040
from sagemaker.workflow.fail_step import FailStep
@@ -546,8 +546,8 @@ def test_local_pipeline_with_training_and_transform_steps(
546546
mxnet_training_latest_py_version,
547547
tmpdir,
548548
):
549-
instance_count = ParameterInteger(name="InstanceCountParam")
550549
session = LocalPipelineNoS3Session()
550+
instance_count = ParameterInteger(name="InstanceCountParam")
551551
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
552552
script_path = os.path.join(data_path, "check_env.py")
553553
output_path = "file://%s" % (str(tmpdir))
@@ -587,19 +587,12 @@ def test_local_pipeline_with_training_and_transform_steps(
587587
)
588588

589589
# define create model step
590-
inputs = CreateModelInput(
591-
instance_type="local",
592-
accelerator_type="local",
593-
)
594-
create_model_step = CreateModelStep(
595-
name="mxnet_mnist_model",
596-
model=model,
597-
inputs=inputs,
598-
)
590+
model_step_args = model.create(instance_type="local", accelerator_type="local")
591+
model_step = ModelStep(name="mxnet_mnist_model", step_args=model_step_args)
599592

600593
# define transformer
601594
transformer = Transformer(
602-
model_name=create_model_step.properties.ModelName,
595+
model_name=model_step.properties.ModelName,
603596
instance_type="local",
604597
instance_count=instance_count,
605598
output_path=output_path,
@@ -619,7 +612,7 @@ def test_local_pipeline_with_training_and_transform_steps(
619612
pipeline = Pipeline(
620613
name="local_pipeline_training_transform",
621614
parameters=[instance_count],
622-
steps=[training_step, create_model_step, transform_step],
615+
steps=[training_step, model_step, transform_step],
623616
sagemaker_session=session,
624617
)
625618

tests/unit/sagemaker/local/test_local_session.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,16 +887,25 @@ def test_create_describe_update_pipeline():
887887
steps=[CustomStep(name="MyStep", input_data=parameter)],
888888
sagemaker_session=LocalSession(),
889889
)
890+
definition = pipeline.definition()
890891
pipeline.create("dummy-role", "pipeline-description")
891892

892893
pipeline_describe_response1 = pipeline.describe()
893894
assert pipeline_describe_response1["PipelineArn"] == "MyPipeline"
894-
assert pipeline_describe_response1["PipelineDefinition"] == pipeline.definition()
895+
assert pipeline_describe_response1["PipelineDefinition"] == definition
895896
assert pipeline_describe_response1["PipelineDescription"] == "pipeline-description"
896897

898+
pipeline = Pipeline(
899+
name="MyPipeline",
900+
parameters=[parameter],
901+
steps=[CustomStep(name="MyStepUpdated", input_data=parameter)],
902+
sagemaker_session=LocalSession(),
903+
)
904+
updated_definition = pipeline.definition()
897905
pipeline.update("dummy-role", "pipeline-description-2")
898906
pipeline_describe_response2 = pipeline.describe()
899907
assert pipeline_describe_response2["PipelineDescription"] == "pipeline-description-2"
908+
assert pipeline_describe_response2["PipelineDefinition"] == updated_definition
900909
assert (
901910
pipeline_describe_response2["CreationTime"]
902911
!= pipeline_describe_response2["LastModifiedTime"]

tests/unit/sagemaker/workflow/test_pipeline_graph.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def pipeline_graph_get_sub_dag(sagemaker_session_mock):
6161
step_f = CustomStep(name="stepF", depends_on=[step_c])
6262
step_g = CustomStep(name="stepG", depends_on=[step_e, step_d])
6363
step_h = CustomStep(name="stepH", depends_on=[step_g])
64-
step_i = CustomStep(name="stepI", depends_on=[step_h])
64+
step_i = CustomStepCollection(name="stepI", depends_on=[step_h])
6565
step_j = CustomStep(name="stepJ", depends_on=[step_h])
6666

6767
pipeline = Pipeline(
@@ -312,7 +312,8 @@ def test_pipeline_graph_cyclic(sagemaker_session_mock):
312312
"stepF",
313313
"stepG",
314314
"stepH",
315-
"stepI",
315+
"stepI-0",
316+
"stepI-1",
316317
"stepJ",
317318
},
318319
),
@@ -326,22 +327,24 @@ def test_pipeline_graph_cyclic(sagemaker_session_mock):
326327
"stepF",
327328
"stepG",
328329
"stepH",
329-
"stepI",
330+
"stepI-0",
331+
"stepI-1",
330332
"stepJ",
331333
},
332334
),
333-
("stepC", {"stepC", "stepE", "stepF", "stepG", "stepH", "stepI", "stepJ"}),
334-
("stepD", {"stepD", "stepG", "stepH", "stepI", "stepJ"}),
335-
("stepE", {"stepE", "stepG", "stepH", "stepI", "stepJ"}),
335+
("stepC", {"stepC", "stepE", "stepF", "stepG", "stepH", "stepI-0", "stepI-1", "stepJ"}),
336+
("stepD", {"stepD", "stepG", "stepH", "stepI-0", "stepI-1", "stepJ"}),
337+
("stepE", {"stepE", "stepG", "stepH", "stepI-0", "stepI-1", "stepJ"}),
336338
("stepF", {"stepF"}),
337-
("stepG", {"stepG", "stepH", "stepI", "stepJ"}),
338-
("stepH", {"stepH", "stepI", "stepJ"}),
339-
("stepI", {"stepI"}),
339+
("stepG", {"stepG", "stepH", "stepI-0", "stepI-1", "stepJ"}),
340+
("stepH", {"stepH", "stepI-0", "stepI-1", "stepJ"}),
341+
("stepI", {"stepI-0", "stepI-1"}),
340342
("stepJ", {"stepJ"}),
341343
],
342344
)
343345
def test_get_steps_in_sub_dag(pipeline_graph_get_sub_dag, step_name, expected_steps):
344-
sub_steps = pipeline_graph_get_sub_dag.get_steps_in_sub_dag(step_name)
346+
step = pipeline_graph_get_sub_dag.step_map.get(step_name)
347+
sub_steps = pipeline_graph_get_sub_dag.get_steps_in_sub_dag(step)
345348
assert sub_steps == expected_steps
346349

347350

0 commit comments

Comments
 (0)