Skip to content

Commit 9828101

Browse files
authored
feat: Support selective pipeline execution for function step (aws#4372)
1 parent 8aaaf51 commit 9828101

File tree

13 files changed

+439
-87
lines changed

13 files changed

+439
-87
lines changed

src/sagemaker/remote_function/core/pipeline_variables.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from sagemaker.s3 import s3_path_join
2121
from sagemaker.remote_function.core.serialization import deserialize_obj_from_s3
22+
from sagemaker.workflow.step_outputs import get_step
2223

2324

2425
@dataclass
@@ -92,7 +93,7 @@ class _S3BaseUriIdentifier:
9293
class _DelayedReturn:
9394
"""Delayed return from a function."""
9495

95-
uri: List[Union[str, _Parameter, _ExecutionVariable]]
96+
uri: Union[_Properties, List[Union[str, _Parameter, _ExecutionVariable]]]
9697
reference_path: Tuple = field(default_factory=tuple)
9798

9899

@@ -164,6 +165,7 @@ def __init__(
164165
self,
165166
delayed_returns: List[_DelayedReturn],
166167
hmac_key: str,
168+
properties_resolver: _PropertiesResolver,
167169
parameter_resolver: _ParameterResolver,
168170
execution_variable_resolver: _ExecutionVariableResolver,
169171
s3_base_uri: str,
@@ -174,6 +176,7 @@ def __init__(
174176
Args:
175177
delayed_returns: list of delayed returns to resolve.
176178
hmac_key: key used to encrypt serialized and deserialized function and arguments.
179+
properties_resolver: resolver used to resolve step properties.
177180
parameter_resolver: resolver used to pipeline parameters.
178181
execution_variable_resolver: resolver used to resolve execution variables.
179182
s3_base_uri (str): the s3 base uri of the function step that
@@ -184,6 +187,7 @@ def __init__(
184187
self._s3_base_uri = s3_base_uri
185188
self._parameter_resolver = parameter_resolver
186189
self._execution_variable_resolver = execution_variable_resolver
190+
self._properties_resolver = properties_resolver
187191
# different delayed returns can have the same uri, so we need to dedupe
188192
uris = {
189193
self._resolve_delayed_return_uri(delayed_return) for delayed_return in delayed_returns
@@ -214,7 +218,10 @@ def resolve(self, delayed_return: _DelayedReturn) -> Any:
214218

215219
def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
216220
"""Resolve the s3 uri of the delayed return."""
221+
if isinstance(delayed_return.uri, _Properties):
222+
return self._properties_resolver.resolve(delayed_return.uri)
217223

224+
# Keep the following old resolution logics to keep backward compatible
218225
uri = []
219226
for component in delayed_return.uri:
220227
if isinstance(component, _Parameter):
@@ -274,6 +281,7 @@ def resolve_pipeline_variables(
274281
delayed_return_resolver = _DelayedReturnResolver(
275282
delayed_returns=delayed_returns,
276283
hmac_key=hmac_key,
284+
properties_resolver=properties_resolver,
277285
parameter_resolver=parameter_resolver,
278286
execution_variable_resolver=execution_variable_resolver,
279287
s3_base_uri=s3_base_uri,
@@ -325,27 +333,12 @@ def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict
325333

326334
from sagemaker.workflow.entities import PipelineVariable
327335

328-
from sagemaker.workflow.execution_variables import ExecutionVariables
329-
330336
from sagemaker.workflow.function_step import DelayedReturn
331337

332-
# Notes:
333-
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334-
# when defining function steps. After step-level arg serialization,
335-
# it's hard to update the s3_base_uri in pipeline compile time.
336-
# Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337-
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338-
# the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339-
# should be retrieved from the pipeline's sagemaker_session.
340338
def convert(arg):
341339
if isinstance(arg, DelayedReturn):
342340
return _DelayedReturn(
343-
uri=[
344-
_S3BaseUriIdentifier(),
345-
ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable,
346-
arg._step.name,
347-
"results",
348-
],
341+
uri=get_step(arg)._properties.OutputDataConfig.S3OutputPath._pickleable,
349342
reference_path=arg._reference_path,
350343
)
351344

src/sagemaker/remote_function/job.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from sagemaker import vpc_utils
6161
from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData
6262
from sagemaker.remote_function.core.pipeline_variables import Context
63+
6364
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
6465
RuntimeEnvironmentManager,
6566
_DependencySettings,
@@ -72,6 +73,8 @@
7273
copy_workdir,
7374
resolve_custom_file_filter_from_config_file,
7475
)
76+
from sagemaker.workflow.function_step import DelayedReturn
77+
from sagemaker.workflow.step_outputs import get_step
7578

7679
if TYPE_CHECKING:
7780
from sagemaker.workflow.entities import PipelineVariable
@@ -701,6 +704,7 @@ def compile(
701704
"""Build the artifacts and generate the training job request."""
702705
from sagemaker.workflow.properties import Properties
703706
from sagemaker.workflow.parameters import Parameter
707+
from sagemaker.workflow.functions import Join
704708
from sagemaker.workflow.execution_variables import ExecutionVariables, ExecutionVariable
705709
from sagemaker.workflow.utilities import load_step_compilation_context
706710

@@ -760,7 +764,19 @@ def compile(
760764
job_settings=job_settings, s3_base_uri=s3_base_uri
761765
)
762766

763-
output_config = {"S3OutputPath": s3_base_uri}
767+
if step_compilation_context:
768+
s3_output_path = Join(
769+
on="/",
770+
values=[
771+
s3_base_uri,
772+
ExecutionVariables.PIPELINE_EXECUTION_ID,
773+
step_compilation_context.step_name,
774+
"results",
775+
],
776+
)
777+
output_config = {"S3OutputPath": s3_output_path}
778+
else:
779+
output_config = {"S3OutputPath": s3_base_uri}
764780
if job_settings.s3_kms_key is not None:
765781
output_config["KmsKeyId"] = job_settings.s3_kms_key
766782
request_dict["OutputDataConfig"] = output_config
@@ -804,6 +820,11 @@ def compile(
804820
if isinstance(arg, (Parameter, ExecutionVariable, Properties)):
805821
container_args.extend([arg.expr["Get"], arg.to_string()])
806822

823+
if isinstance(arg, DelayedReturn):
824+
# The uri is a Properties object
825+
uri = get_step(arg)._properties.OutputDataConfig.S3OutputPath
826+
container_args.extend([uri.expr["Get"], uri.to_string()])
827+
807828
if run_info is not None:
808829
container_args.extend(["--run_in_context", json.dumps(dataclasses.asdict(run_info))])
809830
elif _RunContext.get_current_run() is not None:

src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _is_file_exists(self, dependencies):
252252

253253
def _install_requirements_txt(self, local_path, python_executable):
254254
"""Install requirements.txt file"""
255-
cmd = f"{python_executable} -m pip install -r {local_path}"
255+
cmd = f"{python_executable} -m pip install -r {local_path} -U"
256256
logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd())
257257
_run_shell_cmd(cmd)
258258
logger.info("Command %s ran successfully", cmd)
@@ -268,7 +268,7 @@ def _create_conda_env(self, env_name, local_path):
268268
def _install_req_txt_in_conda_env(self, env_name, local_path):
269269
"""Install requirements.txt in the given conda environment"""
270270

271-
cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path}"
271+
cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U"
272272
logger.info("Activating conda env and installing requirements: %s", cmd)
273273
_run_shell_cmd(cmd)
274274
logger.info("Requirements installed successfully in conda env %s", env_name)

src/sagemaker/workflow/function_step.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535

3636
from sagemaker.workflow.execution_variables import ExecutionVariables
37+
from sagemaker.workflow.properties import Properties
3738
from sagemaker.workflow.retry import RetryPolicy
3839
from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum
3940
from sagemaker.workflow.step_collections import StepCollection
@@ -101,6 +102,12 @@ def __init__(
101102

102103
self.__job_settings = None
103104

105+
# It's for internal usage to retrieve execution id from the properties.
106+
# However, we won't expose the properties of function step to customers.
107+
self._properties = Properties(
108+
step_name=name, step=self, shape_name="DescribeTrainingJobResponse"
109+
)
110+
104111
(
105112
self._converted_func_args,
106113
self._converted_func_kwargs,

src/sagemaker/workflow/pipeline.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1039,11 +1039,19 @@ def get_function_step_result(
10391039
raise ValueError(_ERROR_MSG_OF_WRONG_STEP_TYPE)
10401040
s3_output_path = describe_training_job_response["OutputDataConfig"]["S3OutputPath"]
10411041

1042+
s3_uri_suffix = s3_path_join(execution_id, step_name, RESULTS_FOLDER)
1043+
if s3_output_path.endswith(s3_uri_suffix) or s3_output_path[0:-1].endswith(s3_uri_suffix):
1044+
s3_uri = s3_output_path
1045+
else:
1046+
# This is the obsoleted version of s3_output_path
1047+
# Keeping it for backward compatible
1048+
s3_uri = s3_path_join(s3_output_path, s3_uri_suffix)
1049+
10421050
job_status = describe_training_job_response["TrainingJobStatus"]
10431051
if job_status == "Completed":
10441052
return deserialize_obj_from_s3(
10451053
sagemaker_session=sagemaker_session,
1046-
s3_uri=s3_path_join(s3_output_path, execution_id, step_name, RESULTS_FOLDER),
1054+
s3_uri=s3_uri,
10471055
hmac_key=describe_training_job_response["Environment"]["REMOTE_FUNCTION_SECRET_KEY"],
10481056
)
10491057

tests/integ/sagemaker/workflow/helpers.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,24 @@ def create_and_execute_pipeline(
3939
step_result_type=None,
4040
step_result_value=None,
4141
wait_duration=400, # seconds
42+
selective_execution_config=None,
4243
):
43-
response = pipeline.create(role)
44-
45-
create_arn = response["PipelineArn"]
46-
assert re.match(
47-
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
48-
create_arn,
44+
create_arn = None
45+
if not selective_execution_config:
46+
response = pipeline.create(role)
47+
create_arn = response["PipelineArn"]
48+
assert re.match(
49+
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
50+
create_arn,
51+
)
52+
53+
execution = pipeline.start(
54+
parameters=execution_parameters, selective_execution_config=selective_execution_config
4955
)
5056

51-
execution = pipeline.start(parameters=execution_parameters)
52-
response = execution.describe()
53-
assert response["PipelineArn"] == create_arn
57+
if create_arn:
58+
response = execution.describe()
59+
assert response["PipelineArn"] == create_arn
5460

5561
wait_pipeline_execution(execution=execution, delay=20, max_attempts=int(wait_duration / 20))
5662

@@ -71,6 +77,16 @@ def create_and_execute_pipeline(
7177
if step_result_value:
7278
result = execution.result(execution_steps[0]["StepName"])
7379
assert result == step_result_value, f"Expected {step_result_value}, instead found {result}"
80+
81+
if selective_execution_config:
82+
for exe_step in execution_steps:
83+
if exe_step["StepName"] in selective_execution_config.selected_steps:
84+
continue
85+
assert (
86+
exe_step["SelectiveExecutionResult"]["SourcePipelineExecutionArn"]
87+
== selective_execution_config.source_pipeline_execution_arn
88+
)
89+
7490
return execution, execution_steps
7591

7692

0 commit comments

Comments
 (0)