Skip to content

Commit e8f0753

Browse files
brockwade633Brock Wademufaddal-rohawala
authored andcommitted
fix: support idempotency for framework and spark processors (aws#3460)
Co-authored-by: Brock Wade <[email protected]> Co-authored-by: Mufaddal Rohawala <[email protected]>
1 parent f39b427 commit e8f0753

File tree

10 files changed

+661
-96
lines changed

10 files changed

+661
-96
lines changed

src/sagemaker/processing.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import logging
2424
from textwrap import dedent
2525
from typing import Dict, List, Optional, Union
26+
from copy import copy
2627

2728
import attr
2829

@@ -1830,14 +1831,17 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
18301831
# a7399455f5386d83ddc5cb15c0db00c04bd518ec/src/sagemaker/processing.py#L425-L426
18311832
if inputs is None:
18321833
inputs = []
1833-
inputs.append(
1834+
1835+
# make a shallow copy of user inputs
1836+
patched_inputs = copy(inputs)
1837+
patched_inputs.append(
18341838
ProcessingInput(
18351839
input_name="code",
18361840
source=s3_payload,
18371841
destination="/opt/ml/processing/input/code/",
18381842
)
18391843
)
1840-
return inputs
1844+
return patched_inputs
18411845

18421846
def _set_entrypoint(self, command, user_script_name):
18431847
"""Framework processor override for setting processing job entrypoint.

src/sagemaker/spark/processing.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from enum import Enum
3131
from io import BytesIO
3232
from urllib.parse import urlparse
33+
from copy import copy
3334

3435
from typing import Union, List, Dict, Optional
3536

@@ -279,6 +280,10 @@ def run(
279280
def _extend_processing_args(self, inputs, outputs, **kwargs):
280281
"""Extends processing job args such as inputs."""
281282

283+
# make a shallow copy of user outputs
284+
outputs = outputs or []
285+
extended_outputs = copy(outputs)
286+
282287
if kwargs.get("spark_event_logs_s3_uri"):
283288
spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri")
284289
self._validate_s3_uri(spark_event_logs_s3_uri)
@@ -297,16 +302,21 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
297302
s3_upload_mode="Continuous",
298303
)
299304

300-
outputs = outputs or []
301-
outputs.append(output)
305+
extended_outputs.append(output)
306+
307+
# make a shallow copy of user inputs
308+
inputs = inputs or []
309+
extended_inputs = copy(inputs)
302310

303311
if kwargs.get("configuration"):
304312
configuration = kwargs.get("configuration")
305313
self._validate_configuration(configuration)
306-
inputs = inputs or []
307-
inputs.append(self._stage_configuration(configuration))
314+
extended_inputs.append(self._stage_configuration(configuration))
308315

309-
return inputs, outputs
316+
return (
317+
extended_inputs if extended_inputs else None,
318+
extended_outputs if extended_outputs else None,
319+
)
310320

311321
def start_history_server(self, spark_event_logs_s3_uri=None):
312322
"""Starts a Spark history server.
@@ -940,9 +950,16 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
940950
outputs: Processing outputs.
941951
kwargs: Additional keyword arguments passed to `super()`.
942952
"""
953+
954+
if inputs is None:
955+
inputs = []
956+
957+
# make a shallow copy of user inputs
958+
extended_inputs = copy(inputs)
959+
943960
self.command = [_SparkProcessorBase._default_command]
944961
extended_inputs = self._handle_script_dependencies(
945-
inputs, kwargs.get("submit_py_files"), FileType.PYTHON
962+
extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON
946963
)
947964
extended_inputs = self._handle_script_dependencies(
948965
extended_inputs, kwargs.get("submit_jars"), FileType.JAR
@@ -1199,8 +1216,14 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
11991216
else:
12001217
raise ValueError("submit_class is required")
12011218

1219+
if inputs is None:
1220+
inputs = []
1221+
1222+
# make a shallow copy of user inputs
1223+
extended_inputs = copy(inputs)
1224+
12021225
extended_inputs = self._handle_script_dependencies(
1203-
inputs, kwargs.get("submit_jars"), FileType.JAR
1226+
extended_inputs, kwargs.get("submit_jars"), FileType.JAR
12041227
)
12051228
extended_inputs = self._handle_script_dependencies(
12061229
extended_inputs, kwargs.get("submit_files"), FileType.FILE

src/sagemaker/workflow/utilities.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ def get_code_hash(step: Entity) -> str:
114114
if isinstance(step, ProcessingStep) and step.step_args:
115115
kwargs = step.step_args.func_kwargs
116116
source_dir = kwargs.get("source_dir")
117+
submit_class = kwargs.get("submit_class")
117118
dependencies = get_processing_dependencies(
118119
[
119120
kwargs.get("dependencies"),
120121
kwargs.get("submit_py_files"),
121-
kwargs.get("submit_class"),
122+
[submit_class] if submit_class else None,
122123
kwargs.get("submit_jars"),
123124
kwargs.get("submit_files"),
124125
]
@@ -168,7 +169,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
168169
str: A hash string representing the unique code artifact(s) for the step
169170
"""
170171

171-
# FrameworkProcessor
172+
# If FrameworkProcessor contains source_dir
172173
if source_dir:
173174
source_dir_url = urlparse(source_dir)
174175
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
@@ -400,5 +401,5 @@ def execute_job_functions(step_args: _StepArguments):
400401
"""
401402

402403
chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs)
403-
if chained_args:
404+
if isinstance(chained_args, _StepArguments):
404405
execute_job_functions(chained_args)
1.67 KB
Binary file not shown.
Binary file not shown.

tests/unit/sagemaker/workflow/test_pipeline.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from mock import Mock
20+
from mock import Mock, patch
2121

2222
from sagemaker import s3
2323
from sagemaker.workflow.condition_step import ConditionStep
@@ -78,6 +78,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
7878
)
7979

8080

81+
@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
8182
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
8283
parameter = ParameterString("MyStr")
8384
pipeline = Pipeline(
@@ -87,8 +88,6 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
8788
sagemaker_session=sagemaker_session_mock,
8889
)
8990

90-
s3.S3Uploader.upload_string_as_file_body = Mock()
91-
9291
pipeline.create(role_arn=role_arn)
9392

9493
assert s3.S3Uploader.upload_string_as_file_body.called_with(
@@ -151,6 +150,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
151150
)
152151

153152

153+
@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
154154
def test_large_pipeline_update(sagemaker_session_mock, role_arn):
155155
parameter = ParameterString("MyStr")
156156
pipeline = Pipeline(
@@ -160,8 +160,6 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
160160
sagemaker_session=sagemaker_session_mock,
161161
)
162162

163-
s3.S3Uploader.upload_string_as_file_body = Mock()
164-
165163
pipeline.create(role_arn=role_arn)
166164

167165
assert s3.S3Uploader.upload_string_as_file_body.called_with(

0 commit comments

Comments
 (0)