Skip to content

Commit e918454

Browse files
jerrypeng7773JoseJuan98
authored andcommitted
fix: support pipeline variables for spark processors run arguments (aws#3167)
1 parent d94cf5e commit e918454

File tree

2 files changed

+293
-87
lines changed

2 files changed

+293
-87
lines changed

src/sagemaker/spark/processing.py

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,20 @@
3131
from io import BytesIO
3232
from urllib.parse import urlparse
3333

34+
from typing import Union, List, Dict, Optional
35+
3436
from sagemaker import image_uris
3537
from sagemaker.local.image import _ecr_login_if_needed, _pull_image
3638
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
3739
from sagemaker.s3 import S3Uploader
3840
from sagemaker.session import Session
41+
from sagemaker.network import NetworkConfig
3942
from sagemaker.spark import defaults
4043

44+
from sagemaker.workflow import is_pipeline_variable
45+
from sagemaker.workflow.entities import PipelineVariable
46+
from sagemaker.workflow.functions import Join
47+
4148
logger = logging.getLogger(__name__)
4249

4350

@@ -249,6 +256,12 @@ def run(
249256
"""
250257
self._current_job_name = self._generate_current_job_name(job_name=job_name)
251258

259+
if is_pipeline_variable(submit_app):
260+
raise ValueError(
261+
"submit_app argument has to be a valid S3 URI or local file path "
262+
+ "rather than a pipeline variable"
263+
)
264+
252265
return super().run(
253266
submit_app,
254267
inputs,
@@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
437450

438451
use_input_channel = False
439452
spark_opt_s3_uris = []
453+
spark_opt_s3_uris_has_pipeline_var = False
440454

441455
with tempfile.TemporaryDirectory() as tmpdir:
442456
for dep_path in submit_deps:
457+
if is_pipeline_variable(dep_path):
458+
spark_opt_s3_uris.append(dep_path)
459+
spark_opt_s3_uris_has_pipeline_var = True
460+
continue
443461
dep_url = urlparse(dep_path)
444462
# S3 URIs are included as-is in the spark-submit argument
445463
if dep_url.scheme in ["s3", "s3a"]:
@@ -482,11 +500,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
482500
destination=f"{self._conf_container_base_path}{input_channel_name}",
483501
input_name=input_channel_name,
484502
)
485-
spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination])
503+
spark_opt = (
504+
Join(on=",", values=spark_opt_s3_uris + [input_channel.destination])
505+
if spark_opt_s3_uris_has_pipeline_var
506+
else ",".join(spark_opt_s3_uris + [input_channel.destination])
507+
)
486508
# If no local files were uploaded, form the spark-submit option from a list of S3 URIs
487509
else:
488510
input_channel = None
489-
spark_opt = ",".join(spark_opt_s3_uris)
511+
spark_opt = (
512+
Join(on=",", values=spark_opt_s3_uris)
513+
if spark_opt_s3_uris_has_pipeline_var
514+
else ",".join(spark_opt_s3_uris)
515+
)
490516

491517
return input_channel, spark_opt
492518

@@ -592,6 +618,9 @@ def _validate_s3_uri(self, spark_output_s3_path):
592618
Args:
593619
spark_output_s3_path (str): The URI of the Spark output S3 Path.
594620
"""
621+
if is_pipeline_variable(spark_output_s3_path):
622+
return
623+
595624
if urlparse(spark_output_s3_path).scheme != "s3":
596625
raise ValueError(
597626
f"Invalid s3 path: {spark_output_s3_path}. Please enter something like "
@@ -650,22 +679,22 @@ class PySparkProcessor(_SparkProcessorBase):
650679

651680
def __init__(
652681
self,
653-
role,
654-
instance_type,
655-
instance_count,
656-
framework_version=None,
657-
py_version=None,
658-
container_version=None,
659-
image_uri=None,
660-
volume_size_in_gb=30,
661-
volume_kms_key=None,
662-
output_kms_key=None,
663-
max_runtime_in_seconds=None,
664-
base_job_name=None,
665-
sagemaker_session=None,
666-
env=None,
667-
tags=None,
668-
network_config=None,
682+
role: str,
683+
instance_type: Union[int, PipelineVariable],
684+
instance_count: Union[str, PipelineVariable],
685+
framework_version: Optional[str] = None,
686+
py_version: Optional[str] = None,
687+
container_version: Optional[str] = None,
688+
image_uri: Optional[Union[str, PipelineVariable]] = None,
689+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
690+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
691+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
692+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
693+
base_job_name: Optional[str] = None,
694+
sagemaker_session: Optional[Session] = None,
695+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
696+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
697+
network_config: Optional[NetworkConfig] = None,
669698
):
670699
"""Initialize an ``PySparkProcessor`` instance.
671700
@@ -795,20 +824,20 @@ def get_run_args(
795824

796825
def run(
797826
self,
798-
submit_app,
799-
submit_py_files=None,
800-
submit_jars=None,
801-
submit_files=None,
802-
inputs=None,
803-
outputs=None,
804-
arguments=None,
805-
wait=True,
806-
logs=True,
807-
job_name=None,
808-
experiment_config=None,
809-
configuration=None,
810-
spark_event_logs_s3_uri=None,
811-
kms_key=None,
827+
submit_app: str,
828+
submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None,
829+
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
830+
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
831+
inputs: Optional[List[ProcessingInput]] = None,
832+
outputs: Optional[List[ProcessingOutput]] = None,
833+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
834+
wait: bool = True,
835+
logs: bool = True,
836+
job_name: Optional[str] = None,
837+
experiment_config: Optional[Dict[str, str]] = None,
838+
configuration: Optional[Union[List[Dict], Dict]] = None,
839+
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
840+
kms_key: Optional[str] = None,
812841
):
813842
"""Runs a processing job.
814843
@@ -907,22 +936,22 @@ class SparkJarProcessor(_SparkProcessorBase):
907936

908937
def __init__(
909938
self,
910-
role,
911-
instance_type,
912-
instance_count,
913-
framework_version=None,
914-
py_version=None,
915-
container_version=None,
916-
image_uri=None,
917-
volume_size_in_gb=30,
918-
volume_kms_key=None,
919-
output_kms_key=None,
920-
max_runtime_in_seconds=None,
921-
base_job_name=None,
922-
sagemaker_session=None,
923-
env=None,
924-
tags=None,
925-
network_config=None,
939+
role: str,
940+
instance_type: Union[int, PipelineVariable],
941+
instance_count: Union[str, PipelineVariable],
942+
framework_version: Optional[str] = None,
943+
py_version: Optional[str] = None,
944+
container_version: Optional[str] = None,
945+
image_uri: Optional[Union[str, PipelineVariable]] = None,
946+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
947+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
948+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
949+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
950+
base_job_name: Optional[str] = None,
951+
sagemaker_session: Optional[Session] = None,
952+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
953+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
954+
network_config: Optional[NetworkConfig] = None,
926955
):
927956
"""Initialize a ``SparkJarProcessor`` instance.
928957
@@ -1052,20 +1081,20 @@ def get_run_args(
10521081

10531082
def run(
10541083
self,
1055-
submit_app,
1056-
submit_class=None,
1057-
submit_jars=None,
1058-
submit_files=None,
1059-
inputs=None,
1060-
outputs=None,
1061-
arguments=None,
1062-
wait=True,
1063-
logs=True,
1064-
job_name=None,
1065-
experiment_config=None,
1066-
configuration=None,
1067-
spark_event_logs_s3_uri=None,
1068-
kms_key=None,
1084+
submit_app: str,
1085+
submit_class: Union[str, PipelineVariable],
1086+
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
1087+
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
1088+
inputs: Optional[List[ProcessingInput]] = None,
1089+
outputs: Optional[List[ProcessingOutput]] = None,
1090+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
1091+
wait: bool = True,
1092+
logs: bool = True,
1093+
job_name: Optional[str] = None,
1094+
experiment_config: Optional[Dict[str, str]] = None,
1095+
configuration: Optional[Union[List[Dict], Dict]] = None,
1096+
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
1097+
kms_key: Optional[str] = None,
10691098
):
10701099
"""Runs a processing job.
10711100

0 commit comments

Comments
 (0)