Skip to content

Commit f6963c6

Browse files
author
Dewen Qi
committed
fix: Prevent passing pipeline variables as code arg in ProcessingStep in compile time
1 parent bc5082e commit f6963c6

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

src/sagemaker/processing.py

+7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.local import LocalSession
3535
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
3636
from sagemaker.session import Session
37+
from sagemaker.workflow import is_pipeline_variable
3738
from sagemaker.workflow.properties import Properties
3839
from sagemaker.workflow.parameters import Parameter
3940
from sagemaker.workflow.entities import Expression
@@ -233,6 +234,12 @@ def _normalize_args(
233234
kms_key (str): The ARN of the KMS key that is used to encrypt the
234235
user code file (default: None).
235236
"""
237+
if code and is_pipeline_variable(code):
238+
raise ValueError(
239+
f"code argument {code} has to be a valid S3 URI or local file path "
240+
+ "rather than a pipeline variable"
241+
)
242+
236243
self._current_job_name = self._generate_current_job_name(job_name=job_name)
237244

238245
inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)

tests/unit/sagemaker/workflow/test_steps.py

+38
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,44 @@ def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, sc
666666
)
667667

668668

669+
def test_processing_step_normalizes_args_with_param_str_local_code(
670+
sagemaker_session, script_processor
671+
):
672+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
673+
code_param = ParameterString(name="Script", default_value="S3://my-bucket/file_name.py")
674+
inputs = [
675+
ProcessingInput(
676+
source=f"s3://{BUCKET}/processing_manifest",
677+
destination="processing_manifest",
678+
)
679+
]
680+
outputs = [
681+
ProcessingOutput(
682+
source=f"s3://{BUCKET}/processing_manifest",
683+
destination="processing_manifest",
684+
)
685+
]
686+
step = ProcessingStep(
687+
name="MyProcessingStep",
688+
processor=script_processor,
689+
code=code_param,
690+
inputs=inputs,
691+
outputs=outputs,
692+
job_arguments=["arg1", "arg2"],
693+
cache_config=cache_config,
694+
)
695+
pipeline = Pipeline(
696+
name="MyPipeline",
697+
parameters=[code_param],
698+
steps=[step],
699+
sagemaker_session=sagemaker_session,
700+
)
701+
with pytest.raises(ValueError) as error:
702+
pipeline.definition()
703+
704+
assert "has to be a valid S3 URI or local file path" in str(error.value)
705+
706+
669707
@patch("sagemaker.processing.ScriptProcessor._normalize_args")
670708
def test_processing_step_normalizes_args_with_s3_code(mock_normalize_args, script_processor):
671709
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")

0 commit comments

Comments
 (0)