diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index af52da6288..2605fb9e27 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -59,7 +59,7 @@ class Processor(object): def __init__( self, - role: str, + role: Union[str, PipelineVariable], image_uri: Union[str, PipelineVariable], instance_count: Union[int, PipelineVariable], instance_type: Union[str, PipelineVariable], @@ -79,7 +79,7 @@ def __init__( The ``Processor`` handles Amazon SageMaker Processing tasks. Args: - role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing + role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker Processing uses this role to access AWS resources, such as data stored in Amazon S3. image_uri (str or PipelineVariable): The URI of the Docker image to use for the @@ -438,7 +438,7 @@ class ScriptProcessor(Processor): def __init__( self, - role: str, + role: Union[str, PipelineVariable], image_uri: Union[str, PipelineVariable], command: List[str], instance_count: Union[int, PipelineVariable], @@ -460,7 +460,7 @@ def __init__( run as part of the Processing Job. Args: - role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing + role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker Processing uses this role to access AWS resources, such as data stored in Amazon S3. image_uri (str or PipelineVariable): The URI of the Docker image to use for the @@ -931,7 +931,11 @@ def _get_process_args(cls, processor, inputs, outputs, experiment_config): else: process_request_args["network_config"] = None - process_request_args["role_arn"] = processor.sagemaker_session.expand_role(processor.role) + process_request_args["role_arn"] = ( + processor.role + if is_pipeline_variable(processor.role) + else processor.sagemaker_session.expand_role(processor.role) + ) process_request_args["tags"] = processor.tags diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 34c530747d..5eb91747b3 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -34,6 +34,7 @@ from sagemaker.sklearn.processing import SKLearnProcessor from sagemaker.pytorch.processing import PyTorchProcessor from sagemaker.tensorflow.processing import TensorFlowProcessor +from sagemaker.workflow import ParameterString from sagemaker.xgboost.processing import XGBoostProcessor from sagemaker.mxnet.processing import MXNetProcessor from sagemaker.network import NetworkConfig @@ -737,6 +738,62 @@ def test_processor_with_required_parameters(sagemaker_session): sagemaker_session.process.assert_called_with(**expected_args) +def test_processor_with_role_as_pipeline_parameter(sagemaker_session): + + role = ParameterString(name="Role", default_value=ROLE) + + processor = Processor( + role=role, + image_uri=CUSTOM_IMAGE_URI, + instance_count=1, + instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + ) + + processor.run() + + expected_args = _get_expected_args(processor._current_job_name) + assert expected_args["role_arn"] == role.default_value + + +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=True) +def test_script_processor_with_role_as_pipeline_parameter( + exists_mock, isfile_mock, sagemaker_session +): + role = ParameterString(name="Role", default_value=ROLE) + + script_processor = ScriptProcessor( + role=role, + image_uri=CUSTOM_IMAGE_URI, + instance_count=1, + instance_type="ml.m4.xlarge", + sagemaker_session=sagemaker_session, + command=["python3"], + ) + + run_args = script_processor.get_run_args( + code="/local/path/to/processing_code.py", + inputs=_get_data_inputs_all_parameters(), + outputs=_get_data_outputs_all_parameters(), + arguments=["--drop-columns", "'SelfEmployed'"], + ) + + script_processor.run( + code=run_args.code, + inputs=run_args.inputs, + outputs=run_args.outputs, + arguments=run_args.arguments, + wait=True, + logs=False, + job_name="my_job_name", + experiment_config={"ExperimentName": "AnExperiment"}, + ) + + expected_args = _get_expected_args(script_processor._current_job_name) + assert expected_args["role_arn"] == role.default_value + + def test_processor_with_missing_network_config_parameters(sagemaker_session): processor = Processor( role=ROLE,