Skip to content

feat: Selective Step Execution milestone 2 features #4158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 100 additions & 3 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,10 @@ def start(
A `_PipelineExecution` instance, if successful.
"""
if selective_execution_config is not None:
if selective_execution_config.source_pipeline_execution_arn is None:
if (
selective_execution_config.source_pipeline_execution_arn is None
and selective_execution_config.reference_latest_execution
):
selective_execution_config.source_pipeline_execution_arn = (
self._get_latest_execution_arn()
)
Expand Down Expand Up @@ -425,8 +428,8 @@ def list_executions(
sort_by (str): The field by which to sort results(CreationTime/PipelineExecutionArn).
sort_order (str): The sort order for results (Ascending/Descending).
max_results (int): The maximum number of pipeline executions to return in the response.
next_token (str): If the result of the previous ListPipelineExecutions request was
truncated, the response includes a NextToken. To retrieve the next set of pipeline
next_token (str): If the result of the previous `ListPipelineExecutions` request was
truncated, the response includes a `NextToken`. To retrieve the next set of pipeline
executions, use the token in the next request.

Returns:
Expand Down Expand Up @@ -463,6 +466,76 @@ def _get_latest_execution_arn(self):
return response["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]
return None

def build_parameters_from_execution(
self,
pipeline_execution_arn: str,
parameter_value_overrides: Dict[str, Union[str, bool, int, float]] = None,
) -> Dict[str, Union[str, bool, int, float]]:
"""Gets the parameters from an execution, update with optional parameter value overrides.

Args:
pipeline_execution_arn (str): The arn of the reference pipeline execution.
parameter_value_overrides (Dict[str, Union[str, bool, int, float]]): Parameter dict
to be updated with the parameters from the referenced execution.

Returns:
A parameter dict built from an execution and provided parameter value overrides.
"""
execution_parameters = self._get_parameters_for_execution(pipeline_execution_arn)
if parameter_value_overrides is not None:
self._validate_parameter_overrides(
pipeline_execution_arn, execution_parameters, parameter_value_overrides
)
execution_parameters.update(parameter_value_overrides)
return execution_parameters

def _get_parameters_for_execution(self, pipeline_execution_arn: str) -> Dict[str, str]:
"""Gets all the parameters from an execution.

Args:
pipeline_execution_arn (str): The arn of the pipeline execution.

Returns:
A parameter dict from the execution.
"""
pipeline_execution = _PipelineExecution(
arn=pipeline_execution_arn,
sagemaker_session=self.sagemaker_session,
)

response = pipeline_execution.list_parameters()
parameter_list = response["PipelineParameters"]
while response.get("NextToken") is not None:
response = pipeline_execution.list_parameters(next_token=response["NextToken"])
parameter_list.extend(response["PipelineParameters"])

return {parameter["Name"]: parameter["Value"] for parameter in parameter_list}

@staticmethod
def _validate_parameter_overrides(
pipeline_execution_arn: str,
execution_parameters: Dict[str, str],
parameter_overrides: Dict[str, Union[str, bool, int, float]],
):
"""Validates the parameter overrides are present in the execution parameters.

Args:
pipeline_execution_arn (str): The arn of the pipeline execution.
execution_parameters (Dict[str, str]): A parameter dict from the execution.
parameter_overrides (Dict[str, Union[str, bool, int, float]]): Parameter dict to be
updated in the parameters from the referenced execution.

Raises:
ValueError: If any parameters in parameter overrides is not present in the
execution parameters.
"""
invalid_parameters = set(parameter_overrides) - set(execution_parameters)
if invalid_parameters:
raise ValueError(
f"The following parameter overrides provided: {str(invalid_parameters)} "
+ f"are not present in the pipeline execution: {pipeline_execution_arn}"
)


def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Formats start parameter overrides as a list of dicts.
Expand Down Expand Up @@ -652,6 +725,30 @@ def list_steps(self):
)
return response["PipelineExecutionSteps"]

def list_parameters(self, max_results: int = None, next_token: str = None):
"""Gets a list of parameters for a pipeline execution.

Args:
max_results (int): The maximum number of parameters to return in the response.
next_token (str): If the result of the previous `ListPipelineParametersForExecution`
request was truncated, the response includes a `NextToken`. To retrieve the next
set of parameters, use the token in the next request.

Returns:
Information about the parameters of the pipeline execution. This function is also
a wrapper for `list_pipeline_parameters_for_execution
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_pipeline_parameters_for_execution>`_.
"""
kwargs = dict(PipelineExecutionArn=self.arn)
update_args(
kwargs,
MaxResults=max_results,
NextToken=next_token,
)
return self.sagemaker_session.sagemaker_client.list_pipeline_parameters_for_execution(
**kwargs
)

def wait(self, delay=30, max_attempts=60):
"""Waits for a pipeline execution.

Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/workflow/selective_execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ class SelectiveExecutionConfig:
another SageMaker pipeline run.
"""

def __init__(self, selected_steps: List[str], source_pipeline_execution_arn: str = None):
def __init__(
self,
selected_steps: List[str],
source_pipeline_execution_arn: str = None,
reference_latest_execution: bool = True,
):
"""Create a `SelectiveExecutionConfig`.

Args:
Expand All @@ -32,9 +37,12 @@ def __init__(self, selected_steps: List[str], source_pipeline_execution_arn: str
`Succeeded`.
selected_steps (List[str]): A list of pipeline steps to run. All step(s) in all
path(s) between two selected steps should be included.
reference_latest_execution (bool): Whether to reference the latest execution if
`source_pipeline_execution_arn` is not provided.
"""
self.source_pipeline_execution_arn = source_pipeline_execution_arn
self.selected_steps = selected_steps
self.reference_latest_execution = reference_latest_execution

def _build_selected_steps_from_list(self) -> RequestType:
"""Get the request structure for the list of selected steps."""
Expand Down
122 changes: 121 additions & 1 deletion tests/unit/sagemaker/workflow/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest

from mock import Mock, patch
from mock import Mock, call, patch

from sagemaker import s3
from sagemaker.session_settings import SessionSettings
Expand Down Expand Up @@ -492,8 +492,10 @@ def test_pipeline_start_selective_execution(sagemaker_session_mock):
"SourcePipelineExecutionArn": "foo-arn",
},
)
sagemaker_session_mock.reset_mock()

# Case 2: Start selective execution without SourcePipelineExecutionArn
# References latest execution by default.
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.return_value = {
"PipelineExecutionSummaries": [
{
Expand Down Expand Up @@ -523,6 +525,27 @@ def test_pipeline_start_selective_execution(sagemaker_session_mock):
"SourcePipelineExecutionArn": "my:latest:execution:arn",
},
)
sagemaker_session_mock.reset_mock()

# Case 3: Start selective execution without SourcePipelineExecutionArn
# Opts not to reference latest execution.
selective_execution_config = SelectiveExecutionConfig(
selected_steps=["step-1", "step-2", "step-3"],
reference_latest_execution=False,
)
pipeline.start(selective_execution_config=selective_execution_config)
sagemaker_session_mock.sagemaker_client.list_pipeline_executions.assert_not_called()
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.assert_called_with(
PipelineName="MyPipeline",
SelectiveExecutionConfig={
"SelectedSteps": [
{"StepName": "step-1"},
{"StepName": "step-2"},
{"StepName": "step-3"},
],
},
)
sagemaker_session_mock.reset_mock()


def test_pipeline_basic():
Expand Down Expand Up @@ -718,13 +741,99 @@ def test_pipeline_list_executions(sagemaker_session_mock):
assert executions["NextToken"] == "token"


def test_pipeline_build_parameters_from_execution(sagemaker_session_mock):
pipeline = Pipeline(
name="MyPipeline",
sagemaker_session=sagemaker_session_mock,
)
reference_execution_arn = "reference_execution_arn"
parameter_value_overrides = {"TestParameterName": "NewParameterValue"}
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}]
}
parameters = pipeline.build_parameters_from_execution(
pipeline_execution_arn=reference_execution_arn,
parameter_value_overrides=parameter_value_overrides,
)
assert (
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
PipelineExecutionArn=reference_execution_arn
)
)
assert len(parameters) == 1
assert parameters["TestParameterName"] == "NewParameterValue"


def test_pipeline_build_parameters_from_execution_with_invalid_overrides(sagemaker_session_mock):
pipeline = Pipeline(
name="MyPipeline",
sagemaker_session=sagemaker_session_mock,
)
reference_execution_arn = "reference_execution_arn"
invalid_parameter_value_overrides = {"InvalidParameterName": "Value"}
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}]
}
with pytest.raises(ValueError) as error:
pipeline.build_parameters_from_execution(
pipeline_execution_arn=reference_execution_arn,
parameter_value_overrides=invalid_parameter_value_overrides,
)
assert (
f"The following parameter overrides provided: {str(set(invalid_parameter_value_overrides.keys()))} "
+ f"are not present in the pipeline execution: {reference_execution_arn}"
in str(error)
)
assert (
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
PipelineExecutionArn=reference_execution_arn
)
)


def test_pipeline_build_parameters_from_execution_with_paginated_result(sagemaker_session_mock):
pipeline = Pipeline(
name="MyPipeline",
sagemaker_session=sagemaker_session_mock,
)
reference_execution_arn = "reference_execution_arn"
next_token = "token"
first_page_response = {
"PipelineParameters": [{"Name": "TestParameterName1", "Value": "TestParameterValue1"}],
"NextToken": next_token,
}
second_page_response = {
"PipelineParameters": [{"Name": "TestParameterName2", "Value": "TestParameterValue2"}],
}
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.side_effect = [
first_page_response,
second_page_response,
]
parameters = pipeline.build_parameters_from_execution(
pipeline_execution_arn=reference_execution_arn
)
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.assert_has_calls(
[
call(PipelineExecutionArn=reference_execution_arn),
call(PipelineExecutionArn=reference_execution_arn, NextToken=next_token),
]
)
assert len(parameters) == 2
assert parameters["TestParameterName1"] == "TestParameterValue1"
assert parameters["TestParameterName2"] == "TestParameterValue2"


def test_pipeline_execution_basics(sagemaker_session_mock):
sagemaker_session_mock.sagemaker_client.start_pipeline_execution.return_value = {
"PipelineExecutionArn": "my:arn"
}
sagemaker_session_mock.sagemaker_client.list_pipeline_execution_steps.return_value = {
"PipelineExecutionSteps": [Mock()]
}
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.return_value = {
"PipelineParameters": [{"Name": "TestParameterName", "Value": "TestParameterValue"}],
"NextToken": "token",
}
pipeline = Pipeline(
name="MyPipeline",
parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")],
Expand All @@ -745,6 +854,17 @@ def test_pipeline_execution_basics(sagemaker_session_mock):
PipelineExecutionArn="my:arn"
)
assert len(steps) == 1
list_parameters_response = execution.list_parameters()
assert (
sagemaker_session_mock.sagemaker_client.list_pipeline_parameters_for_execution.called_with(
PipelineExecutionArn="my:arn"
)
)
parameter_list = list_parameters_response["PipelineParameters"]
assert len(parameter_list) == 1
assert parameter_list[0]["Name"] == "TestParameterName"
assert parameter_list[0]["Value"] == "TestParameterValue"
assert list_parameters_response["NextToken"] == "token"


def _generate_large_pipeline_steps(input_data: object):
Expand Down