diff --git a/src/sagemaker/workflow/function_step.py b/src/sagemaker/workflow/function_step.py index 55e7eac90c..32353ece07 100644 --- a/src/sagemaker/workflow/function_step.py +++ b/src/sagemaker/workflow/function_step.py @@ -33,12 +33,11 @@ PipelineVariable, ) -from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.properties import Properties from sagemaker.workflow.retry import RetryPolicy from sagemaker.workflow.steps import Step, ConfigurableRetryStep, StepTypeEnum from sagemaker.workflow.step_collections import StepCollection -from sagemaker.workflow.step_outputs import StepOutput +from sagemaker.workflow.step_outputs import StepOutput, get_step from sagemaker.workflow.utilities import trim_request_dict, load_step_compilation_context from sagemaker.s3_utils import s3_path_join @@ -277,14 +276,12 @@ def _to_json_get(self) -> JsonGet: """Expression structure for workflow service calls using JsonGet resolution.""" from sagemaker.remote_function.core.stored_function import ( JSON_SERIALIZED_RESULT_KEY, - RESULTS_FOLDER, JSON_RESULTS_FILE, ) if not self._step.name: raise ValueError("Step name is not defined.") - s3_root_uri = self._step._job_settings.s3_root_uri # Resolve json path -- # Deserializer will be able to resolve a JsonGet using path "Return[1]" to # access value 10 from following serialized JSON: @@ -308,13 +305,9 @@ def _to_json_get(self) -> JsonGet: return JsonGet( s3_uri=Join( - "/", - [ - s3_root_uri, - ExecutionVariables.PIPELINE_NAME, - ExecutionVariables.PIPELINE_EXECUTION_ID, - self._step.name, - RESULTS_FOLDER, + on="/", + values=[ + get_step(self)._properties.OutputDataConfig.S3OutputPath, JSON_RESULTS_FILE, ], ), diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index 4f63c4651b..947578d433 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -21,7 +21,7 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.execution_variables import ExecutionVariable from sagemaker.workflow.parameters import Parameter -from sagemaker.workflow.properties import PropertyFile +from sagemaker.workflow.properties import PropertyFile, Properties if TYPE_CHECKING: from sagemaker.workflow.steps import Step @@ -172,9 +172,9 @@ def _validate_json_get_s3_uri(self): for join_arg in s3_uri.values: if not is_pipeline_variable(join_arg): continue - if not isinstance(join_arg, (Parameter, ExecutionVariable)): + if not isinstance(join_arg, (Parameter, ExecutionVariable, Properties)): raise ValueError( f"Invalid JsonGet function {self.expr}. " - f"The Join values in JsonGet's s3_uri can only be a primitive object " - f"or Parameter or ExecutionVariable." + f"The Join values in JsonGet's s3_uri can only be a primitive object, " + f"Parameter, ExecutionVariable or Properties." ) diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 6800f2a3ac..510ccd76bf 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -41,7 +41,6 @@ RESOURCE_NOT_FOUND_EXCEPTION, EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, ) -from sagemaker.workflow.function_step import DelayedReturn from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep from sagemaker.workflow.entities import ( Expression, @@ -725,10 +724,7 @@ def _interpolate( pipeline_name (str): The name of the pipeline to be interpolated. """ if isinstance(obj, (Expression, Parameter, Properties, StepOutput)): - updated_obj = _replace_pipeline_name_in_json_get_s3_uri( - obj=obj, pipeline_name=pipeline_name - ) - return updated_obj.expr + return obj.expr if isinstance(obj, CallbackOutput): step_name = callback_output_to_step_map[obj.output_name] @@ -760,27 +756,6 @@ def _interpolate( return new -# TODO: we should remove this once the ExecutionVariables.PIPELINE_NAME is fixed in backend -def _replace_pipeline_name_in_json_get_s3_uri(obj: Union[RequestType, Any], pipeline_name: str): - """Replace the ExecutionVariables.PIPELINE_NAME in DelayedReturn's JsonGet s3_uri - - with the pipeline_name, because ExecutionVariables.PIPELINE_NAME - is parsed as all lower-cased str in backend. - """ - if not isinstance(obj, DelayedReturn): - return obj - - json_get = obj._to_json_get() - - if not json_get.s3_uri: - return obj - # the s3 uri has to be a Join, which has been validated in JsonGet init - for i in range(len(json_get.s3_uri.values)): - if json_get.s3_uri.values[i] == ExecutionVariables.PIPELINE_NAME: - json_get.s3_uri.values[i] = pipeline_name - return json_get - - def _map_callback_outputs(steps: List[Step]): """Iterate over the provided steps, building a map of callback output parameters to step names. diff --git a/tests/integ/sagemaker/workflow/helpers.py b/tests/integ/sagemaker/workflow/helpers.py index 48e1e95734..20365ef169 100644 --- a/tests/integ/sagemaker/workflow/helpers.py +++ b/tests/integ/sagemaker/workflow/helpers.py @@ -33,7 +33,7 @@ def create_and_execute_pipeline( region_name, role, no_of_steps, - last_step_name, + last_step_name_prefix, execution_parameters, step_status, step_result_type=None, @@ -66,7 +66,7 @@ def create_and_execute_pipeline( len(execution_steps) == no_of_steps ), f"Expected {no_of_steps}, instead found {len(execution_steps)}" - assert last_step_name in execution_steps[0]["StepName"] + assert last_step_name_prefix in execution_steps[0]["StepName"] assert execution_steps[0]["StepStatus"] == step_status if step_result_type: result = execution.result(execution_steps[0]["StepName"]) diff --git a/tests/integ/sagemaker/workflow/test_selective_execution.py b/tests/integ/sagemaker/workflow/test_selective_execution.py index a2c0286c6a..a584c095d5 100644 --- a/tests/integ/sagemaker/workflow/test_selective_execution.py +++ b/tests/integ/sagemaker/workflow/test_selective_execution.py @@ -16,6 +16,7 @@ import pytest +from sagemaker.processing import ProcessingInput from tests.integ import DATA_DIR from sagemaker.sklearn import SKLearnProcessor from sagemaker.workflow.step_outputs import get_step @@ -84,7 +85,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -97,7 +98,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -115,7 +116,7 @@ def sum(a, b): pass -def test_selective_execution_of_regular_step_depended_by_function_step( +def test_selective_execution_of_regular_step_referenced_by_function_step( sagemaker_session, role, pipeline_name, @@ -168,7 +169,7 @@ def func_2(arg): region_name=region_name, role=role, no_of_steps=2, - last_step_name="func", + last_step_name_prefix="func", execution_parameters=dict(), step_status="Succeeded", step_result_type=str, @@ -182,7 +183,7 @@ def func_2(arg): region_name=region_name, role=role, no_of_steps=2, - last_step_name="func", + last_step_name_prefix="func", execution_parameters=dict(), step_status="Succeeded", step_result_type=str, @@ -199,3 +200,102 @@ def func_2(arg): pipeline.delete() except Exception: pass + + +def test_selective_execution_of_function_step_referenced_by_regular_step( + pipeline_session, + role, + pipeline_name, + region_name, + dummy_container_without_error, + sklearn_latest_version, +): + # Test Selective Pipeline Execution on function step -> [select: regular step] + os.environ["AWS_DEFAULT_REGION"] = region_name + processing_job_instance_counts = 2 + + @step( + name="step1", + role=role, + image_uri=dummy_container_without_error, + instance_type=INSTANCE_TYPE, + keep_alive_period_in_seconds=60, + ) + def func(var: int): + return 1, var + + step_output = func(processing_job_instance_counts) + + script_path = os.path.join(DATA_DIR, "dummy_script.py") + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") + inputs = [ + ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/"), + ] + + sklearn_processor = SKLearnProcessor( + framework_version=sklearn_latest_version, + role=role, + instance_type=INSTANCE_TYPE, + instance_count=step_output[1], + command=["python3"], + sagemaker_session=pipeline_session, + base_job_name="test-sklearn", + ) + + step_args = sklearn_processor.run( + inputs=inputs, + code=script_path, + ) + process_step = ProcessingStep( + name="MyProcessStep", + step_args=step_args, + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[process_step], + sagemaker_session=pipeline_session, + ) + + try: + execution, _ = create_and_execute_pipeline( + pipeline=pipeline, + pipeline_name=pipeline_name, + region_name=region_name, + role=role, + no_of_steps=2, + last_step_name_prefix=process_step.name, + execution_parameters=dict(), + step_status="Succeeded", + wait_duration=1000, # seconds + ) + + _, execution_steps2 = create_and_execute_pipeline( + pipeline=pipeline, + pipeline_name=pipeline_name, + region_name=region_name, + role=role, + no_of_steps=2, + last_step_name_prefix=process_step.name, + execution_parameters=dict(), + step_status="Succeeded", + wait_duration=1000, # seconds + selective_execution_config=SelectiveExecutionConfig( + source_pipeline_execution_arn=execution.arn, + selected_steps=[process_step.name], + ), + ) + + execution_proc_job = pipeline_session.describe_processing_job( + execution_steps2[0]["Metadata"]["ProcessingJob"]["Arn"].split("/")[-1] + ) + assert ( + execution_proc_job["ProcessingResources"]["ClusterConfig"]["InstanceCount"] + == processing_job_instance_counts + ) + + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/integ/sagemaker/workflow/test_step_decorator.py b/tests/integ/sagemaker/workflow/test_step_decorator.py index bdd18a16f2..3c19a37cc3 100644 --- a/tests/integ/sagemaker/workflow/test_step_decorator.py +++ b/tests/integ/sagemaker/workflow/test_step_decorator.py @@ -159,7 +159,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -203,7 +203,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -252,7 +252,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=2, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -297,7 +297,7 @@ def sum(a, b): region_name=region_name, role=role, no_of_steps=1, - last_step_name="sum", + last_step_name_prefix="sum", execution_parameters=dict(TrainingInstanceCount="ml.m5.xlarge"), step_status="Succeeded", step_result_type=int, @@ -386,7 +386,7 @@ def func_2(*args): region_name=region_name, role=role, no_of_steps=3, - last_step_name="func", + last_step_name_prefix="func", execution_parameters=dict(param_a=3), step_status="Succeeded", step_result_type=tuple, @@ -438,7 +438,7 @@ def validate_file_exists(files_exists, files_does_not_exist): region_name=region_name, role=role, no_of_steps=1, - last_step_name="validate_file_exists", + last_step_name_prefix="validate_file_exists", execution_parameters=dict(), step_status="Succeeded", ) @@ -493,7 +493,7 @@ def train(x): region_name=region_name, role=role, no_of_steps=1, - last_step_name="train", + last_step_name_prefix="train", execution_parameters=dict(), step_status="Succeeded", step_result_type=int, @@ -539,7 +539,7 @@ def cuberoot(x): region_name=region_name, role=role, no_of_steps=1, - last_step_name="cuberoot", + last_step_name_prefix="cuberoot", execution_parameters=dict(), step_status="Succeeded", step_result_type=numpy.float64, @@ -585,7 +585,7 @@ def divide(x, y): region_name=region_name, role=role, no_of_steps=1, - last_step_name="divide", + last_step_name_prefix="divide", execution_parameters=dict(), step_status="Failed", ) @@ -661,7 +661,7 @@ def func3(): region_name=region_name, role=role, no_of_steps=4, # The FailStep in else branch is not executed - last_step_name="MyConditionStep", + last_step_name_prefix="MyConditionStep", execution_parameters=dict(), step_status="Succeeded", ) @@ -733,7 +733,7 @@ def func(var: int): region_name=region_name, role=role, no_of_steps=2, - last_step_name=process_step.name, + last_step_name_prefix=process_step.name, execution_parameters=dict(), step_status="Succeeded", wait_duration=1000, # seconds @@ -846,7 +846,7 @@ def cuberoot(x): region_name=region_name, role=role, no_of_steps=1, - last_step_name="cuberoot", + last_step_name_prefix="cuberoot", execution_parameters=dict(), step_status="Succeeded", step_result_type=numpy.float64, @@ -890,7 +890,7 @@ def my_func(): region_name=region_name, role=role, no_of_steps=1, - last_step_name=get_step(step_a).name, + last_step_name_prefix=get_step(step_a).name, execution_parameters=dict(), step_status="Failed", ) @@ -950,7 +950,7 @@ def func_with_collision(var: str): region_name=region_name, role=role, no_of_steps=2, - last_step_name=get_step(step_output_b).name, + last_step_name_prefix=get_step(step_output_b).name, execution_parameters=dict(), step_status="Succeeded", step_result_type=str, diff --git a/tests/unit/sagemaker/workflow/test_condition_step.py b/tests/unit/sagemaker/workflow/test_condition_step.py index 315d549cce..019b5561ca 100644 --- a/tests/unit/sagemaker/workflow/test_condition_step.py +++ b/tests/unit/sagemaker/workflow/test_condition_step.py @@ -626,11 +626,7 @@ def _get_expected_jsonget_expr(step_name: str, path: str): "Std:Join": { "On": "/", "Values": [ - "s3://s3_bucket/test-prefix", - "MyPipeline", - {"Get": "Execution.PipelineExecutionId"}, - step_name, - "results", + {"Get": f"Steps.{step_name}.OutputDataConfig.S3OutputPath"}, "results.json", ], } diff --git a/tests/unit/sagemaker/workflow/test_function_step.py b/tests/unit/sagemaker/workflow/test_function_step.py index 888635ae02..25109fdc97 100644 --- a/tests/unit/sagemaker/workflow/test_function_step.py +++ b/tests/unit/sagemaker/workflow/test_function_step.py @@ -303,15 +303,7 @@ def func() -> type_hint: @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -@patch("sagemaker.remote_function.job._JobSettings") -def test_step_function_with_no_hint_on_return_values(mock_job_settings_ctr): - s3_root_uri = "s3://bucket" - mock_job_settings = Mock() - mock_job_settings.s3_root_uri = s3_root_uri - mock_job_settings.sagemaker_session = MOCKED_PIPELINE_CONFIG.sagemaker_session - - mock_job_settings_ctr.return_value = mock_job_settings - +def test_step_function_with_no_hint_on_return_values(): @step(name="step_name") def func(): return 1, 2, 3 @@ -330,11 +322,7 @@ def func(): "Std:Join": { "On": "/", "Values": [ - "s3://bucket", - {"Get": "Execution.PipelineName"}, - {"Get": "Execution.PipelineExecutionId"}, - "step_name", - "results", + {"Get": "Steps.step_name.OutputDataConfig.S3OutputPath"}, "results.json", ], } @@ -354,11 +342,7 @@ def func(): "Std:Join": { "On": "/", "Values": [ - "s3://bucket", - {"Get": "Execution.PipelineName"}, - {"Get": "Execution.PipelineExecutionId"}, - "step_name", - "results", + {"Get": "Steps.step_name.OutputDataConfig.S3OutputPath"}, "results.json", ], } @@ -366,8 +350,6 @@ def func(): } } - mock_job_settings_ctr.assert_called_once() - with pytest.raises(NotImplementedError): for _ in step_output: pass diff --git a/tests/unit/sagemaker/workflow/test_functions.py b/tests/unit/sagemaker/workflow/test_functions.py index 040899d883..61e1424bbf 100644 --- a/tests/unit/sagemaker/workflow/test_functions.py +++ b/tests/unit/sagemaker/workflow/test_functions.py @@ -279,8 +279,8 @@ def test_json_get_invalid_s3_uri_not_join(): def test_json_get_invalid_s3_uri_with_invalid_pipeline_variable(sagemaker_session): with pytest.raises(ValueError) as e: - JsonGet(s3_uri=Join(on="/", values=["s3:/", Properties(step_name="test")])) + JsonGet(s3_uri=Join(on="/", values=["s3:/", Join()])) assert ( - "The Join values in JsonGet's s3_uri can only be a primitive object or Parameter or ExecutionVariable." - in str(e.value) + "The Join values in JsonGet's s3_uri can only be a primitive object, " + "Parameter, ExecutionVariable or Properties." in str(e.value) ) diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index d658455d62..14c2d442eb 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -33,7 +33,6 @@ from sagemaker.workflow.pipeline import ( Pipeline, PipelineGraph, - _replace_pipeline_name_in_json_get_s3_uri, ) from sagemaker.workflow.pipeline_context import _PipelineConfig from sagemaker.workflow.pipeline_definition_config import PipelineDefinitionConfig @@ -1009,19 +1008,3 @@ def func(): (parameter, False), (delayed_return, True), ] - - -@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -@patch("sagemaker.remote_function.job._JobSettings", Mock()) -@pytest.mark.parametrize("obj, is_replaced", _generate_parameters_for_replace_pipeline_name_test()) -def test_replace_pipeline_name_in_json_get_s3_uri(obj, is_replaced): - updated_obj = _replace_pipeline_name_in_json_get_s3_uri( - obj=obj, - pipeline_name=_PIPELINE_NAME, - ) - if is_replaced: - assert updated_obj != obj - assert "Execution.PipelineName" not in str(updated_obj.expr) - assert _PIPELINE_NAME in str(updated_obj.expr) - else: - assert updated_obj == obj