diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index 912bc90d80..afb3e04599 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -104,6 +104,8 @@ def __init__( volume_size_in_gb=30, volume_kms_key=None, output_kms_key=None, + configuration_location: Optional[str] = None, + dependency_location: Optional[str] = None, max_runtime_in_seconds=None, base_job_name=None, sagemaker_session=None, @@ -134,6 +136,12 @@ def __init__( volume_kms_key (str): A KMS key for the processing volume. output_kms_key (str): The KMS key id for all ProcessingOutputs. + configuration_location (str): The S3 prefix URI where the user-provided EMR + application configuration will be uploaded (default: None). If not specified, + the default ``configuration location`` is 's3://{sagemaker-default-bucket}'. + dependency_location (str): The S3 prefix URI where Spark dependencies will be + uploaded (default: None). If not specified, the default ``dependency location`` + is 's3://{sagemaker-default-bucket}'. max_runtime_in_seconds (int): Timeout in seconds. After this amount of time Amazon SageMaker terminates the job regardless of its current status. @@ -150,6 +158,8 @@ def __init__( object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. """ + self.configuration_location = configuration_location + self.dependency_location = dependency_location self.history_server = None self._spark_event_logs_s3_uri = None @@ -413,19 +423,27 @@ def _stage_configuration(self, configuration): """ from sagemaker.workflow.utilities import _pipeline_config + if self.configuration_location: + if self.configuration_location.endswith("/"): + s3_prefix_uri = self.configuration_location[:-1] + else: + s3_prefix_uri = self.configuration_location + else: + s3_prefix_uri = f"s3://{self.sagemaker_session.default_bucket()}" + serialized_configuration = BytesIO(json.dumps(configuration).encode("utf-8")) if _pipeline_config and _pipeline_config.config_hash: s3_uri = ( - f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/" - f"{_pipeline_config.step_name}/input/" - f"{self._conf_container_input_name}/{_pipeline_config.config_hash}/" + f"{s3_prefix_uri}/{_pipeline_config.pipeline_name}/{_pipeline_config.step_name}/" + f"input/{self._conf_container_input_name}/{_pipeline_config.config_hash}/" f"{self._conf_file_name}" ) else: s3_uri = ( - f"s3://{self.sagemaker_session.default_bucket()}/{self._current_job_name}/" - f"input/{self._conf_container_input_name}/{self._conf_file_name}" + f"{s3_prefix_uri}/{self._current_job_name}/" + f"input/{self._conf_container_input_name}/" + f"{self._conf_file_name}" ) S3Uploader.upload_string_as_file_body( @@ -447,7 +465,7 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): This prepared list of paths is provided as `spark-submit` options. The submit_deps list may include a combination of S3 URIs and local paths. Any S3 URIs are appended to the `spark-submit` option value without modification. - Any local file paths are copied to a temp directory, uploaded to a default S3 URI, + Any local file paths are copied to a temp directory, uploaded to ``dependency location``, and included as a ProcessingInput channel to provide as local files to the SageMaker Spark container. @@ -500,16 +518,22 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): if os.listdir(tmpdir): from sagemaker.workflow.utilities import _pipeline_config + if self.dependency_location: + if self.dependency_location.endswith("/"): + s3_prefix_uri = self.dependency_location[:-1] + else: + s3_prefix_uri = self.dependency_location + else: + s3_prefix_uri = f"s3://{self.sagemaker_session.default_bucket()}" + if _pipeline_config and _pipeline_config.code_hash: input_channel_s3_uri = ( - f"s3://{self.sagemaker_session.default_bucket()}" - f"/{_pipeline_config.pipeline_name}/code/{_pipeline_config.code_hash}" - f"/{input_channel_name}" + f"{s3_prefix_uri}/{_pipeline_config.pipeline_name}/" + f"code/{_pipeline_config.code_hash}/{input_channel_name}" ) else: input_channel_s3_uri = ( - f"s3://{self.sagemaker_session.default_bucket()}" - f"/{self._current_job_name}/input/{input_channel_name}" + f"{s3_prefix_uri}/{self._current_job_name}/input/{input_channel_name}" ) logger.info( "Uploading dependencies from tmpdir %s to S3 %s", tmpdir, input_channel_s3_uri @@ -719,6 +743,8 @@ def __init__( volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, output_kms_key: Optional[Union[str, PipelineVariable]] = None, + configuration_location: Optional[str] = None, + dependency_location: Optional[str] = None, max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, @@ -749,6 +775,12 @@ def __init__( volume_kms_key (str or PipelineVariable): A KMS key for the processing volume. output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs. + configuration_location (str): The S3 prefix URI where the user-provided EMR + application configuration will be uploaded (default: None). If not specified, + the default ``configuration location`` is 's3://{sagemaker-default-bucket}'. + dependency_location (str): The S3 prefix URI where Spark dependencies will be + uploaded (default: None). If not specified, the default ``dependency location`` + is 's3://{sagemaker-default-bucket}'. max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds. After this amount of time Amazon SageMaker terminates the job regardless of its current status. @@ -779,6 +811,8 @@ def __init__( volume_size_in_gb=volume_size_in_gb, volume_kms_key=volume_kms_key, output_kms_key=output_kms_key, + configuration_location=configuration_location, + dependency_location=dependency_location, max_runtime_in_seconds=max_runtime_in_seconds, base_job_name=base_job_name, sagemaker_session=sagemaker_session, @@ -986,6 +1020,8 @@ def __init__( volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, output_kms_key: Optional[Union[str, PipelineVariable]] = None, + configuration_location: Optional[str] = None, + dependency_location: Optional[str] = None, max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, @@ -1016,6 +1052,12 @@ def __init__( volume_kms_key (str or PipelineVariable): A KMS key for the processing volume. output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs. + configuration_location (str): The S3 prefix URI where the user-provided EMR + application configuration will be uploaded (default: None). If not specified, + the default ``configuration location`` is 's3://{sagemaker-default-bucket}'. + dependency_location (str): The S3 prefix URI where Spark dependencies will be + uploaded (default: None). If not specified, the default ``dependency location`` + is 's3://{sagemaker-default-bucket}'. max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds. After this amount of time Amazon SageMaker terminates the job regardless of its current status. @@ -1046,6 +1088,8 @@ def __init__( volume_size_in_gb=volume_size_in_gb, volume_kms_key=volume_kms_key, output_kms_key=output_kms_key, + configuration_location=configuration_location, + dependency_location=dependency_location, max_runtime_in_seconds=max_runtime_in_seconds, base_job_name=base_job_name, sagemaker_session=sagemaker_session, diff --git a/tests/unit/sagemaker/spark/test_processing.py b/tests/unit/sagemaker/spark/test_processing.py index ba08f82fad..a079f477b4 100644 --- a/tests/unit/sagemaker/spark/test_processing.py +++ b/tests/unit/sagemaker/spark/test_processing.py @@ -273,13 +273,61 @@ def test_spark_processor_base_extend_processing_args( serialized_configuration = BytesIO("test".encode("utf-8")) +@pytest.mark.parametrize( + "config, expected", + [ + ( + { + "spark_processor_type": "py_spark_processor", + "configuration_location": None, + }, + "s3://bucket/None/input/conf/configuration.json", + ), + ( + { + "spark_processor_type": "py_spark_processor", + "configuration_location": "s3://configbucket/someprefix/", + }, + "s3://configbucket/someprefix/None/input/conf/configuration.json", + ), + ( + { + "spark_processor_type": "spark_jar_processor", + "configuration_location": None, + }, + "s3://bucket/None/input/conf/configuration.json", + ), + ( + { + "spark_processor_type": "spark_jar_processor", + "configuration_location": "s3://configbucket/someprefix", + }, + "s3://configbucket/someprefix/None/input/conf/configuration.json", + ), + ], +) @patch("sagemaker.spark.processing.BytesIO") @patch("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body") -def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, sagemaker_session): - desired_s3_uri = "s3://bucket/None/input/conf/configuration.json" +def test_stage_configuration(mock_s3_upload, mock_bytesIO, config, expected, sagemaker_session): + spark_processor_type = { + "py_spark_processor": PySparkProcessor, + "spark_jar_processor": SparkJarProcessor, + }[config["spark_processor_type"]] + spark_processor = spark_processor_type( + base_job_name="sm-spark", + role="AmazonSageMaker-ExecutionRole", + framework_version="2.4", + instance_count=1, + instance_type="ml.c5.xlarge", + image_uri="790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1", + configuration_location=config["configuration_location"], + sagemaker_session=sagemaker_session, + ) + + desired_s3_uri = expected mock_bytesIO.return_value = serialized_configuration - result = py_spark_processor._stage_configuration({}) + result = spark_processor._stage_configuration({}) mock_s3_upload.assert_called_with( body=serialized_configuration, @@ -292,23 +340,121 @@ def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, s @pytest.mark.parametrize( "config, expected", [ - ({"submit_deps": None, "input_channel_name": "channelName"}, ValueError), - ({"submit_deps": ["s3"], "input_channel_name": None}, ValueError), - ({"submit_deps": ["other"], "input_channel_name": "channelName"}, ValueError), - ({"submit_deps": ["file"], "input_channel_name": "channelName"}, ValueError), - ({"submit_deps": ["file"], "input_channel_name": "channelName"}, ValueError), ( - {"submit_deps": ["s3", "s3"], "input_channel_name": "channelName"}, + { + "spark_processor_type": "py_spark_processor", + "dependency_location": None, + "submit_deps": None, + "input_channel_name": "channelName", + }, + ValueError, + ), + ( + { + "spark_processor_type": "py_spark_processor", + "dependency_location": None, + "submit_deps": ["s3"], + "input_channel_name": None, + }, + ValueError, + ), + ( + { + "spark_processor_type": "py_spark_processor", + "dependency_location": None, + "submit_deps": ["other"], + "input_channel_name": "channelName", + }, + ValueError, + ), + ( + { + "spark_processor_type": "py_spark_processor", + "dependency_location": None, + "submit_deps": ["file"], + "input_channel_name": "channelName", + }, + ValueError, + ), + ( + { + "spark_processor_type": "py_spark_processor", + "dependency_location": None, + "submit_deps": ["file"], + "input_channel_name": "channelName", + }, + ValueError, + ), + ( + { + "spark_processor_type": "py_spark_processor", + "dependency_location": None, + "submit_deps": ["s3", "s3"], + "input_channel_name": "channelName", + }, (None, "s3://bucket,s3://bucket"), ), ( - {"submit_deps": ["jar"], "input_channel_name": "channelName"}, - (processing_input, "s3://bucket"), + { + "spark_processor_type": "py_spark_processor", + "dependency_location": None, + "submit_deps": ["jar"], + "input_channel_name": "channelName", + }, + ("s3://bucket/None/input/channelName", "/opt/ml/processing/input/channelName"), + ), + ( + { + "spark_processor_type": "py_spark_processor", + "dependency_location": "s3://codebucket/someprefix/", + "submit_deps": ["jar"], + "input_channel_name": "channelName", + }, + ( + "s3://codebucket/someprefix/None/input/channelName", + "/opt/ml/processing/input/channelName", + ), + ), + ( + { + "spark_processor_type": "spark_jar_processor", + "dependency_location": None, + "submit_deps": ["jar"], + "input_channel_name": "channelName", + }, + ("s3://bucket/None/input/channelName", "/opt/ml/processing/input/channelName"), + ), + ( + { + "spark_processor_type": "spark_jar_processor", + "dependency_location": "s3://codebucket/someprefix", + "submit_deps": ["jar"], + "input_channel_name": "channelName", + }, + ( + "s3://codebucket/someprefix/None/input/channelName", + "/opt/ml/processing/input/channelName", + ), ), ], ) @patch("sagemaker.spark.processing.S3Uploader") -def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, config, expected): +def test_stage_submit_deps(mock_s3_uploader, jar_file, config, expected, sagemaker_session): + spark_processor_type = { + "py_spark_processor": PySparkProcessor, + "spark_jar_processor": SparkJarProcessor, + }[config["spark_processor_type"]] + spark_processor = spark_processor_type( + base_job_name="sm-spark", + role="AmazonSageMaker-ExecutionRole", + framework_version="2.4", + instance_count=1, + instance_type="ml.c5.xlarge", + image_uri="790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1", + dependency_location=config["dependency_location"], + sagemaker_session=sagemaker_session, + ) + submit_deps_dict = { None: None, "s3": "s3://bucket", @@ -322,11 +468,11 @@ def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, confi if expected is ValueError: with pytest.raises(expected) as e: - py_spark_processor._stage_submit_deps(submit_deps, config["input_channel_name"]) + spark_processor._stage_submit_deps(submit_deps, config["input_channel_name"]) assert isinstance(e.value, expected) else: - input_channel, spark_opt = py_spark_processor._stage_submit_deps( + input_channel, spark_opt = spark_processor._stage_submit_deps( submit_deps, config["input_channel_name"] ) @@ -334,9 +480,8 @@ def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, confi assert input_channel is None assert spark_opt == expected[1] else: - expected_source = "s3://bucket/None/input/channelName" - assert input_channel.source == expected_source - assert spark_opt == "/opt/ml/processing/input/channelName" + assert input_channel.source == expected[0] + assert spark_opt == expected[1] @pytest.mark.parametrize(