Skip to content

Commit 4839eef

Browse files
shenlongtangstacichozuoyuanh
authored
feat: Selective Step Execution milestone 2 features (#4158)
* feature: method to build pipeline parameters from existing execution … (#951) * feature: method to build pipeline parameters from existing execution with optional value overrides * fix style check * assert error message in unit test * feature: allow opt out from referencing latest execution in the selec… (#1004) * fix: Update pipeline.py and selective_execution_config.py with small fixes (#1099) --------- Co-authored-by: stacicho <[email protected]> Co-authored-by: Zuoyuan Huang <[email protected]>
1 parent f646180 commit 4839eef

File tree

3 files changed

+230
-5
lines changed

3 files changed

+230
-5
lines changed

src/sagemaker/workflow/pipeline.py

+100-3
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,10 @@ def start(
344344
A `_PipelineExecution` instance, if successful.
345345
"""
346346
if selective_execution_config is not None:
347-
if selective_execution_config.source_pipeline_execution_arn is None:
347+
if (
348+
selective_execution_config.source_pipeline_execution_arn is None
349+
and selective_execution_config.reference_latest_execution
350+
):
348351
selective_execution_config.source_pipeline_execution_arn = (
349352
self._get_latest_execution_arn()
350353
)
@@ -425,8 +428,8 @@ def list_executions(
425428
sort_by (str): The field by which to sort results(CreationTime/PipelineExecutionArn).
426429
sort_order (str): The sort order for results (Ascending/Descending).
427430
max_results (int): The maximum number of pipeline executions to return in the response.
428-
next_token (str): If the result of the previous ListPipelineExecutions request was
429-
truncated, the response includes a NextToken. To retrieve the next set of pipeline
431+
next_token (str): If the result of the previous `ListPipelineExecutions` request was
432+
truncated, the response includes a `NextToken`. To retrieve the next set of pipeline
430433
executions, use the token in the next request.
431434
432435
Returns:
@@ -463,6 +466,76 @@ def _get_latest_execution_arn(self):
463466
return response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]
464467
return None
465468

469+
def build_parameters_from_execution(
470+
self,
471+
pipeline_execution_arn: str,
472+
parameter_value_overrides: Dict[str, Union[str, bool, int, float]] = None,
473+
) -> Dict[str, Union[str, bool, int, float]]:
474+
"""Gets the parameters from an execution, update with optional parameter value overrides.
475+
476+
Args:
477+
pipeline_execution_arn (str): The arn of the reference pipeline execution.
478+
parameter_value_overrides (Dict[str, Union[str, bool, int, float]]): Parameter dict
479+
to be updated with the parameters from the referenced execution.
480+
481+
Returns:
482+
A parameter dict built from an execution and provided parameter value overrides.
483+
"""
484+
execution_parameters = self._get_parameters_for_execution(pipeline_execution_arn)
485+
if parameter_value_overrides is not None:
486+
self._validate_parameter_overrides(
487+
pipeline_execution_arn, execution_parameters, parameter_value_overrides
488+
)
489+
execution_parameters.update(parameter_value_overrides)
490+
return execution_parameters
491+
492+
def _get_parameters_for_execution(self, pipeline_execution_arn: str) -> Dict[str, str]:
493+
"""Gets all the parameters from an execution.
494+
495+
Args:
496+
pipeline_execution_arn (str): The arn of the pipeline execution.
497+
498+
Returns:
499+
A parameter dict from the execution.
500+
"""
501+
pipeline_execution = _PipelineExecution(
502+
arn=pipeline_execution_arn,
503+
sagemaker_session=self.sagemaker_session,
504+
)
505+
506+
response = pipeline_execution.list_parameters()
507+
parameter_list = response["PipelineParameters"]
508+
while response.get("NextToken") is not None:
509+
response = pipeline_execution.list_parameters(next_token=response["NextToken"])
510+
parameter_list.extend(response["PipelineParameters"])
511+
512+
return {parameter["Name"]: parameter["Value"] for parameter in parameter_list}
513+
514+
@staticmethod
515+
def _validate_parameter_overrides(
516+
pipeline_execution_arn: str,
517+
execution_parameters: Dict[str, str],
518+
parameter_overrides: Dict[str, Union[str, bool, int, float]],
519+
):
520+
"""Validates the parameter overrides are present in the execution parameters.
521+
522+
Args:
523+
pipeline_execution_arn (str): The arn of the pipeline execution.
524+
execution_parameters (Dict[str, str]): A parameter dict from the execution.
525+
parameter_overrides (Dict[str, Union[str, bool, int, float]]): Parameter dict to be
526+
updated in the parameters from the referenced execution.
527+
528+
Raises:
529+
ValueError: If any parameters in parameter overrides is not present in the
530+
execution parameters.
531+
"""
532+
invalid_parameters = set(parameter_overrides) - set(execution_parameters)
533+
if invalid_parameters:
534+
raise ValueError(
535+
f"The following parameter overrides provided: {str(invalid_parameters)} "
536+
+ f"are not present in the pipeline execution: {pipeline_execution_arn}"
537+
)
538+
466539

467540
def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
468541
"""Formats start parameter overrides as a list of dicts.
@@ -652,6 +725,30 @@ def list_steps(self):
652725
)
653726
return response["PipelineExecutionSteps"]
654727

728+
def list_parameters(self, max_results: int = None, next_token: str = None):
729+
"""Gets a list of parameters for a pipeline execution.
730+
731+
Args:
732+
max_results (int): The maximum number of parameters to return in the response.
733+
next_token (str): If the result of the previous `ListPipelineParametersForExecution`
734+
request was truncated, the response includes a `NextToken`. To retrieve the next
735+
set of parameters, use the token in the next request.
736+
737+
Returns:
738+
Information about the parameters of the pipeline execution. This function is also
739+
a wrapper for `list_pipeline_parameters_for_execution
740+
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_pipeline_parameters_for_execution>`_.
741+
"""
742+
kwargs = dict(PipelineExecutionArn=self.arn)
743+
update_args(
744+
kwargs,
745+
MaxResults=max_results,
746+
NextToken=next_token,
747+
)
748+
return self.sagemaker_session.sagemaker_client.list_pipeline_parameters_for_execution(
749+
**kwargs
750+
)
751+
655752
def wait(self, delay=30, max_attempts=60):
656753
"""Waits for a pipeline execution.
657754

src/sagemaker/workflow/selective_execution_config.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ class SelectiveExecutionConfig:
2222
another SageMaker pipeline run.
2323
"""
2424

25-
def __init__(self, selected_steps: List[str], source_pipeline_execution_arn: str = None):
25+
def __init__(
26+
self,
27+
selected_steps: List[str],
28+
source_pipeline_execution_arn: str = None,
29+
reference_latest_execution: bool = True,
30+
):
2631
"""Create a `SelectiveExecutionConfig`.
2732
2833
Args:
@@ -32,9 +37,12 @@ def __init__(self, selected_steps: List[str], source_pipeline_execution_arn: str
3237
`Succeeded`.
3338
selected_steps (List[str]): A list of pipeline steps to run. All step(s) in all
3439
path(s) between two selected steps should be included.
40+
reference_latest_execution (bool): Whether to reference the latest execution if
41+
`source_pipeline_execution_arn` is not provided.
3542
"""
3643
self.source_pipeline_execution_arn = source_pipeline_execution_arn
3744
self.selected_steps = selected_steps
45+
self.reference_latest_execution = reference_latest_execution
3846

3947
def _build_selected_steps_from_list(self) -> RequestType:
4048
"""Get the request structure for the list of selected steps."""

tests/unit/sagemaker/workflow/test_pipeline.py

+121-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from mock import Mock, patch
20+
from mock import Mock, call, patch
2121

2222
from sagemaker import s3
2323
from sagemaker.session_settings import SessionSettings
@@ -492,8 +492,10 @@ def test_pipeline_start_selective_execution(sagemaker_session_mock):
492492
"SourcePipelineExecutionArn": "foo-arn",
493493
},
494494
)
495+
sagemaker_session_mock.reset_mock()
495496

496497
# Case 2: Start selective execution without SourcePipelineExecutionArn
498+
# References latest execution by default.
497499
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
498500
"PipelineExecutionSummaries": [
499501
{
@@ -523,6 +525,27 @@ def test_pipeline_start_selective_execution(sagemaker_session_mock):
523525
"SourcePipelineExecutionArn": "my:latest:execution:arn",
524526
},
525527
)
528+
sagemaker_session_mock.reset_mock()
529+
530+
# Case 3: Start selective execution without SourcePipelineExecutionArn
531+
# Opts not to reference latest execution.
532+
selective_execution_config = SelectiveExecutionConfig(
533+
selected_steps=["step-1", "step-2", "step-3"],
534+
reference_latest_execution=False,
535+
)
536+
pipeline.start(selective_execution_config=selective_execution_config)
537+
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.assert_not_called()
538+
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
539+
PipelineName="MyPipeline",
540+
SelectiveExecutionConfig={
541+
"SelectedSteps": [
542+
{"StepName": "step-1"},
543+
{"StepName": "step-2"},
544+
{"StepName": "step-3"},
545+
],
546+
},
547+
)
548+
sagemaker_session_mock.reset_mock()
526549

527550

528551
def test_pipeline_basic():
@@ -718,13 +741,99 @@ def test_pipeline_list_executions(sagemaker_session_mock):
718741
assert executions["NextToken"] == "token"
719742

720743

744+
def test_pipeline_build_parameters_from_execution(sagemaker_session_mock):
745+
pipeline = Pipeline(
746+
name="MyPipeline",
747+
sagemaker_session=sagemaker_session_mock,
748+
)
749+
reference_execution_arn = "reference_execution_arn"
750+
parameter_value_overrides = {"TestParameterName": "NewParameterValue"}
751+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
752+
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}]
753+
}
754+
parameters = pipeline.build_parameters_from_execution(
755+
pipeline_execution_arn=reference_execution_arn,
756+
parameter_value_overrides=parameter_value_overrides,
757+
)
758+
assert (
759+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
760+
PipelineExecutionArn=reference_execution_arn
761+
)
762+
)
763+
assert len(parameters) == 1
764+
assert parameters["TestParameterName"] == "NewParameterValue"
765+
766+
767+
def test_pipeline_build_parameters_from_execution_with_invalid_overrides(sagemaker_session_mock):
768+
pipeline = Pipeline(
769+
name="MyPipeline",
770+
sagemaker_session=sagemaker_session_mock,
771+
)
772+
reference_execution_arn = "reference_execution_arn"
773+
invalid_parameter_value_overrides = {"InvalidParameterName": "Value"}
774+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
775+
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}]
776+
}
777+
with pytest.raises(ValueError) as error:
778+
pipeline.build_parameters_from_execution(
779+
pipeline_execution_arn=reference_execution_arn,
780+
parameter_value_overrides=invalid_parameter_value_overrides,
781+
)
782+
assert (
783+
f"The following parameter overrides provided: {str(set(invalid_parameter_value_overrides.keys()))} "
784+
+ f"are not present in the pipeline execution: {reference_execution_arn}"
785+
in str(error)
786+
)
787+
assert (
788+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
789+
PipelineExecutionArn=reference_execution_arn
790+
)
791+
)
792+
793+
794+
def test_pipeline_build_parameters_from_execution_with_paginated_result(sagemaker_session_mock):
795+
pipeline = Pipeline(
796+
name="MyPipeline",
797+
sagemaker_session=sagemaker_session_mock,
798+
)
799+
reference_execution_arn = "reference_execution_arn"
800+
next_token = "token"
801+
first_page_response = {
802+
"PipelineParameters": [{"Name": "TestParameterName1", "Value": "TestParameterValue1"}],
803+
"NextToken": next_token,
804+
}
805+
second_page_response = {
806+
"PipelineParameters": [{"Name": "TestParameterName2", "Value": "TestParameterValue2"}],
807+
}
808+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.side_effect = [
809+
first_page_response,
810+
second_page_response,
811+
]
812+
parameters = pipeline.build_parameters_from_execution(
813+
pipeline_execution_arn=reference_execution_arn
814+
)
815+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_has_calls(
816+
[
817+
call(PipelineExecutionArn=reference_execution_arn),
818+
call(PipelineExecutionArn=reference_execution_arn, NextToken=next_token),
819+
]
820+
)
821+
assert len(parameters) == 2
822+
assert parameters["TestParameterName1"] == "TestParameterValue1"
823+
assert parameters["TestParameterName2"] == "TestParameterValue2"
824+
825+
721826
def test_pipeline_execution_basics(sagemaker_session_mock):
722827
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
723828
"PipelineExecutionArn": "my:arn"
724829
}
725830
sagemaker_session_mock.sagemaker_client.list_pipeline_execution_steps.return_value = {
726831
"PipelineExecutionSteps": [Mock()]
727832
}
833+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
834+
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}],
835+
"NextToken": "token",
836+
}
728837
pipeline = Pipeline(
729838
name="MyPipeline",
730839
parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")],
@@ -745,6 +854,17 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
745854
PipelineExecutionArn="my:arn"
746855
)
747856
assert len(steps) == 1
857+
list_parameters_response = execution.list_parameters()
858+
assert (
859+
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
860+
PipelineExecutionArn="my:arn"
861+
)
862+
)
863+
parameter_list = list_parameters_response["PipelineParameters"]
864+
assert len(parameter_list) == 1
865+
assert parameter_list[0]["Name"] == "TestParameterName"
866+
assert parameter_list[0]["Value"] == "TestParameterValue"
867+
assert list_parameters_response["NextToken"] == "token"
748868

749869

750870
def _generate_large_pipeline_steps(input_data: object):

0 commit comments

Comments
 (0)