Skip to content

Commit 75d1f2c

Browse files
feature: Support role as PipelineParameter in Processor class (#3605)
1 parent 22c1ca7 commit 75d1f2c

File tree

2 files changed

+66
-5
lines changed

2 files changed

+66
-5
lines changed

src/sagemaker/processing.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class Processor(object):
5959

6060
def __init__(
6161
self,
62-
role: str,
62+
role: Union[str, PipelineVariable],
6363
image_uri: Union[str, PipelineVariable],
6464
instance_count: Union[int, PipelineVariable],
6565
instance_type: Union[str, PipelineVariable],
@@ -79,7 +79,7 @@ def __init__(
7979
The ``Processor`` handles Amazon SageMaker Processing tasks.
8080
8181
Args:
82-
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
82+
role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker Processing
8383
uses this role to access AWS resources, such as
8484
data stored in Amazon S3.
8585
image_uri (str or PipelineVariable): The URI of the Docker image to use for the
@@ -438,7 +438,7 @@ class ScriptProcessor(Processor):
438438

439439
def __init__(
440440
self,
441-
role: str,
441+
role: Union[str, PipelineVariable],
442442
image_uri: Union[str, PipelineVariable],
443443
command: List[str],
444444
instance_count: Union[int, PipelineVariable],
@@ -460,7 +460,7 @@ def __init__(
460460
run as part of the Processing Job.
461461
462462
Args:
463-
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
463+
role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker Processing
464464
uses this role to access AWS resources, such as
465465
data stored in Amazon S3.
466466
image_uri (str or PipelineVariable): The URI of the Docker image to use for the
@@ -931,7 +931,11 @@ def _get_process_args(cls, processor, inputs, outputs, experiment_config):
931931
else:
932932
process_request_args["network_config"] = None
933933

934-
process_request_args["role_arn"] = processor.sagemaker_session.expand_role(processor.role)
934+
process_request_args["role_arn"] = (
935+
processor.role
936+
if is_pipeline_variable(processor.role)
937+
else processor.sagemaker_session.expand_role(processor.role)
938+
)
935939

936940
process_request_args["tags"] = processor.tags
937941

tests/unit/test_processing.py

+57
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from sagemaker.sklearn.processing import SKLearnProcessor
3535
from sagemaker.pytorch.processing import PyTorchProcessor
3636
from sagemaker.tensorflow.processing import TensorFlowProcessor
37+
from sagemaker.workflow import ParameterString
3738
from sagemaker.xgboost.processing import XGBoostProcessor
3839
from sagemaker.mxnet.processing import MXNetProcessor
3940
from sagemaker.network import NetworkConfig
@@ -737,6 +738,62 @@ def test_processor_with_required_parameters(sagemaker_session):
737738
sagemaker_session.process.assert_called_with(**expected_args)
738739

739740

741+
def test_processor_with_role_as_pipeline_parameter(sagemaker_session):
742+
743+
role = ParameterString(name="Role", default_value=ROLE)
744+
745+
processor = Processor(
746+
role=role,
747+
image_uri=CUSTOM_IMAGE_URI,
748+
instance_count=1,
749+
instance_type="ml.m4.xlarge",
750+
sagemaker_session=sagemaker_session,
751+
)
752+
753+
processor.run()
754+
755+
expected_args = _get_expected_args(processor._current_job_name)
756+
assert expected_args["role_arn"] == role.default_value
757+
758+
759+
@patch("os.path.exists", return_value=True)
760+
@patch("os.path.isfile", return_value=True)
761+
def test_script_processor_with_role_as_pipeline_parameter(
762+
exists_mock, isfile_mock, sagemaker_session
763+
):
764+
role = ParameterString(name="Role", default_value=ROLE)
765+
766+
script_processor = ScriptProcessor(
767+
role=role,
768+
image_uri=CUSTOM_IMAGE_URI,
769+
instance_count=1,
770+
instance_type="ml.m4.xlarge",
771+
sagemaker_session=sagemaker_session,
772+
command=["python3"],
773+
)
774+
775+
run_args = script_processor.get_run_args(
776+
code="/local/path/to/processing_code.py",
777+
inputs=_get_data_inputs_all_parameters(),
778+
outputs=_get_data_outputs_all_parameters(),
779+
arguments=["--drop-columns", "'SelfEmployed'"],
780+
)
781+
782+
script_processor.run(
783+
code=run_args.code,
784+
inputs=run_args.inputs,
785+
outputs=run_args.outputs,
786+
arguments=run_args.arguments,
787+
wait=True,
788+
logs=False,
789+
job_name="my_job_name",
790+
experiment_config={"ExperimentName": "AnExperiment"},
791+
)
792+
793+
expected_args = _get_expected_args(script_processor._current_job_name)
794+
assert expected_args["role_arn"] == role.default_value
795+
796+
740797
def test_processor_with_missing_network_config_parameters(sagemaker_session):
741798
processor = Processor(
742799
role=ROLE,

0 commit comments

Comments
 (0)