Skip to content

Commit 7f15e19

Browse files
author
Roja Reddy Sareddy
committed
fix: Added handler for pipeline variable while creating process job
1 parent 10d4c4f commit 7f15e19

File tree

3 files changed

+93
-120
lines changed

3 files changed

+93
-120
lines changed

src/sagemaker/processing.py

+7-39
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
and interpretation on Amazon SageMaker.
1818
"""
1919
from __future__ import absolute_import
20-
20+
import json
2121
import logging
2222
import os
2323
import pathlib
@@ -60,10 +60,9 @@
6060
)
6161
from sagemaker.workflow import is_pipeline_variable
6262
from sagemaker.workflow.entities import PipelineVariable
63-
from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables
63+
from sagemaker.workflow.execution_variables import ExecutionVariables
6464
from sagemaker.workflow.functions import Join
6565
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
66-
from sagemaker.workflow.parameters import Parameter
6766

6867
logger = logging.getLogger(__name__)
6968

@@ -316,14 +315,14 @@ def _normalize_args(
316315
+ "rather than a pipeline variable"
317316
)
318317
if arguments is not None:
319-
normalized_arguments = []
318+
processed_arguments = []
320319
for arg in arguments:
321320
if isinstance(arg, PipelineVariable):
322-
normalized_value = self._normalize_pipeline_variable(arg)
323-
normalized_arguments.append(normalized_value)
321+
processed_value = json.dumps(arg.expr)
322+
processed_arguments.append(processed_value)
324323
else:
325-
normalized_arguments.append(str(arg))
326-
arguments = normalized_arguments
324+
processed_arguments.append(str(arg))
325+
arguments = processed_arguments
327326

328327
self._current_job_name = self._generate_current_job_name(job_name=job_name)
329328

@@ -509,37 +508,6 @@ def _normalize_outputs(self, outputs=None):
509508
normalized_outputs.append(output)
510509
return normalized_outputs
511510

512-
def _normalize_pipeline_variable(self, value):
513-
"""Helper function to normalize PipelineVariable objects"""
514-
try:
515-
if isinstance(value, Parameter):
516-
return str(value.default_value) if value.default_value is not None else None
517-
518-
elif isinstance(value, ExecutionVariable):
519-
return f"{value.name}"
520-
521-
elif isinstance(value, Join):
522-
normalized_values = [
523-
normalize_pipeline_variable(v) if isinstance(v, PipelineVariable) else str(v)
524-
for v in value.values
525-
]
526-
return value.on.join(normalized_values)
527-
528-
elif isinstance(value, PipelineVariable):
529-
if hasattr(value, 'default_value'):
530-
return str(value.default_value)
531-
elif hasattr(value, 'expr'):
532-
return str(value.expr)
533-
534-
return str(value)
535-
536-
except AttributeError as e:
537-
raise ValueError(f"Missing required attribute while normalizing {type(value).__name__}: {e}")
538-
except TypeError as e:
539-
raise ValueError(f"Type error while normalizing {type(value).__name__}: {e}")
540-
except Exception as e:
541-
raise ValueError(f"Error normalizing {type(value).__name__}: {e}")
542-
543511

544512
class ScriptProcessor(Processor):
545513
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""

tests/unit/sagemaker/workflow/test_processing_step.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,12 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
824824
processor, run_inputs = spark_processor
825825
processor.sagemaker_session = pipeline_session
826826
processor.role = ROLE
827-
827+
arguments_output = [
828+
"--input",
829+
"input-data-uri",
830+
"--output",
831+
'{"Get": "Parameters.MyArgOutput"}',
832+
]
828833
run_inputs["inputs"] = processing_input
829834

830835
step_args = processor.run(**run_inputs)
@@ -835,7 +840,7 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
835840

836841
step_args = get_step_args_helper(step_args, "Processing")
837842

838-
assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"]
843+
assert step_args["AppSpecification"]["ContainerArguments"] == arguments_output
839844

840845
entry_points = step_args["AppSpecification"]["ContainerEntrypoint"]
841846
entry_points_expr = []
@@ -1019,6 +1024,12 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_
10191024
processor, run_inputs = spark_processor
10201025
processor.sagemaker_session = pipeline_session
10211026
processor.role = ROLE
1027+
arguments_output = [
1028+
"--input",
1029+
"input-data-uri",
1030+
"--output",
1031+
'{"Get": "Parameters.MyArgOutput"}',
1032+
]
10221033

10231034
run_inputs["inputs"] = processing_input
10241035

@@ -1030,7 +1041,7 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_
10301041

10311042
step_args = get_step_args_helper(step_args, "Processing")
10321043

1033-
assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"]
1044+
assert step_args["AppSpecification"]["ContainerArguments"] == arguments_output
10341045

10351046
entry_points = step_args["AppSpecification"]["ContainerEntrypoint"]
10361047
entry_points_expr = []

0 commit comments

Comments
 (0)