Skip to content

Commit bb48c73

Browse files
qidewenwhenbencrabtree
authored andcommitted
feat: Support selective pipeline execution between function step and regular step (aws#4392)
1 parent 2f1bed0 commit bb48c73

File tree

10 files changed

+137
-108
lines changed

10 files changed

+137
-108
lines changed

src/sagemaker/workflow/function_step.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,11 @@
3333
PipelineVariable,
3434
)
3535

36-
from sagemaker.workflow.execution_variables import ExecutionVariables
3736
from sagemaker.workflow.properties import Properties
3837
from sagemaker.workflow.retry import RetryPolicy
3938
from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
4039
from sagemaker.workflow.step_collections import StepCollection
41-
from sagemaker.workflow.step_outputs import StepOutput
40+
from sagemaker.workflow.step_outputs import StepOutput, get_step
4241
from sagemaker.workflow.utilities import trim_request_dict, load_step_compilation_context
4342

4443
from sagemaker.s3_utils import s3_path_join
@@ -277,14 +276,12 @@ def _to_json_get(self) -> JsonGet:
277276
"""Expression structure for workflow service calls using JsonGet resolution."""
278277
from sagemaker.remote_function.core.stored_function import (
279278
JSON_SERIALIZED_RESULT_KEY,
280-
RESULTS_FOLDER,
281279
JSON_RESULTS_FILE,
282280
)
283281

284282
if not self._step.name:
285283
raise ValueError("Step name is not defined.")
286284

287-
s3_root_uri = self._step._job_settings.s3_root_uri
288285
# Resolve json path --
289286
# Deserializer will be able to resolve a JsonGet using path "Return[1]" to
290287
# access value 10 from following serialized JSON:
@@ -308,13 +305,9 @@ def _to_json_get(self) -> JsonGet:
308305

309306
return JsonGet(
310307
s3_uri=Join(
311-
"/",
312-
[
313-
s3_root_uri,
314-
ExecutionVariables.PIPELINE_NAME,
315-
ExecutionVariables.PIPELINE_EXECUTION_ID,
316-
self._step.name,
317-
RESULTS_FOLDER,
308+
on="/",
309+
values=[
310+
get_step(self)._properties.OutputDataConfig.S3OutputPath,
318311
JSON_RESULTS_FILE,
319312
],
320313
),

src/sagemaker/workflow/functions.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.workflow.entities import PipelineVariable
2222
from sagemaker.workflow.execution_variables import ExecutionVariable
2323
from sagemaker.workflow.parameters import Parameter
24-
from sagemaker.workflow.properties import PropertyFile
24+
from sagemaker.workflow.properties import PropertyFile, Properties
2525

2626
if TYPE_CHECKING:
2727
from sagemaker.workflow.steps import Step
@@ -172,9 +172,9 @@ def _validate_json_get_s3_uri(self):
172172
for join_arg in s3_uri.values:
173173
if not is_pipeline_variable(join_arg):
174174
continue
175-
if not isinstance(join_arg, (Parameter, ExecutionVariable)):
175+
if not isinstance(join_arg, (Parameter, ExecutionVariable, Properties)):
176176
raise ValueError(
177177
f"Invalid JsonGet function {self.expr}. "
178-
f"The Join values in JsonGet's s3_uri can only be a primitive object "
179-
f"or Parameter or ExecutionVariable."
178+
f"The Join values in JsonGet's s3_uri can only be a primitive object, "
179+
f"Parameter, ExecutionVariable or Properties."
180180
)

src/sagemaker/workflow/pipeline.py

+1-26
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
RESOURCE_NOT_FOUND_EXCEPTION,
4242
EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT,
4343
)
44-
from sagemaker.workflow.function_step import DelayedReturn
4544
from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep
4645
from sagemaker.workflow.entities import (
4746
Expression,
@@ -725,10 +724,7 @@ def _interpolate(
725724
pipeline_name (str): The name of the pipeline to be interpolated.
726725
"""
727726
if isinstance(obj, (Expression, Parameter, Properties, StepOutput)):
728-
updated_obj = _replace_pipeline_name_in_json_get_s3_uri(
729-
obj=obj, pipeline_name=pipeline_name
730-
)
731-
return updated_obj.expr
727+
return obj.expr
732728

733729
if isinstance(obj, CallbackOutput):
734730
step_name = callback_output_to_step_map[obj.output_name]
@@ -760,27 +756,6 @@ def _interpolate(
760756
return new
761757

762758

763-
# TODO: we should remove this once the ExecutionVariables.PIPELINE_NAME is fixed in backend
764-
def _replace_pipeline_name_in_json_get_s3_uri(obj: Union[RequestType, Any], pipeline_name: str):
765-
"""Replace the ExecutionVariables.PIPELINE_NAME in DelayedReturn's JsonGet s3_uri
766-
767-
with the pipeline_name, because ExecutionVariables.PIPELINE_NAME
768-
is parsed as all lower-cased str in backend.
769-
"""
770-
if not isinstance(obj, DelayedReturn):
771-
return obj
772-
773-
json_get = obj._to_json_get()
774-
775-
if not json_get.s3_uri:
776-
return obj
777-
# the s3 uri has to be a Join, which has been validated in JsonGet init
778-
for i in range(len(json_get.s3_uri.values)):
779-
if json_get.s3_uri.values[i] == ExecutionVariables.PIPELINE_NAME:
780-
json_get.s3_uri.values[i] = pipeline_name
781-
return json_get
782-
783-
784759
def _map_callback_outputs(steps: List[Step]):
785760
"""Iterate over the provided steps, building a map of callback output parameters to step names.
786761

tests/integ/sagemaker/workflow/helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def create_and_execute_pipeline(
3333
region_name,
3434
role,
3535
no_of_steps,
36-
last_step_name,
36+
last_step_name_prefix,
3737
execution_parameters,
3838
step_status,
3939
step_result_type=None,
@@ -66,7 +66,7 @@ def create_and_execute_pipeline(
6666
len(execution_steps) == no_of_steps
6767
), f"Expected {no_of_steps}, instead found {len(execution_steps)}"
6868

69-
assert last_step_name in execution_steps[0]["StepName"]
69+
assert last_step_name_prefix in execution_steps[0]["StepName"]
7070
assert execution_steps[0]["StepStatus"] == step_status
7171
if step_result_type:
7272
result = execution.result(execution_steps[0]["StepName"])

tests/integ/sagemaker/workflow/test_selective_execution.py

+105-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818

19+
from sagemaker.processing import ProcessingInput
1920
from tests.integ import DATA_DIR
2021
from sagemaker.sklearn import SKLearnProcessor
2122
from sagemaker.workflow.step_outputs import get_step
@@ -84,7 +85,7 @@ def sum(a, b):
8485
region_name=region_name,
8586
role=role,
8687
no_of_steps=2,
87-
last_step_name="sum",
88+
last_step_name_prefix="sum",
8889
execution_parameters=dict(),
8990
step_status="Succeeded",
9091
step_result_type=int,
@@ -97,7 +98,7 @@ def sum(a, b):
9798
region_name=region_name,
9899
role=role,
99100
no_of_steps=2,
100-
last_step_name="sum",
101+
last_step_name_prefix="sum",
101102
execution_parameters=dict(),
102103
step_status="Succeeded",
103104
step_result_type=int,
@@ -115,7 +116,7 @@ def sum(a, b):
115116
pass
116117

117118

118-
def test_selective_execution_of_regular_step_depended_by_function_step(
119+
def test_selective_execution_of_regular_step_referenced_by_function_step(
119120
sagemaker_session,
120121
role,
121122
pipeline_name,
@@ -168,7 +169,7 @@ def func_2(arg):
168169
region_name=region_name,
169170
role=role,
170171
no_of_steps=2,
171-
last_step_name="func",
172+
last_step_name_prefix="func",
172173
execution_parameters=dict(),
173174
step_status="Succeeded",
174175
step_result_type=str,
@@ -182,7 +183,7 @@ def func_2(arg):
182183
region_name=region_name,
183184
role=role,
184185
no_of_steps=2,
185-
last_step_name="func",
186+
last_step_name_prefix="func",
186187
execution_parameters=dict(),
187188
step_status="Succeeded",
188189
step_result_type=str,
@@ -199,3 +200,102 @@ def func_2(arg):
199200
pipeline.delete()
200201
except Exception:
201202
pass
203+
204+
205+
def test_selective_execution_of_function_step_referenced_by_regular_step(
206+
pipeline_session,
207+
role,
208+
pipeline_name,
209+
region_name,
210+
dummy_container_without_error,
211+
sklearn_latest_version,
212+
):
213+
# Test Selective Pipeline Execution on function step -> [select: regular step]
214+
os.environ["AWS_DEFAULT_REGION"] = region_name
215+
processing_job_instance_counts = 2
216+
217+
@step(
218+
name="step1",
219+
role=role,
220+
image_uri=dummy_container_without_error,
221+
instance_type=INSTANCE_TYPE,
222+
keep_alive_period_in_seconds=60,
223+
)
224+
def func(var: int):
225+
return 1, var
226+
227+
step_output = func(processing_job_instance_counts)
228+
229+
script_path = os.path.join(DATA_DIR, "dummy_script.py")
230+
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
231+
inputs = [
232+
ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/"),
233+
]
234+
235+
sklearn_processor = SKLearnProcessor(
236+
framework_version=sklearn_latest_version,
237+
role=role,
238+
instance_type=INSTANCE_TYPE,
239+
instance_count=step_output[1],
240+
command=["python3"],
241+
sagemaker_session=pipeline_session,
242+
base_job_name="test-sklearn",
243+
)
244+
245+
step_args = sklearn_processor.run(
246+
inputs=inputs,
247+
code=script_path,
248+
)
249+
process_step = ProcessingStep(
250+
name="MyProcessStep",
251+
step_args=step_args,
252+
)
253+
254+
pipeline = Pipeline(
255+
name=pipeline_name,
256+
steps=[process_step],
257+
sagemaker_session=pipeline_session,
258+
)
259+
260+
try:
261+
execution, _ = create_and_execute_pipeline(
262+
pipeline=pipeline,
263+
pipeline_name=pipeline_name,
264+
region_name=region_name,
265+
role=role,
266+
no_of_steps=2,
267+
last_step_name_prefix=process_step.name,
268+
execution_parameters=dict(),
269+
step_status="Succeeded",
270+
wait_duration=1000, # seconds
271+
)
272+
273+
_, execution_steps2 = create_and_execute_pipeline(
274+
pipeline=pipeline,
275+
pipeline_name=pipeline_name,
276+
region_name=region_name,
277+
role=role,
278+
no_of_steps=2,
279+
last_step_name_prefix=process_step.name,
280+
execution_parameters=dict(),
281+
step_status="Succeeded",
282+
wait_duration=1000, # seconds
283+
selective_execution_config=SelectiveExecutionConfig(
284+
source_pipeline_execution_arn=execution.arn,
285+
selected_steps=[process_step.name],
286+
),
287+
)
288+
289+
execution_proc_job = pipeline_session.describe_processing_job(
290+
execution_steps2[0]["Metadata"]["ProcessingJob"]["Arn"].split("/")[-1]
291+
)
292+
assert (
293+
execution_proc_job["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
294+
== processing_job_instance_counts
295+
)
296+
297+
finally:
298+
try:
299+
pipeline.delete()
300+
except Exception:
301+
pass

0 commit comments

Comments
 (0)