Skip to content

Commit 641021d

Browse files
committed
fix: allow kms_key to be passed for processing step
1 parent 948debc commit 641021d

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/sagemaker/workflow/steps.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def __init__(
475475
cache_config: CacheConfig = None,
476476
depends_on: Union[List[str], List[Step]] = None,
477477
retry_policies: List[RetryPolicy] = None,
478+
kms_key=None,
478479
):
479480
"""Construct a ProcessingStep, given a `Processor` instance.
480481
@@ -500,6 +501,8 @@ def __init__(
500501
depends_on (List[str] or List[Step]): A list of step names or step instance
501502
this `sagemaker.workflow.steps.ProcessingStep` depends on
502503
retry_policies (List[RetryPolicy]): A list of retry policy
504+
kms_key (str): The ARN of the KMS key that is used to encrypt the
505+
user code file. Defaults to `None`.
503506
"""
504507
super(ProcessingStep, self).__init__(
505508
name, StepTypeEnum.PROCESSING, display_name, description, depends_on, retry_policies
@@ -511,6 +514,7 @@ def __init__(
511514
self.code = code
512515
self.property_files = property_files
513516
self.job_name = None
517+
self.kms_key = kms_key
514518

515519
# Examine why run method in sagemaker.processing.Processor mutates the processor instance
516520
# by setting the instance's arguments attribute. Refactor Processor.run, if possible.
@@ -545,8 +549,8 @@ def arguments(self) -> RequestType:
545549
inputs=self.inputs,
546550
outputs=self.outputs,
547551
code=self.code,
552+
kms_key=self.kms_key,
548553
)
549-
550554
process_args = ProcessingJob._get_process_args(
551555
self.processor, normalized_inputs, normalized_outputs, experiment_config=dict()
552556
)

tests/unit/sagemaker/workflow/test_steps.py

+4
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ def test_processing_step_normalizes_args_with_local_code(mock_normalize_args, sc
598598
inputs=step.inputs,
599599
outputs=step.outputs,
600600
code=step.code,
601+
kms_key=None,
601602
)
602603

603604

@@ -624,6 +625,7 @@ def test_processing_step_normalizes_args_with_s3_code(mock_normalize_args, scrip
624625
outputs=outputs,
625626
job_arguments=["arg1", "arg2"],
626627
cache_config=cache_config,
628+
kms_key="arn:aws:kms:us-west-2:012345678901:key/s3-kms-key",
627629
)
628630
mock_normalize_args.return_value = [step.inputs, step.outputs]
629631
step.to_request()
@@ -633,6 +635,7 @@ def test_processing_step_normalizes_args_with_s3_code(mock_normalize_args, scrip
633635
inputs=step.inputs,
634636
outputs=step.outputs,
635637
code=step.code,
638+
kms_key=step.kms_key,
636639
)
637640

638641

@@ -667,6 +670,7 @@ def test_processing_step_normalizes_args_with_no_code(mock_normalize_args, scrip
667670
inputs=step.inputs,
668671
outputs=step.outputs,
669672
code=None,
673+
kms_key=None,
670674
)
671675

672676

0 commit comments

Comments
 (0)