Skip to content

Commit 353530d

Browse files
committed
fix: allow kms_key to be passed for processing step
1 parent 99f023e commit 353530d

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/sagemaker/workflow/steps.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ def __init__(
461461
cache_config: CacheConfig = None,
462462
depends_on: Union[List[str], List[Step]] = None,
463463
retry_policies: List[RetryPolicy] = None,
464+
kms_key=None,
464465
):
465466
"""Construct a ProcessingStep, given a `Processor` instance.
466467
@@ -486,6 +487,8 @@ def __init__(
486487
depends_on (List[str] or List[Step]): A list of step names or step instance
487488
this `sagemaker.workflow.steps.ProcessingStep` depends on
488489
retry_policies (List[RetryPolicy]): A list of retry policy
490+
kms_key (str): The ARN of the KMS key that is used to encrypt the
491+
user code file. Defaults to `None`.
489492
"""
490493
super(ProcessingStep, self).__init__(
491494
name, StepTypeEnum.PROCESSING, display_name, description, depends_on, retry_policies
@@ -496,6 +499,7 @@ def __init__(
496499
self.job_arguments = job_arguments
497500
self.code = code
498501
self.property_files = property_files
502+
self.kms_key = kms_key
499503

500504
# Examine why run method in sagemaker.processing.Processor mutates the processor instance
501505
# by setting the instance's arguments attribute. Refactor Processor.run, if possible.
@@ -518,8 +522,8 @@ def arguments(self) -> RequestType:
518522
inputs=self.inputs,
519523
outputs=self.outputs,
520524
code=self.code,
525+
kms_key=self.kms_key,
521526
)
522-
523527
process_args = ProcessingJob._get_process_args(
524528
self.processor, normalized_inputs, normalized_outputs, experiment_config=dict()
525529
)

tests/unit/sagemaker/workflow/test_steps.py

+2
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
516516
outputs=outputs,
517517
job_arguments=["arg1", "arg2"],
518518
cache_config=cache_config,
519+
kms_key="arn:aws:kms:us-west-2:012345678901:key/s3-kms-key",
519520
)
520521
mock_normalize_args.return_value = [step.inputs, step.outputs]
521522
step.to_request()
@@ -524,6 +525,7 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
524525
inputs=step.inputs,
525526
outputs=step.outputs,
526527
code=step.code,
528+
kms_key=step.kms_key,
527529
)
528530

529531

0 commit comments

Comments
 (0)