Skip to content

Commit 4fb7c01

Browse files
feature: locations for EMR configuration and Spark dependencies (#3486)
* feat: locations for EMR configuration and Spark dependencies * Fix style issues
1 parent e22ccb4 commit 4fb7c01

File tree

2 files changed

+217
-28
lines changed

2 files changed

+217
-28
lines changed

src/sagemaker/spark/processing.py

+55-11
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def __init__(
104104
volume_size_in_gb=30,
105105
volume_kms_key=None,
106106
output_kms_key=None,
107+
configuration_location: Optional[str] = None,
108+
dependency_location: Optional[str] = None,
107109
max_runtime_in_seconds=None,
108110
base_job_name=None,
109111
sagemaker_session=None,
@@ -134,6 +136,12 @@ def __init__(
134136
volume_kms_key (str): A KMS key for the processing
135137
volume.
136138
output_kms_key (str): The KMS key id for all ProcessingOutputs.
139+
configuration_location (str): The S3 prefix URI where the user-provided EMR
140+
application configuration will be uploaded (default: None). If not specified,
141+
the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
142+
dependency_location (str): The S3 prefix URI where Spark dependencies will be
143+
uploaded (default: None). If not specified, the default ``dependency location``
144+
is 's3://{sagemaker-default-bucket}'.
137145
max_runtime_in_seconds (int): Timeout in seconds.
138146
After this amount of time Amazon SageMaker terminates the job
139147
regardless of its current status.
@@ -150,6 +158,8 @@ def __init__(
150158
object that configures network isolation, encryption of
151159
inter-container traffic, security group IDs, and subnets.
152160
"""
161+
self.configuration_location = configuration_location
162+
self.dependency_location = dependency_location
153163
self.history_server = None
154164
self._spark_event_logs_s3_uri = None
155165

@@ -413,19 +423,27 @@ def _stage_configuration(self, configuration):
413423
"""
414424
from sagemaker.workflow.utilities import _pipeline_config
415425

426+
if self.configuration_location:
427+
if self.configuration_location.endswith("/"):
428+
s3_prefix_uri = self.configuration_location[:-1]
429+
else:
430+
s3_prefix_uri = self.configuration_location
431+
else:
432+
s3_prefix_uri = f"s3://{self.sagemaker_session.default_bucket()}"
433+
416434
serialized_configuration = BytesIO(json.dumps(configuration).encode("utf-8"))
417435

418436
if _pipeline_config and _pipeline_config.config_hash:
419437
s3_uri = (
420-
f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/"
421-
f"{_pipeline_config.step_name}/input/"
422-
f"{self._conf_container_input_name}/{_pipeline_config.config_hash}/"
438+
f"{s3_prefix_uri}/{_pipeline_config.pipeline_name}/{_pipeline_config.step_name}/"
439+
f"input/{self._conf_container_input_name}/{_pipeline_config.config_hash}/"
423440
f"{self._conf_file_name}"
424441
)
425442
else:
426443
s3_uri = (
427-
f"s3://{self.sagemaker_session.default_bucket()}/{self._current_job_name}/"
428-
f"input/{self._conf_container_input_name}/{self._conf_file_name}"
444+
f"{s3_prefix_uri}/{self._current_job_name}/"
445+
f"input/{self._conf_container_input_name}/"
446+
f"{self._conf_file_name}"
429447
)
430448

431449
S3Uploader.upload_string_as_file_body(
@@ -447,7 +465,7 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
447465
This prepared list of paths is provided as `spark-submit` options.
448466
The submit_deps list may include a combination of S3 URIs and local paths.
449467
Any S3 URIs are appended to the `spark-submit` option value without modification.
450-
Any local file paths are copied to a temp directory, uploaded to a default S3 URI,
468+
Any local file paths are copied to a temp directory, uploaded to ``dependency location``,
451469
and included as a ProcessingInput channel to provide as local files to the SageMaker
452470
Spark container.
453471
@@ -500,16 +518,22 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
500518
if os.listdir(tmpdir):
501519
from sagemaker.workflow.utilities import _pipeline_config
502520

521+
if self.dependency_location:
522+
if self.dependency_location.endswith("/"):
523+
s3_prefix_uri = self.dependency_location[:-1]
524+
else:
525+
s3_prefix_uri = self.dependency_location
526+
else:
527+
s3_prefix_uri = f"s3://{self.sagemaker_session.default_bucket()}"
528+
503529
if _pipeline_config and _pipeline_config.code_hash:
504530
input_channel_s3_uri = (
505-
f"s3://{self.sagemaker_session.default_bucket()}"
506-
f"/{_pipeline_config.pipeline_name}/code/{_pipeline_config.code_hash}"
507-
f"/{input_channel_name}"
531+
f"{s3_prefix_uri}/{_pipeline_config.pipeline_name}/"
532+
f"code/{_pipeline_config.code_hash}/{input_channel_name}"
508533
)
509534
else:
510535
input_channel_s3_uri = (
511-
f"s3://{self.sagemaker_session.default_bucket()}"
512-
f"/{self._current_job_name}/input/{input_channel_name}"
536+
f"{s3_prefix_uri}/{self._current_job_name}/input/{input_channel_name}"
513537
)
514538
logger.info(
515539
"Uploading dependencies from tmpdir %s to S3 %s", tmpdir, input_channel_s3_uri
@@ -719,6 +743,8 @@ def __init__(
719743
volume_size_in_gb: Union[int, PipelineVariable] = 30,
720744
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
721745
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
746+
configuration_location: Optional[str] = None,
747+
dependency_location: Optional[str] = None,
722748
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
723749
base_job_name: Optional[str] = None,
724750
sagemaker_session: Optional[Session] = None,
@@ -749,6 +775,12 @@ def __init__(
749775
volume_kms_key (str or PipelineVariable): A KMS key for the processing
750776
volume.
751777
output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs.
778+
configuration_location (str): The S3 prefix URI where the user-provided EMR
779+
application configuration will be uploaded (default: None). If not specified,
780+
the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
781+
dependency_location (str): The S3 prefix URI where Spark dependencies will be
782+
uploaded (default: None). If not specified, the default ``dependency location``
783+
is 's3://{sagemaker-default-bucket}'.
752784
max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds.
753785
After this amount of time Amazon SageMaker terminates the job
754786
regardless of its current status.
@@ -779,6 +811,8 @@ def __init__(
779811
volume_size_in_gb=volume_size_in_gb,
780812
volume_kms_key=volume_kms_key,
781813
output_kms_key=output_kms_key,
814+
configuration_location=configuration_location,
815+
dependency_location=dependency_location,
782816
max_runtime_in_seconds=max_runtime_in_seconds,
783817
base_job_name=base_job_name,
784818
sagemaker_session=sagemaker_session,
@@ -986,6 +1020,8 @@ def __init__(
9861020
volume_size_in_gb: Union[int, PipelineVariable] = 30,
9871021
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
9881022
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
1023+
configuration_location: Optional[str] = None,
1024+
dependency_location: Optional[str] = None,
9891025
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
9901026
base_job_name: Optional[str] = None,
9911027
sagemaker_session: Optional[Session] = None,
@@ -1016,6 +1052,12 @@ def __init__(
10161052
volume_kms_key (str or PipelineVariable): A KMS key for the processing
10171053
volume.
10181054
output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs.
1055+
configuration_location (str): The S3 prefix URI where the user-provided EMR
1056+
application configuration will be uploaded (default: None). If not specified,
1057+
the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
1058+
dependency_location (str): The S3 prefix URI where Spark dependencies will be
1059+
uploaded (default: None). If not specified, the default ``dependency location``
1060+
is 's3://{sagemaker-default-bucket}'.
10191061
max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds.
10201062
After this amount of time Amazon SageMaker terminates the job
10211063
regardless of its current status.
@@ -1046,6 +1088,8 @@ def __init__(
10461088
volume_size_in_gb=volume_size_in_gb,
10471089
volume_kms_key=volume_kms_key,
10481090
output_kms_key=output_kms_key,
1091+
configuration_location=configuration_location,
1092+
dependency_location=dependency_location,
10491093
max_runtime_in_seconds=max_runtime_in_seconds,
10501094
base_job_name=base_job_name,
10511095
sagemaker_session=sagemaker_session,

tests/unit/sagemaker/spark/test_processing.py

+162-17
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,61 @@ def test_spark_processor_base_extend_processing_args(
273273
serialized_configuration = BytesIO("test".encode("utf-8"))
274274

275275

276+
@pytest.mark.parametrize(
277+
"config, expected",
278+
[
279+
(
280+
{
281+
"spark_processor_type": "py_spark_processor",
282+
"configuration_location": None,
283+
},
284+
"s3://bucket/None/input/conf/configuration.json",
285+
),
286+
(
287+
{
288+
"spark_processor_type": "py_spark_processor",
289+
"configuration_location": "s3://configbucket/someprefix/",
290+
},
291+
"s3://configbucket/someprefix/None/input/conf/configuration.json",
292+
),
293+
(
294+
{
295+
"spark_processor_type": "spark_jar_processor",
296+
"configuration_location": None,
297+
},
298+
"s3://bucket/None/input/conf/configuration.json",
299+
),
300+
(
301+
{
302+
"spark_processor_type": "spark_jar_processor",
303+
"configuration_location": "s3://configbucket/someprefix",
304+
},
305+
"s3://configbucket/someprefix/None/input/conf/configuration.json",
306+
),
307+
],
308+
)
276309
@patch("sagemaker.spark.processing.BytesIO")
277310
@patch("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body")
278-
def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, sagemaker_session):
279-
desired_s3_uri = "s3://bucket/None/input/conf/configuration.json"
311+
def test_stage_configuration(mock_s3_upload, mock_bytesIO, config, expected, sagemaker_session):
312+
spark_processor_type = {
313+
"py_spark_processor": PySparkProcessor,
314+
"spark_jar_processor": SparkJarProcessor,
315+
}[config["spark_processor_type"]]
316+
spark_processor = spark_processor_type(
317+
base_job_name="sm-spark",
318+
role="AmazonSageMaker-ExecutionRole",
319+
framework_version="2.4",
320+
instance_count=1,
321+
instance_type="ml.c5.xlarge",
322+
image_uri="790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1",
323+
configuration_location=config["configuration_location"],
324+
sagemaker_session=sagemaker_session,
325+
)
326+
327+
desired_s3_uri = expected
280328
mock_bytesIO.return_value = serialized_configuration
281329

282-
result = py_spark_processor._stage_configuration({})
330+
result = spark_processor._stage_configuration({})
283331

284332
mock_s3_upload.assert_called_with(
285333
body=serialized_configuration,
@@ -292,23 +340,121 @@ def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, s
292340
@pytest.mark.parametrize(
293341
"config, expected",
294342
[
295-
({"submit_deps": None, "input_channel_name": "channelName"}, ValueError),
296-
({"submit_deps": ["s3"], "input_channel_name": None}, ValueError),
297-
({"submit_deps": ["other"], "input_channel_name": "channelName"}, ValueError),
298-
({"submit_deps": ["file"], "input_channel_name": "channelName"}, ValueError),
299-
({"submit_deps": ["file"], "input_channel_name": "channelName"}, ValueError),
300343
(
301-
{"submit_deps": ["s3", "s3"], "input_channel_name": "channelName"},
344+
{
345+
"spark_processor_type": "py_spark_processor",
346+
"dependency_location": None,
347+
"submit_deps": None,
348+
"input_channel_name": "channelName",
349+
},
350+
ValueError,
351+
),
352+
(
353+
{
354+
"spark_processor_type": "py_spark_processor",
355+
"dependency_location": None,
356+
"submit_deps": ["s3"],
357+
"input_channel_name": None,
358+
},
359+
ValueError,
360+
),
361+
(
362+
{
363+
"spark_processor_type": "py_spark_processor",
364+
"dependency_location": None,
365+
"submit_deps": ["other"],
366+
"input_channel_name": "channelName",
367+
},
368+
ValueError,
369+
),
370+
(
371+
{
372+
"spark_processor_type": "py_spark_processor",
373+
"dependency_location": None,
374+
"submit_deps": ["file"],
375+
"input_channel_name": "channelName",
376+
},
377+
ValueError,
378+
),
379+
(
380+
{
381+
"spark_processor_type": "py_spark_processor",
382+
"dependency_location": None,
383+
"submit_deps": ["file"],
384+
"input_channel_name": "channelName",
385+
},
386+
ValueError,
387+
),
388+
(
389+
{
390+
"spark_processor_type": "py_spark_processor",
391+
"dependency_location": None,
392+
"submit_deps": ["s3", "s3"],
393+
"input_channel_name": "channelName",
394+
},
302395
(None, "s3://bucket,s3://bucket"),
303396
),
304397
(
305-
{"submit_deps": ["jar"], "input_channel_name": "channelName"},
306-
(processing_input, "s3://bucket"),
398+
{
399+
"spark_processor_type": "py_spark_processor",
400+
"dependency_location": None,
401+
"submit_deps": ["jar"],
402+
"input_channel_name": "channelName",
403+
},
404+
("s3://bucket/None/input/channelName", "/opt/ml/processing/input/channelName"),
405+
),
406+
(
407+
{
408+
"spark_processor_type": "py_spark_processor",
409+
"dependency_location": "s3://codebucket/someprefix/",
410+
"submit_deps": ["jar"],
411+
"input_channel_name": "channelName",
412+
},
413+
(
414+
"s3://codebucket/someprefix/None/input/channelName",
415+
"/opt/ml/processing/input/channelName",
416+
),
417+
),
418+
(
419+
{
420+
"spark_processor_type": "spark_jar_processor",
421+
"dependency_location": None,
422+
"submit_deps": ["jar"],
423+
"input_channel_name": "channelName",
424+
},
425+
("s3://bucket/None/input/channelName", "/opt/ml/processing/input/channelName"),
426+
),
427+
(
428+
{
429+
"spark_processor_type": "spark_jar_processor",
430+
"dependency_location": "s3://codebucket/someprefix",
431+
"submit_deps": ["jar"],
432+
"input_channel_name": "channelName",
433+
},
434+
(
435+
"s3://codebucket/someprefix/None/input/channelName",
436+
"/opt/ml/processing/input/channelName",
437+
),
307438
),
308439
],
309440
)
310441
@patch("sagemaker.spark.processing.S3Uploader")
311-
def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, config, expected):
442+
def test_stage_submit_deps(mock_s3_uploader, jar_file, config, expected, sagemaker_session):
443+
spark_processor_type = {
444+
"py_spark_processor": PySparkProcessor,
445+
"spark_jar_processor": SparkJarProcessor,
446+
}[config["spark_processor_type"]]
447+
spark_processor = spark_processor_type(
448+
base_job_name="sm-spark",
449+
role="AmazonSageMaker-ExecutionRole",
450+
framework_version="2.4",
451+
instance_count=1,
452+
instance_type="ml.c5.xlarge",
453+
image_uri="790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1",
454+
dependency_location=config["dependency_location"],
455+
sagemaker_session=sagemaker_session,
456+
)
457+
312458
submit_deps_dict = {
313459
None: None,
314460
"s3": "s3://bucket",
@@ -322,21 +468,20 @@ def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, confi
322468

323469
if expected is ValueError:
324470
with pytest.raises(expected) as e:
325-
py_spark_processor._stage_submit_deps(submit_deps, config["input_channel_name"])
471+
spark_processor._stage_submit_deps(submit_deps, config["input_channel_name"])
326472

327473
assert isinstance(e.value, expected)
328474
else:
329-
input_channel, spark_opt = py_spark_processor._stage_submit_deps(
475+
input_channel, spark_opt = spark_processor._stage_submit_deps(
330476
submit_deps, config["input_channel_name"]
331477
)
332478

333479
if expected[0] is None:
334480
assert input_channel is None
335481
assert spark_opt == expected[1]
336482
else:
337-
expected_source = "s3://bucket/None/input/channelName"
338-
assert input_channel.source == expected_source
339-
assert spark_opt == "/opt/ml/processing/input/channelName"
483+
assert input_channel.source == expected[0]
484+
assert spark_opt == expected[1]
340485

341486

342487
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)