Skip to content

Commit 63f5e61

Browse files
committed
feat: Support selective pipeline execution between function step and regular step
1 parent 985f7bc commit 63f5e61

File tree

5 files changed

+115
-42
lines changed

5 files changed

+115
-42
lines changed

src/sagemaker/workflow/function_step.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,11 @@
3333
PipelineVariable,
3434
)
3535

36-
from sagemaker.workflow.execution_variables import ExecutionVariables
3736
from sagemaker.workflow.properties import Properties
3837
from sagemaker.workflow.retry import RetryPolicy
3938
from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
4039
from sagemaker.workflow.step_collections import StepCollection
41-
from sagemaker.workflow.step_outputs import StepOutput
40+
from sagemaker.workflow.step_outputs import StepOutput, get_step
4241
from sagemaker.workflow.utilities import trim_request_dict, load_step_compilation_context
4342

4443
from sagemaker.s3_utils import s3_path_join
@@ -275,16 +274,11 @@ def expr(self) -> RequestType:
275274

276275
def _to_json_get(self) -> JsonGet:
277276
"""Expression structure for workflow service calls using JsonGet resolution."""
278-
from sagemaker.remote_function.core.stored_function import (
279-
JSON_SERIALIZED_RESULT_KEY,
280-
RESULTS_FOLDER,
281-
JSON_RESULTS_FILE,
282-
)
277+
from sagemaker.remote_function.core.stored_function import JSON_SERIALIZED_RESULT_KEY
283278

284279
if not self._step.name:
285280
raise ValueError("Step name is not defined.")
286281

287-
s3_root_uri = self._step._job_settings.s3_root_uri
288282
# Resolve json path --
289283
# Deserializer will be able to resolve a JsonGet using path "Return[1]" to
290284
# access value 10 from following serialized JSON:
@@ -308,15 +302,8 @@ def _to_json_get(self) -> JsonGet:
308302

309303
return JsonGet(
310304
s3_uri=Join(
311-
"/",
312-
[
313-
s3_root_uri,
314-
ExecutionVariables.PIPELINE_NAME,
315-
ExecutionVariables.PIPELINE_EXECUTION_ID,
316-
self._step.name,
317-
RESULTS_FOLDER,
318-
JSON_RESULTS_FILE,
319-
],
305+
on="",
306+
values=[get_step(self)._properties.OutputDataConfig.S3OutputPath],
320307
),
321308
json_path=_resolved_reference_path,
322309
step=self._step,

src/sagemaker/workflow/functions.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.workflow.entities import PipelineVariable
2222
from sagemaker.workflow.execution_variables import ExecutionVariable
2323
from sagemaker.workflow.parameters import Parameter
24-
from sagemaker.workflow.properties import PropertyFile
24+
from sagemaker.workflow.properties import PropertyFile, Properties
2525

2626
if TYPE_CHECKING:
2727
from sagemaker.workflow.steps import Step
@@ -172,9 +172,9 @@ def _validate_json_get_s3_uri(self):
172172
for join_arg in s3_uri.values:
173173
if not is_pipeline_variable(join_arg):
174174
continue
175-
if not isinstance(join_arg, (Parameter, ExecutionVariable)):
175+
if not isinstance(join_arg, (Parameter, ExecutionVariable, Properties)):
176176
raise ValueError(
177177
f"Invalid JsonGet function {self.expr}. "
178-
f"The Join values in JsonGet's s3_uri can only be a primitive object "
179-
f"or Parameter or ExecutionVariable."
178+
f"The Join values in JsonGet's s3_uri can only be a primitive object, "
179+
f"Parameter, ExecutionVariable or Properties."
180180
)

tests/integ/sagemaker/workflow/test_selective_execution.py

+100
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818

19+
from sagemaker.processing import ProcessingInput
1920
from tests.integ import DATA_DIR
2021
from sagemaker.sklearn import SKLearnProcessor
2122
from sagemaker.workflow.step_outputs import get_step
@@ -199,3 +200,102 @@ def func_2(arg):
199200
pipeline.delete()
200201
except Exception:
201202
pass
203+
204+
205+
def test_decorator_step_data_referenced_by_other_steps(
206+
pipeline_session,
207+
role,
208+
pipeline_name,
209+
region_name,
210+
dummy_container_without_error,
211+
sklearn_latest_version,
212+
):
213+
# Test Selective Pipeline Execution on function step -> [select: regular step]
214+
os.environ["AWS_DEFAULT_REGION"] = region_name
215+
processing_job_instance_counts = 2
216+
217+
@step(
218+
name="step1",
219+
role=role,
220+
image_uri=dummy_container_without_error,
221+
instance_type=INSTANCE_TYPE,
222+
keep_alive_period_in_seconds=60,
223+
)
224+
def func(var: int):
225+
return 1, var
226+
227+
step_output = func(processing_job_instance_counts)
228+
229+
script_path = os.path.join(DATA_DIR, "dummy_script.py")
230+
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
231+
inputs = [
232+
ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/"),
233+
]
234+
235+
sklearn_processor = SKLearnProcessor(
236+
framework_version=sklearn_latest_version,
237+
role=role,
238+
instance_type=INSTANCE_TYPE,
239+
instance_count=step_output[1],
240+
command=["python3"],
241+
sagemaker_session=pipeline_session,
242+
base_job_name="test-sklearn",
243+
)
244+
245+
step_args = sklearn_processor.run(
246+
inputs=inputs,
247+
code=script_path,
248+
)
249+
process_step = ProcessingStep(
250+
name="MyProcessStep",
251+
step_args=step_args,
252+
)
253+
254+
pipeline = Pipeline(
255+
name=pipeline_name,
256+
steps=[process_step],
257+
sagemaker_session=pipeline_session,
258+
)
259+
260+
try:
261+
execution, _ = create_and_execute_pipeline(
262+
pipeline=pipeline,
263+
pipeline_name=pipeline_name,
264+
region_name=region_name,
265+
role=role,
266+
no_of_steps=2,
267+
last_step_name=process_step.name,
268+
execution_parameters=dict(),
269+
step_status="Succeeded",
270+
wait_duration=1000, # seconds
271+
)
272+
273+
_, execution_steps2 = create_and_execute_pipeline(
274+
pipeline=pipeline,
275+
pipeline_name=pipeline_name,
276+
region_name=region_name,
277+
role=role,
278+
no_of_steps=2,
279+
last_step_name=process_step.name,
280+
execution_parameters=dict(),
281+
step_status="Succeeded",
282+
wait_duration=1000, # seconds
283+
selective_execution_config=SelectiveExecutionConfig(
284+
source_pipeline_execution_arn=execution.arn,
285+
selected_steps=[process_step.name],
286+
),
287+
)
288+
289+
execution_proc_job = pipeline_session.describe_processing_job(
290+
execution_steps2[0]["Metadata"]["ProcessingJob"]["Arn"].split("/")[-1]
291+
)
292+
assert (
293+
execution_proc_job["ProcessingResources"]["ClusterConfig"]["InstanceCount"]
294+
== processing_job_instance_counts
295+
)
296+
297+
finally:
298+
try:
299+
pipeline.delete()
300+
except Exception:
301+
pass

tests/unit/sagemaker/workflow/test_function_step.py

+4-18
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,8 @@ def func():
328328
"Path": "Result['some_key'][0]",
329329
"S3Uri": {
330330
"Std:Join": {
331-
"On": "/",
332-
"Values": [
333-
"s3://bucket",
334-
{"Get": "Execution.PipelineName"},
335-
{"Get": "Execution.PipelineExecutionId"},
336-
"step_name",
337-
"results",
338-
"results.json",
339-
],
331+
"On": "",
332+
"Values": [{"Get": "Steps.step_name.OutputDataConfig.S3OutputPath"}],
340333
}
341334
},
342335
}
@@ -352,15 +345,8 @@ def func():
352345
"Path": "Result[0]['some_key']",
353346
"S3Uri": {
354347
"Std:Join": {
355-
"On": "/",
356-
"Values": [
357-
"s3://bucket",
358-
{"Get": "Execution.PipelineName"},
359-
{"Get": "Execution.PipelineExecutionId"},
360-
"step_name",
361-
"results",
362-
"results.json",
363-
],
348+
"On": "",
349+
"Values": [{"Get": "Steps.step_name.OutputDataConfig.S3OutputPath"}],
364350
}
365351
},
366352
}

tests/unit/sagemaker/workflow/test_functions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ def test_json_get_invalid_s3_uri_not_join():
279279

280280
def test_json_get_invalid_s3_uri_with_invalid_pipeline_variable(sagemaker_session):
281281
with pytest.raises(ValueError) as e:
282-
JsonGet(s3_uri=Join(on="/", values=["s3:/", Properties(step_name="test")]))
282+
JsonGet(s3_uri=Join(on="/", values=["s3:/", Join()]))
283283
assert (
284-
"The Join values in JsonGet's s3_uri can only be a primitive object or Parameter or ExecutionVariable."
285-
in str(e.value)
284+
"The Join values in JsonGet's s3_uri can only be a primitive object, "
285+
"Parameter, ExecutionVariable or Properties." in str(e.value)
286286
)

0 commit comments

Comments
 (0)