diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index aab9279f78..90f6a3d8ae 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -31,13 +31,20 @@ from io import BytesIO from urllib.parse import urlparse +from typing import Union, List, Dict, Optional + from sagemaker import image_uris from sagemaker.local.image import _ecr_login_if_needed, _pull_image from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor from sagemaker.s3 import S3Uploader from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.spark import defaults +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.functions import Join + logger = logging.getLogger(__name__) @@ -249,6 +256,12 @@ def run( """ self._current_job_name = self._generate_current_job_name(job_name=job_name) + if is_pipeline_variable(submit_app): + raise ValueError( + "submit_app argument has to be a valid S3 URI or local file path " + + "rather than a pipeline variable" + ) + return super().run( submit_app, inputs, @@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): use_input_channel = False spark_opt_s3_uris = [] + spark_opt_s3_uris_has_pipeline_var = False with tempfile.TemporaryDirectory() as tmpdir: for dep_path in submit_deps: + if is_pipeline_variable(dep_path): + spark_opt_s3_uris.append(dep_path) + spark_opt_s3_uris_has_pipeline_var = True + continue dep_url = urlparse(dep_path) # S3 URIs are included as-is in the spark-submit argument if dep_url.scheme in ["s3", "s3a"]: @@ -482,11 +500,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): destination=f"{self._conf_container_base_path}{input_channel_name}", input_name=input_channel_name, ) - spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination]) + spark_opt = ( + Join(on=",", values=spark_opt_s3_uris + [input_channel.destination]) + if spark_opt_s3_uris_has_pipeline_var + else ",".join(spark_opt_s3_uris + [input_channel.destination]) + ) # If no local files were uploaded, form the spark-submit option from a list of S3 URIs else: input_channel = None - spark_opt = ",".join(spark_opt_s3_uris) + spark_opt = ( + Join(on=",", values=spark_opt_s3_uris) + if spark_opt_s3_uris_has_pipeline_var + else ",".join(spark_opt_s3_uris) + ) return input_channel, spark_opt @@ -592,6 +618,9 @@ def _validate_s3_uri(self, spark_output_s3_path): Args: spark_output_s3_path (str): The URI of the Spark output S3 Path. """ + if is_pipeline_variable(spark_output_s3_path): + return + if urlparse(spark_output_s3_path).scheme != "s3": raise ValueError( f"Invalid s3 path: {spark_output_s3_path}. Please enter something like " @@ -650,22 +679,22 @@ class PySparkProcessor(_SparkProcessorBase): def __init__( self, - role, - instance_type, - instance_count, - framework_version=None, - py_version=None, - container_version=None, - image_uri=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + instance_type: Union[int, PipelineVariable], + instance_count: Union[str, PipelineVariable], + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + container_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initialize an ``PySparkProcessor`` instance. @@ -795,20 +824,20 @@ def get_run_args( def run( self, - submit_app, - submit_py_files=None, - submit_jars=None, - submit_files=None, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - configuration=None, - spark_event_logs_s3_uri=None, - kms_key=None, + submit_app: str, + submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None, + submit_jars: Optional[List[Union[str, PipelineVariable]]] = None, + submit_files: Optional[List[Union[str, PipelineVariable]]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingOutput]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + configuration: Optional[Union[List[Dict], Dict]] = None, + spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. @@ -907,22 +936,22 @@ class SparkJarProcessor(_SparkProcessorBase): def __init__( self, - role, - instance_type, - instance_count, - framework_version=None, - py_version=None, - container_version=None, - image_uri=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + instance_type: Union[int, PipelineVariable], + instance_count: Union[str, PipelineVariable], + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + container_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initialize a ``SparkJarProcessor`` instance. @@ -1052,20 +1081,20 @@ def get_run_args( def run( self, - submit_app, - submit_class=None, - submit_jars=None, - submit_files=None, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - configuration=None, - spark_event_logs_s3_uri=None, - kms_key=None, + submit_app: str, + submit_class: Union[str, PipelineVariable], + submit_jars: Optional[List[Union[str, PipelineVariable]]] = None, + submit_files: Optional[List[Union[str, PipelineVariable]]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingOutput]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + configuration: Optional[Union[List[Dict], Dict]] = None, + spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index f2347bbf11..262d0eb558 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -46,6 +46,7 @@ from sagemaker.workflow.properties import PropertyFile from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.functions import Join +from sagemaker.workflow import is_pipeline_variable from sagemaker.network import NetworkConfig from sagemaker.pytorch.estimator import PyTorch @@ -149,31 +150,6 @@ ), {}, ), - ( - SparkJarProcessor( - role=ROLE, - framework_version="2.4", - instance_count=1, - instance_type=INSTANCE_TYPE, - ), - { - "submit_app": "s3://my-jar", - "submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", - "arguments": ["--input", "input-data-uri", "--output", "output-data-uri"], - }, - ), - ( - PySparkProcessor( - role=ROLE, - framework_version="2.4", - instance_count=1, - instance_type=INSTANCE_TYPE, - ), - { - "submit_app": "s3://my-jar", - "arguments": ["--input", "input-data-uri", "--output", "output-data-uri"], - }, - ), ] PROCESSING_INPUT = [ @@ -641,3 +617,204 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session): assert "The step_args of ProcessingStep must be obtained from processor.run()" in str( error.value ) + + +@pytest.mark.parametrize( + "spark_processor", + [ + ( + SparkJarProcessor( + role=ROLE, + framework_version="2.4", + instance_count=1, + instance_type=INSTANCE_TYPE, + ), + { + "submit_app": "s3://my-jar", + "submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", + "arguments": [ + "--input", + "input-data-uri", + "--output", + ParameterString("MyArgOutput"), + ], + "submit_jars": [ + "s3://my-jar", + ParameterString("MyJars"), + "s3://her-jar", + ParameterString("OurJar"), + ], + "submit_files": [ + "s3://my-files", + ParameterString("MyFiles"), + "s3://her-files", + ParameterString("OurFiles"), + ], + "spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"), + }, + ), + ( + PySparkProcessor( + role=ROLE, + framework_version="2.4", + instance_count=1, + instance_type=INSTANCE_TYPE, + ), + { + "submit_app": "s3://my-jar", + "arguments": [ + "--input", + "input-data-uri", + "--output", + ParameterString("MyArgOutput"), + ], + "submit_py_files": [ + "s3://my-py-files", + ParameterString("MyPyFiles"), + "s3://her-pyfiles", + ParameterString("OurPyFiles"), + ], + "submit_jars": [ + "s3://my-jar", + ParameterString("MyJars"), + "s3://her-jar", + ParameterString("OurJar"), + ], + "submit_files": [ + "s3://my-files", + ParameterString("MyFiles"), + "s3://her-files", + ParameterString("OurFiles"), + ], + "spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"), + }, + ), + ], +) +def test_spark_processor(spark_processor, processing_input, pipeline_session): + + processor, run_inputs = spark_processor + processor.sagemaker_session = pipeline_session + processor.role = ROLE + + run_inputs["inputs"] = processing_input + + step_args = processor.run(**run_inputs) + step = ProcessingStep( + name="MyProcessingStep", + step_args=step_args, + ) + + step_args = step_args.args + + assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"] + + entry_points = step_args["AppSpecification"]["ContainerEntrypoint"] + entry_points_expr = [] + for entry_point in entry_points: + if is_pipeline_variable(entry_point): + entry_points_expr.append(entry_point.expr) + else: + entry_points_expr.append(entry_point) + + if "submit_py_files" in run_inputs: + expected = [ + "smspark-submit", + "--py-files", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-py-files", + {"Get": "Parameters.MyPyFiles"}, + "s3://her-pyfiles", + {"Get": "Parameters.OurPyFiles"}, + ], + } + }, + "--jars", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-jar", + {"Get": "Parameters.MyJars"}, + "s3://her-jar", + {"Get": "Parameters.OurJar"}, + ], + } + }, + "--files", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-files", + {"Get": "Parameters.MyFiles"}, + "s3://her-files", + {"Get": "Parameters.OurFiles"}, + ], + } + }, + "--local-spark-event-logs-dir", + "/opt/ml/processing/spark-events/", + "/opt/ml/processing/input/code", + ] + # py spark + else: + expected = [ + "smspark-submit", + "--class", + "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", + "--jars", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-jar", + {"Get": "Parameters.MyJars"}, + "s3://her-jar", + {"Get": "Parameters.OurJar"}, + ], + } + }, + "--files", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-files", + {"Get": "Parameters.MyFiles"}, + "s3://her-files", + {"Get": "Parameters.OurFiles"}, + ], + } + }, + "--local-spark-event-logs-dir", + "/opt/ml/processing/spark-events/", + "/opt/ml/processing/input/code", + ] + + assert entry_points_expr == expected + for output in step_args["ProcessingOutputConfig"]["Outputs"]: + if is_pipeline_variable(output["S3Output"]["S3Uri"]): + output["S3Output"]["S3Uri"] = output["S3Output"]["S3Uri"].expr + + assert step_args["ProcessingOutputConfig"]["Outputs"] == [ + { + "OutputName": "output-1", + "AppManaged": False, + "S3Output": { + "S3Uri": {"Get": "Parameters.MySparkEventLogS3Uri"}, + "LocalPath": "/opt/ml/processing/spark-events/", + "S3UploadMode": "Continuous", + }, + } + ] + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + pipeline.definition()