Skip to content

feat: Support selective pipeline execution between function step and regular step #4392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions src/sagemaker/workflow/function_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@
PipelineVariable,
)

from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.properties import Properties
from sagemaker.workflow.retry import RetryPolicy
from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.step_outputs import StepOutput
from sagemaker.workflow.step_outputs import StepOutput, get_step
from sagemaker.workflow.utilities import trim_request_dict, load_step_compilation_context

from sagemaker.s3_utils import s3_path_join
Expand Down Expand Up @@ -277,14 +276,12 @@ def _to_json_get(self) -> JsonGet:
"""Expression structure for workflow service calls using JsonGet resolution."""
from sagemaker.remote_function.core.stored_function import (
JSON_SERIALIZED_RESULT_KEY,
RESULTS_FOLDER,
JSON_RESULTS_FILE,
)

if not self._step.name:
raise ValueError("Step name is not defined.")

s3_root_uri = self._step._job_settings.s3_root_uri
# Resolve json path --
# Deserializer will be able to resolve a JsonGet using path "Return[1]" to
# access value 10 from following serialized JSON:
Expand All @@ -308,13 +305,9 @@ def _to_json_get(self) -> JsonGet:

return JsonGet(
s3_uri=Join(
"/",
[
s3_root_uri,
ExecutionVariables.PIPELINE_NAME,
ExecutionVariables.PIPELINE_EXECUTION_ID,
self._step.name,
RESULTS_FOLDER,
on="/",
values=[
get_step(self)._properties.OutputDataConfig.S3OutputPath,
JSON_RESULTS_FILE,
],
),
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/workflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.execution_variables import ExecutionVariable
from sagemaker.workflow.parameters import Parameter
from sagemaker.workflow.properties import PropertyFile
from sagemaker.workflow.properties import PropertyFile, Properties

if TYPE_CHECKING:
from sagemaker.workflow.steps import Step
Expand Down Expand Up @@ -172,9 +172,9 @@ def _validate_json_get_s3_uri(self):
for join_arg in s3_uri.values:
if not is_pipeline_variable(join_arg):
continue
if not isinstance(join_arg, (Parameter, ExecutionVariable)):
if not isinstance(join_arg, (Parameter, ExecutionVariable, Properties)):
raise ValueError(
f"Invalid JsonGet function {self.expr}. "
f"The Join values in JsonGet's s3_uri can only be a primitive object "
f"or Parameter or ExecutionVariable."
f"The Join values in JsonGet's s3_uri can only be a primitive object, "
f"Parameter, ExecutionVariable or Properties."
)
27 changes: 1 addition & 26 deletions src/sagemaker/workflow/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
RESOURCE_NOT_FOUND_EXCEPTION,
EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT,
)
from sagemaker.workflow.function_step import DelayedReturn
from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep
from sagemaker.workflow.entities import (
Expression,
Expand Down Expand Up @@ -725,10 +724,7 @@ def _interpolate(
pipeline_name (str): The name of the pipeline to be interpolated.
"""
if isinstance(obj, (Expression, Parameter, Properties, StepOutput)):
updated_obj = _replace_pipeline_name_in_json_get_s3_uri(
obj=obj, pipeline_name=pipeline_name
)
return updated_obj.expr
return obj.expr

if isinstance(obj, CallbackOutput):
step_name = callback_output_to_step_map[obj.output_name]
Expand Down Expand Up @@ -760,27 +756,6 @@ def _interpolate(
return new


# TODO: we should remove this once the ExecutionVariables.PIPELINE_NAME is fixed in backend
def _replace_pipeline_name_in_json_get_s3_uri(obj: Union[RequestType, Any], pipeline_name: str):
"""Replace the ExecutionVariables.PIPELINE_NAME in DelayedReturn's JsonGet s3_uri

with the pipeline_name, because ExecutionVariables.PIPELINE_NAME
is parsed as all lower-cased str in backend.
"""
if not isinstance(obj, DelayedReturn):
return obj

json_get = obj._to_json_get()

if not json_get.s3_uri:
return obj
# the s3 uri has to be a Join, which has been validated in JsonGet init
for i in range(len(json_get.s3_uri.values)):
if json_get.s3_uri.values[i] == ExecutionVariables.PIPELINE_NAME:
json_get.s3_uri.values[i] = pipeline_name
return json_get


def _map_callback_outputs(steps: List[Step]):
"""Iterate over the provided steps, building a map of callback output parameters to step names.

Expand Down
4 changes: 2 additions & 2 deletions tests/integ/sagemaker/workflow/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_and_execute_pipeline(
region_name,
role,
no_of_steps,
last_step_name,
last_step_name_prefix,
execution_parameters,
step_status,
step_result_type=None,
Expand Down Expand Up @@ -66,7 +66,7 @@ def create_and_execute_pipeline(
len(execution_steps) == no_of_steps
), f"Expected {no_of_steps}, instead found {len(execution_steps)}"

assert last_step_name in execution_steps[0]["StepName"]
assert last_step_name_prefix in execution_steps[0]["StepName"]
assert execution_steps[0]["StepStatus"] == step_status
if step_result_type:
result = execution.result(execution_steps[0]["StepName"])
Expand Down
110 changes: 105 additions & 5 deletions tests/integ/sagemaker/workflow/test_selective_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest

from sagemaker.processing import ProcessingInput
from tests.integ import DATA_DIR
from sagemaker.sklearn import SKLearnProcessor
from sagemaker.workflow.step_outputs import get_step
Expand Down Expand Up @@ -84,7 +85,7 @@ def sum(a, b):
region_name=region_name,
role=role,
no_of_steps=2,
last_step_name="sum",
last_step_name_prefix="sum",
execution_parameters=dict(),
step_status="Succeeded",
step_result_type=int,
Expand All @@ -97,7 +98,7 @@ def sum(a, b):
region_name=region_name,
role=role,
no_of_steps=2,
last_step_name="sum",
last_step_name_prefix="sum",
execution_parameters=dict(),
step_status="Succeeded",
step_result_type=int,
Expand All @@ -115,7 +116,7 @@ def sum(a, b):
pass


def test_selective_execution_of_regular_step_depended_by_function_step(
def test_selective_execution_of_regular_step_referenced_by_function_step(
sagemaker_session,
role,
pipeline_name,
Expand Down Expand Up @@ -168,7 +169,7 @@ def func_2(arg):
region_name=region_name,
role=role,
no_of_steps=2,
last_step_name="func",
last_step_name_prefix="func",
execution_parameters=dict(),
step_status="Succeeded",
step_result_type=str,
Expand All @@ -182,7 +183,7 @@ def func_2(arg):
region_name=region_name,
role=role,
no_of_steps=2,
last_step_name="func",
last_step_name_prefix="func",
execution_parameters=dict(),
step_status="Succeeded",
step_result_type=str,
Expand All @@ -199,3 +200,102 @@ def func_2(arg):
pipeline.delete()
except Exception:
pass


def test_selective_execution_of_function_step_referenced_by_regular_step(
pipeline_session,
role,
pipeline_name,
region_name,
dummy_container_without_error,
sklearn_latest_version,
):
# Test Selective Pipeline Execution on function step -> [select: regular step]
os.environ["AWS_DEFAULT_REGION"] = region_name
processing_job_instance_counts = 2

@step(
name="step1",
role=role,
image_uri=dummy_container_without_error,
instance_type=INSTANCE_TYPE,
keep_alive_period_in_seconds=60,
)
def func(var: int):
return 1, var

step_output = func(processing_job_instance_counts)

script_path = os.path.join(DATA_DIR, "dummy_script.py")
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
inputs = [
ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/"),
]

sklearn_processor = SKLearnProcessor(
framework_version=sklearn_latest_version,
role=role,
instance_type=INSTANCE_TYPE,
instance_count=step_output[1],
command=["python3"],
sagemaker_session=pipeline_session,
base_job_name="test-sklearn",
)

step_args = sklearn_processor.run(
inputs=inputs,
code=script_path,
)
process_step = ProcessingStep(
name="MyProcessStep",
step_args=step_args,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[process_step],
sagemaker_session=pipeline_session,
)

try:
execution, _ = create_and_execute_pipeline(
pipeline=pipeline,
pipeline_name=pipeline_name,
region_name=region_name,
role=role,
no_of_steps=2,
last_step_name_prefix=process_step.name,
execution_parameters=dict(),
step_status="Succeeded",
wait_duration=1000, # seconds
)

_, execution_steps2 = create_and_execute_pipeline(
pipeline=pipeline,
pipeline_name=pipeline_name,
region_name=region_name,
role=role,
no_of_steps=2,
last_step_name_prefix=process_step.name,
execution_parameters=dict(),
step_status="Succeeded",
wait_duration=1000, # seconds
selective_execution_config=SelectiveExecutionConfig(
source_pipeline_execution_arn=execution.arn,
selected_steps=[process_step.name],
),
)

execution_proc_job = pipeline_session.describe_processing_job(
execution_steps2[0]["Metadata"]["ProcessingJob"]["Arn"].split("/")[-1]
)
assert (
execution_proc_job["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
== processing_job_instance_counts
)

finally:
try:
pipeline.delete()
except Exception:
pass
Loading