Skip to content

Commit 549a719

Browse files
feat: locations for EMR configuration and Spark dependencies
1 parent a35a093 commit 549a719

File tree

2 files changed

+218
-28
lines changed

2 files changed

+218
-28
lines changed

src/sagemaker/spark/processing.py

+56-11
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def __init__(
103103
volume_size_in_gb=30,
104104
volume_kms_key=None,
105105
output_kms_key=None,
106+
configuration_location: Optional[str] = None,
107+
dependency_location: Optional[str] = None,
106108
max_runtime_in_seconds=None,
107109
base_job_name=None,
108110
sagemaker_session=None,
@@ -133,6 +135,12 @@ def __init__(
133135
volume_kms_key (str): A KMS key for the processing
134136
volume.
135137
output_kms_key (str): The KMS key id for all ProcessingOutputs.
138+
configuration_location (str): The S3 prefix URI where the user-provided EMR
139+
application configuration will be uploaded (default: None). If not specified,
140+
the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
141+
dependency_location (str): The S3 prefix URI where Spark dependencies will be
142+
uploaded (default: None). If not specified, the default ``dependency location``
143+
is 's3://{sagemaker-default-bucket}'.
136144
max_runtime_in_seconds (int): Timeout in seconds.
137145
After this amount of time Amazon SageMaker terminates the job
138146
regardless of its current status.
@@ -149,6 +157,8 @@ def __init__(
149157
object that configures network isolation, encryption of
150158
inter-container traffic, security group IDs, and subnets.
151159
"""
160+
self.configuration_location = configuration_location
161+
self.dependency_location = dependency_location
152162
self.history_server = None
153163
self._spark_event_logs_s3_uri = None
154164

@@ -403,19 +413,27 @@ def _stage_configuration(self, configuration):
403413
"""
404414
from sagemaker.workflow.utilities import _pipeline_config
405415

416+
if self.configuration_location:
417+
if self.configuration_location.endswith("/"):
418+
s3_prefix_uri = self.configuration_location[:-1]
419+
else:
420+
s3_prefix_uri = self.configuration_location
421+
else:
422+
s3_prefix_uri = f"s3://{self.sagemaker_session.default_bucket()}"
423+
406424
serialized_configuration = BytesIO(json.dumps(configuration).encode("utf-8"))
407425

408426
if _pipeline_config and _pipeline_config.config_hash:
409427
s3_uri = (
410-
f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/"
411-
f"{_pipeline_config.step_name}/input/"
412-
f"{self._conf_container_input_name}/{_pipeline_config.config_hash}/"
428+
f"{s3_prefix_uri}/{_pipeline_config.pipeline_name}/{_pipeline_config.step_name}/"
429+
f"input/{self._conf_container_input_name}/{_pipeline_config.config_hash}/"
413430
f"{self._conf_file_name}"
414431
)
415432
else:
416433
s3_uri = (
417-
f"s3://{self.sagemaker_session.default_bucket()}/{self._current_job_name}/"
418-
f"input/{self._conf_container_input_name}/{self._conf_file_name}"
434+
f"{s3_prefix_uri}/{self._current_job_name}/"
435+
f"input/{self._conf_container_input_name}/"
436+
f"{self._conf_file_name}"
419437
)
420438

421439
S3Uploader.upload_string_as_file_body(
@@ -437,7 +455,7 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
437455
This prepared list of paths is provided as `spark-submit` options.
438456
The submit_deps list may include a combination of S3 URIs and local paths.
439457
Any S3 URIs are appended to the `spark-submit` option value without modification.
440-
Any local file paths are copied to a temp directory, uploaded to a default S3 URI,
458+
Any local file paths are copied to a temp directory, uploaded to ``dependency location``,
441459
and included as a ProcessingInput channel to provide as local files to the SageMaker
442460
Spark container.
443461
@@ -490,16 +508,23 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
490508
if os.listdir(tmpdir):
491509
from sagemaker.workflow.utilities import _pipeline_config
492510

511+
if self.dependency_location:
512+
if self.dependency_location.endswith("/"):
513+
s3_prefix_uri = self.dependency_location[:-1]
514+
else:
515+
s3_prefix_uri = self.dependency_location
516+
else:
517+
s3_prefix_uri = f"s3://{self.sagemaker_session.default_bucket()}"
518+
493519
if _pipeline_config and _pipeline_config.code_hash:
494520
input_channel_s3_uri = (
495-
f"s3://{self.sagemaker_session.default_bucket()}"
496-
f"/{_pipeline_config.pipeline_name}/code/{_pipeline_config.code_hash}"
497-
f"/{input_channel_name}"
521+
f"{s3_prefix_uri}/{_pipeline_config.pipeline_name}/"
522+
f"code/{_pipeline_config.code_hash}/{input_channel_name}"
498523
)
499524
else:
500525
input_channel_s3_uri = (
501-
f"s3://{self.sagemaker_session.default_bucket()}"
502-
f"/{self._current_job_name}/input/{input_channel_name}"
526+
f"{s3_prefix_uri}/{self._current_job_name}/"
527+
f"input/{input_channel_name}"
503528
)
504529
logger.info(
505530
"Uploading dependencies from tmpdir %s to S3 %s", tmpdir, input_channel_s3_uri
@@ -709,6 +734,8 @@ def __init__(
709734
volume_size_in_gb: Union[int, PipelineVariable] = 30,
710735
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
711736
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
737+
configuration_location: Optional[str] = None,
738+
dependency_location: Optional[str] = None,
712739
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
713740
base_job_name: Optional[str] = None,
714741
sagemaker_session: Optional[Session] = None,
@@ -739,6 +766,12 @@ def __init__(
739766
volume_kms_key (str or PipelineVariable): A KMS key for the processing
740767
volume.
741768
output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs.
769+
configuration_location (str): The S3 prefix URI where the user-provided EMR
770+
application configuration will be uploaded (default: None). If not specified,
771+
the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
772+
dependency_location (str): The S3 prefix URI where Spark dependencies will be
773+
uploaded (default: None). If not specified, the default ``dependency location``
774+
is 's3://{sagemaker-default-bucket}'.
742775
max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds.
743776
After this amount of time Amazon SageMaker terminates the job
744777
regardless of its current status.
@@ -769,6 +802,8 @@ def __init__(
769802
volume_size_in_gb=volume_size_in_gb,
770803
volume_kms_key=volume_kms_key,
771804
output_kms_key=output_kms_key,
805+
configuration_location=configuration_location,
806+
dependency_location=dependency_location,
772807
max_runtime_in_seconds=max_runtime_in_seconds,
773808
base_job_name=base_job_name,
774809
sagemaker_session=sagemaker_session,
@@ -969,6 +1004,8 @@ def __init__(
9691004
volume_size_in_gb: Union[int, PipelineVariable] = 30,
9701005
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
9711006
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
1007+
configuration_location: Optional[str] = None,
1008+
dependency_location: Optional[str] = None,
9721009
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
9731010
base_job_name: Optional[str] = None,
9741011
sagemaker_session: Optional[Session] = None,
@@ -999,6 +1036,12 @@ def __init__(
9991036
volume_kms_key (str or PipelineVariable): A KMS key for the processing
10001037
volume.
10011038
output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs.
1039+
configuration_location (str): The S3 prefix URI where the user-provided EMR
1040+
application configuration will be uploaded (default: None). If not specified,
1041+
the default ``configuration location`` is 's3://{sagemaker-default-bucket}'.
1042+
dependency_location (str): The S3 prefix URI where Spark dependencies will be
1043+
uploaded (default: None). If not specified, the default ``dependency location``
1044+
is 's3://{sagemaker-default-bucket}'.
10021045
max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds.
10031046
After this amount of time Amazon SageMaker terminates the job
10041047
regardless of its current status.
@@ -1029,6 +1072,8 @@ def __init__(
10291072
volume_size_in_gb=volume_size_in_gb,
10301073
volume_kms_key=volume_kms_key,
10311074
output_kms_key=output_kms_key,
1075+
configuration_location=configuration_location,
1076+
dependency_location=dependency_location,
10321077
max_runtime_in_seconds=max_runtime_in_seconds,
10331078
base_job_name=base_job_name,
10341079
sagemaker_session=sagemaker_session,

tests/unit/sagemaker/spark/test_processing.py

+162-17
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,61 @@ def test_spark_processor_base_extend_processing_args(
271271
serialized_configuration = BytesIO("test".encode("utf-8"))
272272

273273

274+
@pytest.mark.parametrize(
275+
"config, expected",
276+
[
277+
(
278+
{
279+
"spark_processor_type": "py_spark_processor",
280+
"configuration_location": None,
281+
},
282+
"s3://bucket/None/input/conf/configuration.json",
283+
),
284+
(
285+
{
286+
"spark_processor_type": "py_spark_processor",
287+
"configuration_location": "s3://configbucket/someprefix/",
288+
},
289+
"s3://configbucket/someprefix/None/input/conf/configuration.json",
290+
),
291+
(
292+
{
293+
"spark_processor_type": "spark_jar_processor",
294+
"configuration_location": None,
295+
},
296+
"s3://bucket/None/input/conf/configuration.json",
297+
),
298+
(
299+
{
300+
"spark_processor_type": "spark_jar_processor",
301+
"configuration_location": "s3://configbucket/someprefix",
302+
},
303+
"s3://configbucket/someprefix/None/input/conf/configuration.json",
304+
),
305+
],
306+
)
274307
@patch("sagemaker.spark.processing.BytesIO")
275308
@patch("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body")
276-
def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, sagemaker_session):
277-
desired_s3_uri = "s3://bucket/None/input/conf/configuration.json"
309+
def test_stage_configuration(mock_s3_upload, mock_bytesIO, config, expected, sagemaker_session):
310+
spark_processor_type = {
311+
"py_spark_processor": PySparkProcessor,
312+
"spark_jar_processor": SparkJarProcessor,
313+
}[config["spark_processor_type"]]
314+
spark_processor = spark_processor_type(
315+
base_job_name="sm-spark",
316+
role="AmazonSageMaker-ExecutionRole",
317+
framework_version="2.4",
318+
instance_count=1,
319+
instance_type="ml.c5.xlarge",
320+
image_uri="790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1",
321+
configuration_location=config["configuration_location"],
322+
sagemaker_session=sagemaker_session,
323+
)
324+
325+
desired_s3_uri = expected
278326
mock_bytesIO.return_value = serialized_configuration
279327

280-
result = py_spark_processor._stage_configuration({})
328+
result = spark_processor._stage_configuration({})
281329

282330
mock_s3_upload.assert_called_with(
283331
body=serialized_configuration,
@@ -290,23 +338,121 @@ def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, s
290338
@pytest.mark.parametrize(
291339
"config, expected",
292340
[
293-
({"submit_deps": None, "input_channel_name": "channelName"}, ValueError),
294-
({"submit_deps": ["s3"], "input_channel_name": None}, ValueError),
295-
({"submit_deps": ["other"], "input_channel_name": "channelName"}, ValueError),
296-
({"submit_deps": ["file"], "input_channel_name": "channelName"}, ValueError),
297-
({"submit_deps": ["file"], "input_channel_name": "channelName"}, ValueError),
298341
(
299-
{"submit_deps": ["s3", "s3"], "input_channel_name": "channelName"},
342+
{
343+
"spark_processor_type": "py_spark_processor",
344+
"dependency_location": None,
345+
"submit_deps": None,
346+
"input_channel_name": "channelName",
347+
},
348+
ValueError,
349+
),
350+
(
351+
{
352+
"spark_processor_type": "py_spark_processor",
353+
"dependency_location": None,
354+
"submit_deps": ["s3"],
355+
"input_channel_name": None,
356+
},
357+
ValueError,
358+
),
359+
(
360+
{
361+
"spark_processor_type": "py_spark_processor",
362+
"dependency_location": None,
363+
"submit_deps": ["other"],
364+
"input_channel_name": "channelName",
365+
},
366+
ValueError,
367+
),
368+
(
369+
{
370+
"spark_processor_type": "py_spark_processor",
371+
"dependency_location": None,
372+
"submit_deps": ["file"],
373+
"input_channel_name": "channelName",
374+
},
375+
ValueError,
376+
),
377+
(
378+
{
379+
"spark_processor_type": "py_spark_processor",
380+
"dependency_location": None,
381+
"submit_deps": ["file"],
382+
"input_channel_name": "channelName",
383+
},
384+
ValueError,
385+
),
386+
(
387+
{
388+
"spark_processor_type": "py_spark_processor",
389+
"dependency_location": None,
390+
"submit_deps": ["s3", "s3"],
391+
"input_channel_name": "channelName",
392+
},
300393
(None, "s3://bucket,s3://bucket"),
301394
),
302395
(
303-
{"submit_deps": ["jar"], "input_channel_name": "channelName"},
304-
(processing_input, "s3://bucket"),
396+
{
397+
"spark_processor_type": "py_spark_processor",
398+
"dependency_location": None,
399+
"submit_deps": ["jar"],
400+
"input_channel_name": "channelName",
401+
},
402+
("s3://bucket/None/input/channelName", "/opt/ml/processing/input/channelName"),
403+
),
404+
(
405+
{
406+
"spark_processor_type": "py_spark_processor",
407+
"dependency_location": "s3://codebucket/someprefix/",
408+
"submit_deps": ["jar"],
409+
"input_channel_name": "channelName",
410+
},
411+
(
412+
"s3://codebucket/someprefix/None/input/channelName",
413+
"/opt/ml/processing/input/channelName"
414+
),
415+
),
416+
(
417+
{
418+
"spark_processor_type": "spark_jar_processor",
419+
"dependency_location": None,
420+
"submit_deps": ["jar"],
421+
"input_channel_name": "channelName",
422+
},
423+
("s3://bucket/None/input/channelName", "/opt/ml/processing/input/channelName"),
424+
),
425+
(
426+
{
427+
"spark_processor_type": "spark_jar_processor",
428+
"dependency_location": "s3://codebucket/someprefix",
429+
"submit_deps": ["jar"],
430+
"input_channel_name": "channelName",
431+
},
432+
(
433+
"s3://codebucket/someprefix/None/input/channelName",
434+
"/opt/ml/processing/input/channelName"
435+
),
305436
),
306437
],
307438
)
308439
@patch("sagemaker.spark.processing.S3Uploader")
309-
def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, config, expected):
440+
def test_stage_submit_deps(mock_s3_uploader, jar_file, config, expected, sagemaker_session):
441+
spark_processor_type = {
442+
"py_spark_processor": PySparkProcessor,
443+
"spark_jar_processor": SparkJarProcessor,
444+
}[config["spark_processor_type"]]
445+
spark_processor = spark_processor_type(
446+
base_job_name="sm-spark",
447+
role="AmazonSageMaker-ExecutionRole",
448+
framework_version="2.4",
449+
instance_count=1,
450+
instance_type="ml.c5.xlarge",
451+
image_uri="790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1",
452+
dependency_location=config["dependency_location"],
453+
sagemaker_session=sagemaker_session,
454+
)
455+
310456
submit_deps_dict = {
311457
None: None,
312458
"s3": "s3://bucket",
@@ -320,21 +466,20 @@ def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, confi
320466

321467
if expected is ValueError:
322468
with pytest.raises(expected) as e:
323-
py_spark_processor._stage_submit_deps(submit_deps, config["input_channel_name"])
469+
spark_processor._stage_submit_deps(submit_deps, config["input_channel_name"])
324470

325471
assert isinstance(e.value, expected)
326472
else:
327-
input_channel, spark_opt = py_spark_processor._stage_submit_deps(
473+
input_channel, spark_opt = spark_processor._stage_submit_deps(
328474
submit_deps, config["input_channel_name"]
329475
)
330476

331477
if expected[0] is None:
332478
assert input_channel is None
333479
assert spark_opt == expected[1]
334480
else:
335-
expected_source = "s3://bucket/None/input/channelName"
336-
assert input_channel.source == expected_source
337-
assert spark_opt == "/opt/ml/processing/input/channelName"
481+
assert input_channel.source == expected[0]
482+
assert spark_opt == expected[1]
338483

339484

340485
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)