Skip to content

Commit 85d5be2

Browse files
author
Brock Wade
committed
bug-fix: support idempotency for framework and spark processors
1 parent 885423c commit 85d5be2

File tree

9 files changed

+638
-85
lines changed

9 files changed

+638
-85
lines changed

src/sagemaker/processing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,14 +1827,19 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
18271827
# a7399455f5386d83ddc5cb15c0db00c04bd518ec/src/sagemaker/processing.py#L425-L426
18281828
if inputs is None:
18291829
inputs = []
1830-
inputs.append(
1830+
1831+
# make a shallow copy of user inputs
1832+
patched_inputs = []
1833+
for user_input in inputs:
1834+
patched_inputs.append(user_input)
1835+
patched_inputs.append(
18311836
ProcessingInput(
18321837
input_name="code",
18331838
source=s3_payload,
18341839
destination="/opt/ml/processing/input/code/",
18351840
)
18361841
)
1837-
return inputs
1842+
return patched_inputs
18381843

18391844
def _set_entrypoint(self, command, user_script_name):
18401845
"""Framework processor override for setting processing job entrypoint.

src/sagemaker/spark/processing.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,9 +940,18 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
940940
outputs: Processing outputs.
941941
kwargs: Additional keyword arguments passed to `super()`.
942942
"""
943+
944+
if inputs is None:
945+
inputs = []
946+
947+
# make a shallow copy of user inputs
948+
extended_inputs = []
949+
for user_input in inputs:
950+
extended_inputs.append(user_input)
951+
943952
self.command = [_SparkProcessorBase._default_command]
944953
extended_inputs = self._handle_script_dependencies(
945-
inputs, kwargs.get("submit_py_files"), FileType.PYTHON
954+
extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON
946955
)
947956
extended_inputs = self._handle_script_dependencies(
948957
extended_inputs, kwargs.get("submit_jars"), FileType.JAR
@@ -1199,8 +1208,16 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
11991208
else:
12001209
raise ValueError("submit_class is required")
12011210

1211+
if inputs is None:
1212+
inputs = []
1213+
1214+
# make a shallow copy of user inputs
1215+
extended_inputs = []
1216+
for user_input in inputs:
1217+
extended_inputs.append(user_input)
1218+
12021219
extended_inputs = self._handle_script_dependencies(
1203-
inputs, kwargs.get("submit_jars"), FileType.JAR
1220+
extended_inputs, kwargs.get("submit_jars"), FileType.JAR
12041221
)
12051222
extended_inputs = self._handle_script_dependencies(
12061223
extended_inputs, kwargs.get("submit_files"), FileType.FILE

src/sagemaker/workflow/utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
168168
str: A hash string representing the unique code artifact(s) for the step
169169
"""
170170

171-
# FrameworkProcessor
171+
# If FrameworkProcessor contains source_dir
172172
if source_dir:
173173
source_dir_url = urlparse(source_dir)
174174
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
@@ -400,5 +400,5 @@ def execute_job_functions(step_args: _StepArguments):
400400
"""
401401

402402
chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs)
403-
if chained_args:
403+
if isinstance(chained_args, _StepArguments):
404404
execute_job_functions(chained_args)
1.67 KB
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)