Skip to content

Commit b84a818

Browse files
committed
support pipeline variables for spark processors run arguments
1 parent 2b5b4da commit b84a818

File tree

2 files changed

+199
-87
lines changed

2 files changed

+199
-87
lines changed

src/sagemaker/spark/processing.py

Lines changed: 85 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,20 @@
3131
from io import BytesIO
3232
from urllib.parse import urlparse
3333

34+
from typing import Union, List, Dict, Optional
35+
3436
from sagemaker import image_uris
3537
from sagemaker.local.image import _ecr_login_if_needed, _pull_image
3638
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
3739
from sagemaker.s3 import S3Uploader
3840
from sagemaker.session import Session
41+
from sagemaker.network import NetworkConfig
3942
from sagemaker.spark import defaults
4043

44+
from sagemaker.workflow import is_pipeline_variable
45+
from sagemaker.workflow.entities import PipelineVariable
46+
from sagemaker.workflow.functions import Join
47+
4148
logger = logging.getLogger(__name__)
4249

4350

@@ -249,6 +256,12 @@ def run(
249256
"""
250257
self._current_job_name = self._generate_current_job_name(job_name=job_name)
251258

259+
if is_pipeline_variable(submit_app):
260+
raise ValueError(
261+
"submit_app argument has to be a valid S3 URI or local file path "
262+
+ "rather than a pipeline variable"
263+
)
264+
252265
return super().run(
253266
submit_app,
254267
inputs,
@@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
437450

438451
use_input_channel = False
439452
spark_opt_s3_uris = []
453+
spark_opt_s3_uris_has_pipeline_var = False
440454

441455
with tempfile.TemporaryDirectory() as tmpdir:
442456
for dep_path in submit_deps:
457+
if is_pipeline_variable(dep_path):
458+
spark_opt_s3_uris.append(dep_path)
459+
spark_opt_s3_uris_has_pipeline_var = True
460+
continue
443461
dep_url = urlparse(dep_path)
444462
# S3 URIs are included as-is in the spark-submit argument
445463
if dep_url.scheme in ["s3", "s3a"]:
@@ -482,11 +500,13 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
482500
destination=f"{self._conf_container_base_path}{input_channel_name}",
483501
input_name=input_channel_name,
484502
)
485-
spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination])
503+
spark_opt = Join(on=",", values=spark_opt_s3_uris + [input_channel.destination]) \
504+
if spark_opt_s3_uris_has_pipeline_var else ",".join(spark_opt_s3_uris + [input_channel.destination])
486505
# If no local files were uploaded, form the spark-submit option from a list of S3 URIs
487506
else:
488507
input_channel = None
489-
spark_opt = ",".join(spark_opt_s3_uris)
508+
spark_opt = Join(on=",", values=spark_opt_s3_uris) if spark_opt_s3_uris_has_pipeline_var \
509+
else ",".join(spark_opt_s3_uris)
490510

491511
return input_channel, spark_opt
492512

@@ -592,6 +612,9 @@ def _validate_s3_uri(self, spark_output_s3_path):
592612
Args:
593613
spark_output_s3_path (str): The URI of the Spark output S3 Path.
594614
"""
615+
if is_pipeline_variable(spark_output_s3_path):
616+
return
617+
595618
if urlparse(spark_output_s3_path).scheme != "s3":
596619
raise ValueError(
597620
f"Invalid s3 path: {spark_output_s3_path}. Please enter something like "
@@ -650,22 +673,22 @@ class PySparkProcessor(_SparkProcessorBase):
650673

651674
def __init__(
652675
self,
653-
role,
654-
instance_type,
655-
instance_count,
656-
framework_version=None,
657-
py_version=None,
658-
container_version=None,
659-
image_uri=None,
660-
volume_size_in_gb=30,
661-
volume_kms_key=None,
662-
output_kms_key=None,
663-
max_runtime_in_seconds=None,
664-
base_job_name=None,
665-
sagemaker_session=None,
666-
env=None,
667-
tags=None,
668-
network_config=None,
676+
role: str,
677+
instance_type: Union[int, PipelineVariable],
678+
instance_count: Union[str, PipelineVariable],
679+
framework_version: Optional[str] = None,
680+
py_version: Optional[str] = None,
681+
container_version: Optional[str] = None,
682+
image_uri: Optional[Union[str, PipelineVariable]] = None,
683+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
684+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
685+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
686+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
687+
base_job_name: Optional[str] = None,
688+
sagemaker_session: Optional[Session] = None,
689+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
690+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
691+
network_config: Optional[NetworkConfig] = None,
669692
):
670693
"""Initialize an ``PySparkProcessor`` instance.
671694
@@ -795,20 +818,20 @@ def get_run_args(
795818

796819
def run(
797820
self,
798-
submit_app,
799-
submit_py_files=None,
800-
submit_jars=None,
801-
submit_files=None,
802-
inputs=None,
803-
outputs=None,
804-
arguments=None,
805-
wait=True,
806-
logs=True,
807-
job_name=None,
808-
experiment_config=None,
809-
configuration=None,
810-
spark_event_logs_s3_uri=None,
811-
kms_key=None,
821+
submit_app: str,
822+
submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None,
823+
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
824+
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
825+
inputs: Optional[List[ProcessingInput]] = None,
826+
outputs: Optional[List[ProcessingOutput]] = None,
827+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
828+
wait: bool = True,
829+
logs: bool = True,
830+
job_name: Optional[str] = None,
831+
experiment_config: Optional[Dict[str, str]] = None,
832+
configuration: Optional[Union[List[Dict], Dict]] = None,
833+
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
834+
kms_key: Optional[str] = None,
812835
):
813836
"""Runs a processing job.
814837
@@ -907,22 +930,22 @@ class SparkJarProcessor(_SparkProcessorBase):
907930

908931
def __init__(
909932
self,
910-
role,
911-
instance_type,
912-
instance_count,
913-
framework_version=None,
914-
py_version=None,
915-
container_version=None,
916-
image_uri=None,
917-
volume_size_in_gb=30,
918-
volume_kms_key=None,
919-
output_kms_key=None,
920-
max_runtime_in_seconds=None,
921-
base_job_name=None,
922-
sagemaker_session=None,
923-
env=None,
924-
tags=None,
925-
network_config=None,
933+
role: str,
934+
instance_type: Union[int, PipelineVariable],
935+
instance_count: Union[str, PipelineVariable],
936+
framework_version: Optional[str] = None,
937+
py_version: Optional[str] = None,
938+
container_version: Optional[str] = None,
939+
image_uri: Optional[Union[str, PipelineVariable]] = None,
940+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
941+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
942+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
943+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
944+
base_job_name: Optional[str] = None,
945+
sagemaker_session: Optional[Session] = None,
946+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
947+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
948+
network_config: Optional[NetworkConfig] = None,
926949
):
927950
"""Initialize a ``SparkJarProcessor`` instance.
928951
@@ -1052,20 +1075,20 @@ def get_run_args(
10521075

10531076
def run(
10541077
self,
1055-
submit_app,
1056-
submit_class=None,
1057-
submit_jars=None,
1058-
submit_files=None,
1059-
inputs=None,
1060-
outputs=None,
1061-
arguments=None,
1062-
wait=True,
1063-
logs=True,
1064-
job_name=None,
1065-
experiment_config=None,
1066-
configuration=None,
1067-
spark_event_logs_s3_uri=None,
1068-
kms_key=None,
1078+
submit_app: str,
1079+
submit_class: Union[str, PipelineVariable],
1080+
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
1081+
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
1082+
inputs: Optional[List[ProcessingInput]] = None,
1083+
outputs: Optional[List[ProcessingOutput]] = None,
1084+
arguments:Optional[List[Union[str, PipelineVariable]]] = None,
1085+
wait: bool = True,
1086+
logs: bool = True,
1087+
job_name: Optional[str] = None,
1088+
experiment_config: Optional[Dict[str, str]] = None,
1089+
configuration: Optional[Union[List[Dict], Dict]] = None,
1090+
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
1091+
kms_key: Optional[str] = None,
10691092
):
10701093
"""Runs a processing job.
10711094

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 114 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from sagemaker.transformer import Transformer
2424
from sagemaker.tuner import HyperparameterTuner
2525
from sagemaker.workflow.pipeline_context import PipelineSession
26+
from sagemaker.workflow.functions import Join
27+
from sagemaker.workflow.parameters import ParameterString
2628

2729
from sagemaker.processing import (
2830
Processor,
@@ -46,6 +48,7 @@
4648
from sagemaker.workflow.properties import PropertyFile
4749
from sagemaker.workflow.parameters import ParameterString
4850
from sagemaker.workflow.functions import Join
51+
from sagemaker.workflow import is_pipeline_variable
4952

5053
from sagemaker.network import NetworkConfig
5154
from sagemaker.pytorch.estimator import PyTorch
@@ -149,31 +152,6 @@
149152
),
150153
{},
151154
),
152-
(
153-
SparkJarProcessor(
154-
role=ROLE,
155-
framework_version="2.4",
156-
instance_count=1,
157-
instance_type=INSTANCE_TYPE,
158-
),
159-
{
160-
"submit_app": "s3://my-jar",
161-
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
162-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
163-
},
164-
),
165-
(
166-
PySparkProcessor(
167-
role=ROLE,
168-
framework_version="2.4",
169-
instance_count=1,
170-
instance_type=INSTANCE_TYPE,
171-
),
172-
{
173-
"submit_app": "s3://my-jar",
174-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
175-
},
176-
),
177155
]
178156

179157
PROCESSING_INPUT = [
@@ -623,3 +601,114 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session):
623601
assert "The step_args of ProcessingStep must be obtained from processor.run()" in str(
624602
error.value
625603
)
604+
605+
606+
@pytest.mark.parametrize("spark_processor", [
607+
(
608+
SparkJarProcessor(
609+
role=ROLE,
610+
framework_version="2.4",
611+
instance_count=1,
612+
instance_type=INSTANCE_TYPE,
613+
),
614+
{
615+
"submit_app": "s3://my-jar",
616+
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
617+
"arguments": ["--input", "input-data-uri", "--output", ParameterString("MyArgOutput")],
618+
"submit_jars": ["s3://my-jar", ParameterString("MyJars"), "s3://her-jar", ParameterString("OurJar")],
619+
"submit_files": ["s3://my-files", ParameterString("MyFiles"), "s3://her-files",
620+
ParameterString("OurFiles")],
621+
"spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri")
622+
},
623+
),
624+
(
625+
PySparkProcessor(
626+
role=ROLE,
627+
framework_version="2.4",
628+
instance_count=1,
629+
instance_type=INSTANCE_TYPE,
630+
),
631+
{
632+
"submit_app": "s3://my-jar",
633+
"arguments": ["--input", "input-data-uri", "--output", ParameterString("MyArgOutput")],
634+
"submit_py_files": ["s3://my-py-files", ParameterString("MyPyFiles"), "s3://her-pyfiles",
635+
ParameterString("OurPyFiles")],
636+
"submit_jars": ["s3://my-jar", ParameterString("MyJars"), "s3://her-jar", ParameterString("OurJar")],
637+
"submit_files": ["s3://my-files", ParameterString("MyFiles"), "s3://her-files",
638+
ParameterString("OurFiles")],
639+
"spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"),
640+
},
641+
),
642+
])
643+
def test_spark_processor(spark_processor, processing_input, pipeline_session):
644+
645+
processor, run_inputs = spark_processor
646+
processor.sagemaker_session = pipeline_session
647+
processor.role = ROLE
648+
649+
run_inputs["inputs"] = processing_input
650+
651+
step_args = processor.run(**run_inputs)
652+
step = ProcessingStep(
653+
name="MyProcessingStep",
654+
step_args=step_args,
655+
)
656+
657+
step_args = step_args.args
658+
659+
assert step_args['AppSpecification']['ContainerArguments'] == run_inputs['arguments']
660+
661+
entry_points = step_args['AppSpecification']['ContainerEntrypoint']
662+
entry_points_expr = []
663+
for entry_point in entry_points:
664+
if is_pipeline_variable(entry_point):
665+
entry_points_expr.append(entry_point.expr)
666+
else:
667+
entry_points_expr.append(entry_point)
668+
669+
if 'submit_py_files' in run_inputs:
670+
expected = [
671+
'smspark-submit',
672+
'--py-files', {'Std:Join': {'On': ',', 'Values': ['s3://my-py-files', {'Get': 'Parameters.MyPyFiles'},
673+
's3://her-pyfiles', {'Get': 'Parameters.OurPyFiles'}]}},
674+
'--jars', {'Std:Join': {'On': ',', 'Values': ['s3://my-jar', {'Get': 'Parameters.MyJars'}, 's3://her-jar', {'Get': 'Parameters.OurJar'}]}},
675+
'--files', {'Std:Join': {'On': ',', 'Values': ['s3://my-files', {'Get': 'Parameters.MyFiles'}, 's3://her-files', {'Get': 'Parameters.OurFiles'}]}},
676+
'--local-spark-event-logs-dir', '/opt/ml/processing/spark-events/', '/opt/ml/processing/input/code'
677+
]
678+
# py spark
679+
else:
680+
expected = [
681+
'smspark-submit',
682+
'--class',
683+
'com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp',
684+
'--jars', {'Std:Join': {'On': ',', 'Values': ['s3://my-jar', {'Get': 'Parameters.MyJars'},
685+
's3://her-jar', {'Get': 'Parameters.OurJar'}]}},
686+
'--files', {'Std:Join': {'On': ',', 'Values': ['s3://my-files', {'Get': 'Parameters.MyFiles'},
687+
's3://her-files', {'Get': 'Parameters.OurFiles'}]}},
688+
'--local-spark-event-logs-dir', '/opt/ml/processing/spark-events/', '/opt/ml/processing/input/code'
689+
]
690+
691+
assert entry_points_expr == expected
692+
for output in step_args['ProcessingOutputConfig']['Outputs']:
693+
if is_pipeline_variable(output['S3Output']['S3Uri']):
694+
output['S3Output']['S3Uri'] = output['S3Output']['S3Uri'].expr
695+
696+
assert step_args['ProcessingOutputConfig']['Outputs'] == [
697+
{
698+
'OutputName': 'output-1',
699+
'AppManaged': False,
700+
'S3Output': {
701+
'S3Uri': {'Get': 'Parameters.MySparkEventLogS3Uri'},
702+
'LocalPath': '/opt/ml/processing/spark-events/',
703+
'S3UploadMode': 'Continuous'
704+
}
705+
}
706+
]
707+
708+
pipeline = Pipeline(
709+
name="MyPipeline",
710+
steps=[step],
711+
sagemaker_session=pipeline_session,
712+
)
713+
pipeline.definition()
714+

0 commit comments

Comments
 (0)