Skip to content

fix: support pipeline variables for spark processors run arguments #3167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 12, 2022
153 changes: 91 additions & 62 deletions src/sagemaker/spark/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,20 @@
from io import BytesIO
from urllib.parse import urlparse

from typing import Union, List, Dict, Optional

from sagemaker import image_uris
from sagemaker.local.image import _ecr_login_if_needed, _pull_image
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
from sagemaker.s3 import S3Uploader
from sagemaker.session import Session
from sagemaker.network import NetworkConfig
from sagemaker.spark import defaults

from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.functions import Join

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -249,6 +256,12 @@ def run(
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

if is_pipeline_variable(submit_app):
raise ValueError(
"submit_app argument has to be a valid S3 URI or local file path "
+ "rather than a pipeline variable"
)

return super().run(
submit_app,
inputs,
Expand Down Expand Up @@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):

use_input_channel = False
spark_opt_s3_uris = []
spark_opt_s3_uris_has_pipeline_var = False

with tempfile.TemporaryDirectory() as tmpdir:
for dep_path in submit_deps:
if is_pipeline_variable(dep_path):
spark_opt_s3_uris.append(dep_path)
spark_opt_s3_uris_has_pipeline_var = True
continue
dep_url = urlparse(dep_path)
# S3 URIs are included as-is in the spark-submit argument
if dep_url.scheme in ["s3", "s3a"]:
Expand Down Expand Up @@ -482,11 +500,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
destination=f"{self._conf_container_base_path}{input_channel_name}",
input_name=input_channel_name,
)
spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination])
spark_opt = (
Join(on=",", values=spark_opt_s3_uris + [input_channel.destination])
if spark_opt_s3_uris_has_pipeline_var
else ",".join(spark_opt_s3_uris + [input_channel.destination])
)
# If no local files were uploaded, form the spark-submit option from a list of S3 URIs
else:
input_channel = None
spark_opt = ",".join(spark_opt_s3_uris)
spark_opt = (
Join(on=",", values=spark_opt_s3_uris)
if spark_opt_s3_uris_has_pipeline_var
else ",".join(spark_opt_s3_uris)
)

return input_channel, spark_opt

Expand Down Expand Up @@ -592,6 +618,9 @@ def _validate_s3_uri(self, spark_output_s3_path):
Args:
spark_output_s3_path (str): The URI of the Spark output S3 Path.
"""
if is_pipeline_variable(spark_output_s3_path):
return

if urlparse(spark_output_s3_path).scheme != "s3":
raise ValueError(
f"Invalid s3 path: {spark_output_s3_path}. Please enter something like "
Expand Down Expand Up @@ -650,22 +679,22 @@ class PySparkProcessor(_SparkProcessorBase):

def __init__(
self,
role,
instance_type,
instance_count,
framework_version=None,
py_version=None,
container_version=None,
image_uri=None,
volume_size_in_gb=30,
volume_kms_key=None,
output_kms_key=None,
max_runtime_in_seconds=None,
base_job_name=None,
sagemaker_session=None,
env=None,
tags=None,
network_config=None,
role: str,
instance_type: Union[int, PipelineVariable],
instance_count: Union[str, PipelineVariable],
framework_version: Optional[str] = None,
py_version: Optional[str] = None,
container_version: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
volume_size_in_gb: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
network_config: Optional[NetworkConfig] = None,
):
"""Initialize an ``PySparkProcessor`` instance.

Expand Down Expand Up @@ -795,20 +824,20 @@ def get_run_args(

def run(
self,
submit_app,
submit_py_files=None,
submit_jars=None,
submit_files=None,
inputs=None,
outputs=None,
arguments=None,
wait=True,
logs=True,
job_name=None,
experiment_config=None,
configuration=None,
spark_event_logs_s3_uri=None,
kms_key=None,
submit_app: str,
submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None,
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
inputs: Optional[List[ProcessingInput]] = None,
outputs: Optional[List[ProcessingOutput]] = None,
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
wait: bool = True,
logs: bool = True,
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
configuration: Optional[Union[List[Dict], Dict]] = None,
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
kms_key: Optional[str] = None,
):
"""Runs a processing job.

Expand Down Expand Up @@ -907,22 +936,22 @@ class SparkJarProcessor(_SparkProcessorBase):

def __init__(
self,
role,
instance_type,
instance_count,
framework_version=None,
py_version=None,
container_version=None,
image_uri=None,
volume_size_in_gb=30,
volume_kms_key=None,
output_kms_key=None,
max_runtime_in_seconds=None,
base_job_name=None,
sagemaker_session=None,
env=None,
tags=None,
network_config=None,
role: str,
instance_type: Union[int, PipelineVariable],
instance_count: Union[str, PipelineVariable],
framework_version: Optional[str] = None,
py_version: Optional[str] = None,
container_version: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
volume_size_in_gb: Union[int, PipelineVariable] = 30,
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
network_config: Optional[NetworkConfig] = None,
):
"""Initialize a ``SparkJarProcessor`` instance.

Expand Down Expand Up @@ -1052,20 +1081,20 @@ def get_run_args(

def run(
self,
submit_app,
submit_class=None,
submit_jars=None,
submit_files=None,
inputs=None,
outputs=None,
arguments=None,
wait=True,
logs=True,
job_name=None,
experiment_config=None,
configuration=None,
spark_event_logs_s3_uri=None,
kms_key=None,
submit_app: str,
submit_class: Union[str, PipelineVariable],
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
inputs: Optional[List[ProcessingInput]] = None,
outputs: Optional[List[ProcessingOutput]] = None,
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
wait: bool = True,
logs: bool = True,
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
configuration: Optional[Union[List[Dict], Dict]] = None,
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
kms_key: Optional[str] = None,
):
"""Runs a processing job.

Expand Down
Loading