|
31 | 31 | from io import BytesIO
|
32 | 32 | from urllib.parse import urlparse
|
33 | 33 |
|
| 34 | +from typing import Union, List, Dict, Optional |
| 35 | + |
34 | 36 | from sagemaker import image_uris
|
35 | 37 | from sagemaker.local.image import _ecr_login_if_needed, _pull_image
|
36 | 38 | from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
|
37 | 39 | from sagemaker.s3 import S3Uploader
|
38 | 40 | from sagemaker.session import Session
|
| 41 | +from sagemaker.network import NetworkConfig |
39 | 42 | from sagemaker.spark import defaults
|
40 | 43 |
|
| 44 | +from sagemaker.workflow import is_pipeline_variable |
| 45 | +from sagemaker.workflow.entities import PipelineVariable |
| 46 | +from sagemaker.workflow.functions import Join |
| 47 | + |
41 | 48 | logger = logging.getLogger(__name__)
|
42 | 49 |
|
43 | 50 |
|
@@ -249,6 +256,12 @@ def run(
|
249 | 256 | """
|
250 | 257 | self._current_job_name = self._generate_current_job_name(job_name=job_name)
|
251 | 258 |
|
| 259 | + if is_pipeline_variable(submit_app): |
| 260 | + raise ValueError( |
| 261 | + "submit_app argument has to be a valid S3 URI or local file path " |
| 262 | + + "rather than a pipeline variable" |
| 263 | + ) |
| 264 | + |
252 | 265 | return super().run(
|
253 | 266 | submit_app,
|
254 | 267 | inputs,
|
@@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
|
437 | 450 |
|
438 | 451 | use_input_channel = False
|
439 | 452 | spark_opt_s3_uris = []
|
| 453 | + spark_opt_s3_uris_has_pipeline_var = False |
440 | 454 |
|
441 | 455 | with tempfile.TemporaryDirectory() as tmpdir:
|
442 | 456 | for dep_path in submit_deps:
|
| 457 | + if is_pipeline_variable(dep_path): |
| 458 | + spark_opt_s3_uris.append(dep_path) |
| 459 | + spark_opt_s3_uris_has_pipeline_var = True |
| 460 | + continue |
443 | 461 | dep_url = urlparse(dep_path)
|
444 | 462 | # S3 URIs are included as-is in the spark-submit argument
|
445 | 463 | if dep_url.scheme in ["s3", "s3a"]:
|
@@ -482,11 +500,13 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
|
482 | 500 | destination=f"{self._conf_container_base_path}{input_channel_name}",
|
483 | 501 | input_name=input_channel_name,
|
484 | 502 | )
|
485 |
| - spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination]) |
| 503 | + spark_opt = Join(on=",", values=spark_opt_s3_uris + [input_channel.destination]) \ |
| 504 | + if spark_opt_s3_uris_has_pipeline_var else ",".join(spark_opt_s3_uris + [input_channel.destination]) |
486 | 505 | # If no local files were uploaded, form the spark-submit option from a list of S3 URIs
|
487 | 506 | else:
|
488 | 507 | input_channel = None
|
489 |
| - spark_opt = ",".join(spark_opt_s3_uris) |
| 508 | + spark_opt = Join(on=",", values=spark_opt_s3_uris) if spark_opt_s3_uris_has_pipeline_var \ |
| 509 | + else ",".join(spark_opt_s3_uris) |
490 | 510 |
|
491 | 511 | return input_channel, spark_opt
|
492 | 512 |
|
@@ -592,6 +612,9 @@ def _validate_s3_uri(self, spark_output_s3_path):
|
592 | 612 | Args:
|
593 | 613 | spark_output_s3_path (str): The URI of the Spark output S3 Path.
|
594 | 614 | """
|
| 615 | + if is_pipeline_variable(spark_output_s3_path): |
| 616 | + return |
| 617 | + |
595 | 618 | if urlparse(spark_output_s3_path).scheme != "s3":
|
596 | 619 | raise ValueError(
|
597 | 620 | f"Invalid s3 path: {spark_output_s3_path}. Please enter something like "
|
@@ -650,22 +673,22 @@ class PySparkProcessor(_SparkProcessorBase):
|
650 | 673 |
|
651 | 674 | def __init__(
|
652 | 675 | self,
|
653 |
| - role, |
654 |
| - instance_type, |
655 |
| - instance_count, |
656 |
| - framework_version=None, |
657 |
| - py_version=None, |
658 |
| - container_version=None, |
659 |
| - image_uri=None, |
660 |
| - volume_size_in_gb=30, |
661 |
| - volume_kms_key=None, |
662 |
| - output_kms_key=None, |
663 |
| - max_runtime_in_seconds=None, |
664 |
| - base_job_name=None, |
665 |
| - sagemaker_session=None, |
666 |
| - env=None, |
667 |
| - tags=None, |
668 |
| - network_config=None, |
| 676 | + role: str, |
| 677 | + instance_type: Union[int, PipelineVariable], |
| 678 | + instance_count: Union[str, PipelineVariable], |
| 679 | + framework_version: Optional[str] = None, |
| 680 | + py_version: Optional[str] = None, |
| 681 | + container_version: Optional[str] = None, |
| 682 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 683 | + volume_size_in_gb: Union[int, PipelineVariable] = 30, |
| 684 | + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 685 | + output_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 686 | + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, |
| 687 | + base_job_name: Optional[str] = None, |
| 688 | + sagemaker_session: Optional[Session] = None, |
| 689 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 690 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 691 | + network_config: Optional[NetworkConfig] = None, |
669 | 692 | ):
|
670 | 693 | """Initialize an ``PySparkProcessor`` instance.
|
671 | 694 |
|
@@ -795,20 +818,20 @@ def get_run_args(
|
795 | 818 |
|
796 | 819 | def run(
|
797 | 820 | self,
|
798 |
| - submit_app, |
799 |
| - submit_py_files=None, |
800 |
| - submit_jars=None, |
801 |
| - submit_files=None, |
802 |
| - inputs=None, |
803 |
| - outputs=None, |
804 |
| - arguments=None, |
805 |
| - wait=True, |
806 |
| - logs=True, |
807 |
| - job_name=None, |
808 |
| - experiment_config=None, |
809 |
| - configuration=None, |
810 |
| - spark_event_logs_s3_uri=None, |
811 |
| - kms_key=None, |
| 821 | + submit_app: str, |
| 822 | + submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None, |
| 823 | + submit_jars: Optional[List[Union[str, PipelineVariable]]] = None, |
| 824 | + submit_files: Optional[List[Union[str, PipelineVariable]]] = None, |
| 825 | + inputs: Optional[List[ProcessingInput]] = None, |
| 826 | + outputs: Optional[List[ProcessingOutput]] = None, |
| 827 | + arguments: Optional[List[Union[str, PipelineVariable]]] = None, |
| 828 | + wait: bool = True, |
| 829 | + logs: bool = True, |
| 830 | + job_name: Optional[str] = None, |
| 831 | + experiment_config: Optional[Dict[str, str]] = None, |
| 832 | + configuration: Optional[Union[List[Dict], Dict]] = None, |
| 833 | + spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 834 | + kms_key: Optional[str] = None, |
812 | 835 | ):
|
813 | 836 | """Runs a processing job.
|
814 | 837 |
|
@@ -907,22 +930,22 @@ class SparkJarProcessor(_SparkProcessorBase):
|
907 | 930 |
|
908 | 931 | def __init__(
|
909 | 932 | self,
|
910 |
| - role, |
911 |
| - instance_type, |
912 |
| - instance_count, |
913 |
| - framework_version=None, |
914 |
| - py_version=None, |
915 |
| - container_version=None, |
916 |
| - image_uri=None, |
917 |
| - volume_size_in_gb=30, |
918 |
| - volume_kms_key=None, |
919 |
| - output_kms_key=None, |
920 |
| - max_runtime_in_seconds=None, |
921 |
| - base_job_name=None, |
922 |
| - sagemaker_session=None, |
923 |
| - env=None, |
924 |
| - tags=None, |
925 |
| - network_config=None, |
| 933 | + role: str, |
| 934 | + instance_type: Union[int, PipelineVariable], |
| 935 | + instance_count: Union[str, PipelineVariable], |
| 936 | + framework_version: Optional[str] = None, |
| 937 | + py_version: Optional[str] = None, |
| 938 | + container_version: Optional[str] = None, |
| 939 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 940 | + volume_size_in_gb: Union[int, PipelineVariable] = 30, |
| 941 | + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 942 | + output_kms_key: Optional[Union[str, PipelineVariable]] = None, |
| 943 | + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, |
| 944 | + base_job_name: Optional[str] = None, |
| 945 | + sagemaker_session: Optional[Session] = None, |
| 946 | + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 947 | + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, |
| 948 | + network_config: Optional[NetworkConfig] = None, |
926 | 949 | ):
|
927 | 950 | """Initialize a ``SparkJarProcessor`` instance.
|
928 | 951 |
|
@@ -1052,20 +1075,20 @@ def get_run_args(
|
1052 | 1075 |
|
1053 | 1076 | def run(
|
1054 | 1077 | self,
|
1055 |
| - submit_app, |
1056 |
| - submit_class=None, |
1057 |
| - submit_jars=None, |
1058 |
| - submit_files=None, |
1059 |
| - inputs=None, |
1060 |
| - outputs=None, |
1061 |
| - arguments=None, |
1062 |
| - wait=True, |
1063 |
| - logs=True, |
1064 |
| - job_name=None, |
1065 |
| - experiment_config=None, |
1066 |
| - configuration=None, |
1067 |
| - spark_event_logs_s3_uri=None, |
1068 |
| - kms_key=None, |
| 1078 | + submit_app: str, |
| 1079 | + submit_class: Union[str, PipelineVariable], |
| 1080 | + submit_jars: Optional[List[Union[str, PipelineVariable]]] = None, |
| 1081 | + submit_files: Optional[List[Union[str, PipelineVariable]]] = None, |
| 1082 | + inputs: Optional[List[ProcessingInput]] = None, |
| 1083 | + outputs: Optional[List[ProcessingOutput]] = None, |
| 1084 | + arguments:Optional[List[Union[str, PipelineVariable]]] = None, |
| 1085 | + wait: bool = True, |
| 1086 | + logs: bool = True, |
| 1087 | + job_name: Optional[str] = None, |
| 1088 | + experiment_config: Optional[Dict[str, str]] = None, |
| 1089 | + configuration: Optional[Union[List[Dict], Dict]] = None, |
| 1090 | + spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None, |
| 1091 | + kms_key: Optional[str] = None, |
1069 | 1092 | ):
|
1070 | 1093 | """Runs a processing job.
|
1071 | 1094 |
|
|
0 commit comments