Skip to content

Commit 718b8da

Browse files
jayatalrshreyapandit
authored andcommitted
fix: allow kms_key to be passed for processing step (#2779)
1 parent c4d3b9e commit 718b8da

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/sagemaker/workflow/steps.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ def __init__(
486486
cache_config: CacheConfig = None,
487487
depends_on: Union[List[str], List[Step]] = None,
488488
retry_policies: List[RetryPolicy] = None,
489+
kms_key=None,
489490
):
490491
"""Construct a ProcessingStep, given a `Processor` instance.
491492
@@ -511,6 +512,8 @@ def __init__(
511512
depends_on (List[str] or List[Step]): A list of step names or step instance
512513
this `sagemaker.workflow.steps.ProcessingStep` depends on
513514
retry_policies (List[RetryPolicy]): A list of retry policy
515+
kms_key (str): The ARN of the KMS key that is used to encrypt the
516+
user code file. Defaults to `None`.
514517
"""
515518
super(ProcessingStep, self).__init__(
516519
name, StepTypeEnum.PROCESSING, display_name, description, depends_on, retry_policies
@@ -522,6 +525,7 @@ def __init__(
522525
self.code = code
523526
self.property_files = property_files
524527
self.job_name = None
528+
self.kms_key = kms_key
525529

526530
# Examine why run method in sagemaker.processing.Processor mutates the processor instance
527531
# by setting the instance's arguments attribute. Refactor Processor.run, if possible.
@@ -556,8 +560,8 @@ def arguments(self) -> RequestType:
556560
inputs=self.inputs,
557561
outputs=self.outputs,
558562
code=self.code,
563+
kms_key=self.kms_key,
559564
)
560-
561565
process_args = ProcessingJob._get_process_args(
562566
self.processor, normalized_inputs, normalized_outputs, experiment_config=dict()
563567
)

tests/unit/sagemaker/workflow/test_steps.py

+4
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, sc
617617
inputs=step.inputs,
618618
outputs=step.outputs,
619619
code=step.code,
620+
kms_key=None,
620621
)
621622

622623

@@ -643,6 +644,7 @@ def test_processing_step_normalizes_args_with_s3_code(mock_normalize_args, scrip
643644
outputs=outputs,
644645
job_arguments=["arg1", "arg2"],
645646
cache_config=cache_config,
647+
kms_key="arn:aws:kms:us-west-2:012345678901:key/s3-kms-key",
646648
)
647649
mock_normalize_args.return_value = [step.inputs, step.outputs]
648650
step.to_request()
@@ -652,6 +654,7 @@ def test_processing_step_normalizes_args_with_s3_code(mock_normalize_args, scrip
652654
inputs=step.inputs,
653655
outputs=step.outputs,
654656
code=step.code,
657+
kms_key=step.kms_key,
655658
)
656659

657660

@@ -686,6 +689,7 @@ def test_processing_step_normalizes_args_with_no_code(mock_normalize_args, scrip
686689
inputs=step.inputs,
687690
outputs=step.outputs,
688691
code=None,
692+
kms_key=None,
689693
)
690694

691695

0 commit comments

Comments
 (0)