diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index d8674f269d..7beef2e5bd 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -17,7 +17,7 @@ and interpretation on Amazon SageMaker. """ from __future__ import absolute_import - +import json import logging import os import pathlib @@ -314,6 +314,15 @@ def _normalize_args( "code argument has to be a valid S3 URI or local file path " + "rather than a pipeline variable" ) + if arguments is not None: + processed_arguments = [] + for arg in arguments: + if isinstance(arg, PipelineVariable): + processed_value = json.dumps(arg.expr) + processed_arguments.append(processed_value) + else: + processed_arguments.append(str(arg)) + arguments = processed_arguments self._current_job_name = self._generate_current_job_name(job_name=job_name) diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 0dcd7c2495..f94e0791cb 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -824,7 +824,12 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session): processor, run_inputs = spark_processor processor.sagemaker_session = pipeline_session processor.role = ROLE - + arguments_output = [ + "--input", + "input-data-uri", + "--output", + '{"Get": "Parameters.MyArgOutput"}', + ] run_inputs["inputs"] = processing_input step_args = processor.run(**run_inputs) @@ -835,7 +840,7 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session): step_args = get_step_args_helper(step_args, "Processing") - assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"] + assert step_args["AppSpecification"]["ContainerArguments"] == arguments_output entry_points = step_args["AppSpecification"]["ContainerEntrypoint"] entry_points_expr = [] @@ -1019,6 +1024,12 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_ processor, run_inputs = spark_processor processor.sagemaker_session = pipeline_session processor.role = ROLE + arguments_output = [ + "--input", + "input-data-uri", + "--output", + '{"Get": "Parameters.MyArgOutput"}', + ] run_inputs["inputs"] = processing_input @@ -1030,7 +1041,7 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_ step_args = get_step_args_helper(step_args, "Processing") - assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"] + assert step_args["AppSpecification"]["ContainerArguments"] == arguments_output entry_points = step_args["AppSpecification"]["ContainerEntrypoint"] entry_points_expr = [] diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 06d2cde02e..7b020c61bf 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -46,8 +46,9 @@ from sagemaker.fw_utils import UploadedCode from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.workflow.functions import Join -from sagemaker.workflow.execution_variables import ExecutionVariables +from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables from tests.unit import SAGEMAKER_CONFIG_PROCESSING_JOB +from sagemaker.workflow.parameters import ParameterString BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -1717,3 +1718,249 @@ def _get_describe_response_inputs_and_ouputs(): "ProcessingInputs": _get_expected_args_all_parameters(None)["inputs"], "ProcessingOutputConfig": _get_expected_args_all_parameters(None)["output_config"], } + + +# Parameters +def _get_data_inputs_with_parameters(): + return [ + ProcessingInput( + source=ParameterString(name="input_data", default_value="s3://dummy-bucket/input"), + destination="/opt/ml/processing/input", + input_name="input-1", + ) + ] + + +def _get_data_outputs_with_parameters(): + return [ + ProcessingOutput( + source="/opt/ml/processing/output", + destination=ParameterString( + name="output_data", default_value="s3://dummy-bucket/output" + ), + output_name="output-1", + ) + ] + + +def _get_expected_args_with_parameters(job_name): + return { + "inputs": [ + { + "InputName": "input-1", + "S3Input": { + "S3Uri": "s3://dummy-bucket/input", + "LocalPath": "/opt/ml/processing/input", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None", + }, + } + ], + "output_config": { + "Outputs": [ + { + "OutputName": "output-1", + "S3Output": { + "S3Uri": "s3://dummy-bucket/output", + "LocalPath": "/opt/ml/processing/output", + "S3UploadMode": "EndOfJob", + }, + } + ] + }, + "job_name": job_name, + "resources": { + "ClusterConfig": { + "InstanceType": "ml.m4.xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 100, + "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", + } + }, + "stopping_condition": {"MaxRuntimeInSeconds": 3600}, + "app_specification": { + "ImageUri": "custom-image-uri", + "ContainerArguments": [ + "--input-data", + "s3://dummy-bucket/input-param", + "--output-path", + "s3://dummy-bucket/output-param", + ], + "ContainerEntrypoint": ["python3"], + }, + "environment": {"my_env_variable": "my_env_variable_value"}, + "network_config": { + "EnableNetworkIsolation": True, + "EnableInterContainerTrafficEncryption": True, + "VpcConfig": { + "Subnets": ["my_subnet_id"], + "SecurityGroupIds": ["my_security_group_id"], + }, + }, + "role_arn": "dummy/role", + "tags": [{"Key": "my-tag", "Value": "my-tag-value"}], + "experiment_config": {"ExperimentName": "AnExperiment"}, + } + + +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=True) +@patch("sagemaker.utils.repack_model") +@patch("sagemaker.utils.create_tar_file") +@patch("sagemaker.session.Session.upload_data") +def test_script_processor_with_parameter_string( + upload_data_mock, + create_tar_file_mock, + repack_model_mock, + exists_mock, + isfile_mock, + sagemaker_session, +): + """Test ScriptProcessor with ParameterString arguments""" + upload_data_mock.return_value = "s3://mocked_s3_uri_from_upload_data" + + # Setup processor + processor = ScriptProcessor( + role="arn:aws:iam::012345678901:role/SageMakerRole", # Updated role ARN + image_uri="custom-image-uri", + command=["python3"], + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=100, + volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", + output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key", + max_runtime_in_seconds=3600, + base_job_name="test_processor", + env={"my_env_variable": "my_env_variable_value"}, + tags=[{"Key": "my-tag", "Value": "my-tag-value"}], + network_config=NetworkConfig( + subnets=["my_subnet_id"], + security_group_ids=["my_security_group_id"], + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + ), + sagemaker_session=sagemaker_session, + ) + + input_param = ParameterString(name="input_param", default_value="s3://dummy-bucket/input-param") + output_param = ParameterString( + name="output_param", default_value="s3://dummy-bucket/output-param" + ) + exec_var = ExecutionVariable(name="ExecutionTest") + join_var = Join(on="/", values=["s3://bucket", "prefix", "file.txt"]) + dummy_str_var = "test-variable" + + # Define expected arguments + expected_args = { + "inputs": [ + { + "InputName": "input-1", + "AppManaged": False, + "S3Input": { + "S3Uri": ParameterString( + name="input_data", default_value="s3://dummy-bucket/input" + ), + "LocalPath": "/opt/ml/processing/input", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None", + }, + }, + { + "InputName": "code", + "AppManaged": False, + "S3Input": { + "S3Uri": "s3://mocked_s3_uri_from_upload_data", + "LocalPath": "/opt/ml/processing/input/code", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None", + }, + }, + ], + "output_config": { + "Outputs": [ + { + "OutputName": "output-1", + "AppManaged": False, + "S3Output": { + "S3Uri": ParameterString( + name="output_data", default_value="s3://dummy-bucket/output" + ), + "LocalPath": "/opt/ml/processing/output", + "S3UploadMode": "EndOfJob", + }, + } + ], + "KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key", + }, + "job_name": "test_job", + "resources": { + "ClusterConfig": { + "InstanceType": "ml.m4.xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 100, + "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", + } + }, + "stopping_condition": {"MaxRuntimeInSeconds": 3600}, + "app_specification": { + "ImageUri": "custom-image-uri", + "ContainerArguments": [ + "--input-data", + '{"Get": "Parameters.input_param"}', + "--output-path", + '{"Get": "Parameters.output_param"}', + "--exec-arg", + '{"Get": "Execution.ExecutionTest"}', + "--join-arg", + '{"Std:Join": {"On": "/", "Values": ["s3://bucket", "prefix", "file.txt"]}}', + "--string-param", + "test-variable", + ], + "ContainerEntrypoint": ["python3", "/opt/ml/processing/input/code/processing_code.py"], + }, + "environment": {"my_env_variable": "my_env_variable_value"}, + "network_config": { + "EnableNetworkIsolation": True, + "EnableInterContainerTrafficEncryption": True, + "VpcConfig": { + "SecurityGroupIds": ["my_security_group_id"], + "Subnets": ["my_subnet_id"], + }, + }, + "role_arn": "arn:aws:iam::012345678901:role/SageMakerRole", + "tags": [{"Key": "my-tag", "Value": "my-tag-value"}], + "experiment_config": {"ExperimentName": "AnExperiment"}, + } + + # Run processor + processor.run( + code="/local/path/to/processing_code.py", + inputs=_get_data_inputs_with_parameters(), + outputs=_get_data_outputs_with_parameters(), + arguments=[ + "--input-data", + input_param, + "--output-path", + output_param, + "--exec-arg", + exec_var, + "--join-arg", + join_var, + "--string-param", + dummy_str_var, + ], + wait=True, + logs=False, + job_name="test_job", + experiment_config={"ExperimentName": "AnExperiment"}, + ) + + # Assert + sagemaker_session.process.assert_called_with(**expected_args) + assert "test_job" in processor._current_job_name