diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index b87d89e811..65fa3c7dbc 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1361,7 +1361,7 @@ def __init__( self, estimator_cls: type, framework_version: str, - role: str, + role: Union[str, PipelineVariable], instance_count: Union[int, PipelineVariable], instance_type: Union[str, PipelineVariable], py_version: str = "py3", @@ -1389,8 +1389,9 @@ def __init__( estimator framework_version (str): The version of the framework. Value is ignored when ``image_uri`` is provided. - role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing uses - this role to access AWS resources, such as data stored in Amazon S3. + role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker + Processing uses this role to access AWS resources, such as data stored + in Amazon S3. instance_count (int or PipelineVariable): The number of instances to run a processing job with. instance_type (str or PipelineVariable): The type of EC2 instance to use for diff --git a/tests/integ/sagemaker/workflow/test_processing_steps.py b/tests/integ/sagemaker/workflow/test_processing_steps.py index e18e1f7627..09e8bac2a2 100644 --- a/tests/integ/sagemaker/workflow/test_processing_steps.py +++ b/tests/integ/sagemaker/workflow/test_processing_steps.py @@ -385,8 +385,10 @@ def test_multi_step_framework_processing_pipeline_same_source_dir( SOURCE_DIR = "/pipeline/test_source_dir" + role_param = ParameterString(name="Role", default_value=role) + framework_processor_tf = FrameworkProcessor( - role=role, + role=role_param, instance_type="ml.m5.xlarge", instance_count=1, estimator_cls=TensorFlow, @@ -400,7 +402,7 @@ def test_multi_step_framework_processing_pipeline_same_source_dir( instance_type="ml.m5.xlarge", instance_count=1, base_job_name="my-job", - role=role, + role=role_param, estimator_cls=SKLearn, sagemaker_session=pipeline_session, ) @@ -431,7 +433,10 @@ def test_multi_step_framework_processing_pipeline_same_source_dir( ) pipeline = Pipeline( - name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + name=pipeline_name, + steps=[step_1, step_2], + sagemaker_session=pipeline_session, + parameters=[role_param], ) try: pipeline.create(role) diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 862e9fc8f6..4ce6e5302c 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -1064,3 +1064,54 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_ step_def = json.loads(pipeline.definition())["Steps"][0] step_def2 = json.loads(pipeline.definition())["Steps"][0] assert step_def == step_def2 + + +_PARAM_ROLE_NAME = "Role" + + +@pytest.mark.parametrize( + "processor_args", + [ + ( + ScriptProcessor( + role=ParameterString(name=_PARAM_ROLE_NAME, default_value=ROLE), + image_uri=IMAGE_URI, + instance_count=1, + instance_type="ml.m4.xlarge", + command=["python3"], + ), + {"code": DUMMY_S3_SCRIPT_PATH}, + ), + ( + Processor( + role=ParameterString(name=_PARAM_ROLE_NAME, default_value=ROLE), + image_uri=IMAGE_URI, + instance_count=1, + instance_type="ml.m4.xlarge", + ), + {}, + ), + ], +) +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=True) +def test_processor_with_role_as_pipeline_parameter( + exists_mock, isfile_mock, processor_args, pipeline_session +): + processor, run_inputs = processor_args + processor.sagemaker_session = pipeline_session + processor.run(**run_inputs) + + step_args = processor.run(**run_inputs) + step = ProcessingStep( + name="MyProcessingStep", + step_args=step_args, + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_def["Arguments"]["RoleArn"] == {"Get": f"Parameters.{_PARAM_ROLE_NAME}"} diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 5eb91747b3..34c530747d 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -34,7 +34,6 @@ from sagemaker.sklearn.processing import SKLearnProcessor from sagemaker.pytorch.processing import PyTorchProcessor from sagemaker.tensorflow.processing import TensorFlowProcessor -from sagemaker.workflow import ParameterString from sagemaker.xgboost.processing import XGBoostProcessor from sagemaker.mxnet.processing import MXNetProcessor from sagemaker.network import NetworkConfig @@ -738,62 +737,6 @@ def test_processor_with_required_parameters(sagemaker_session): sagemaker_session.process.assert_called_with(**expected_args) -def test_processor_with_role_as_pipeline_parameter(sagemaker_session): - - role = ParameterString(name="Role", default_value=ROLE) - - processor = Processor( - role=role, - image_uri=CUSTOM_IMAGE_URI, - instance_count=1, - instance_type="ml.m4.xlarge", - sagemaker_session=sagemaker_session, - ) - - processor.run() - - expected_args = _get_expected_args(processor._current_job_name) - assert expected_args["role_arn"] == role.default_value - - -@patch("os.path.exists", return_value=True) -@patch("os.path.isfile", return_value=True) -def test_script_processor_with_role_as_pipeline_parameter( - exists_mock, isfile_mock, sagemaker_session -): - role = ParameterString(name="Role", default_value=ROLE) - - script_processor = ScriptProcessor( - role=role, - image_uri=CUSTOM_IMAGE_URI, - instance_count=1, - instance_type="ml.m4.xlarge", - sagemaker_session=sagemaker_session, - command=["python3"], - ) - - run_args = script_processor.get_run_args( - code="/local/path/to/processing_code.py", - inputs=_get_data_inputs_all_parameters(), - outputs=_get_data_outputs_all_parameters(), - arguments=["--drop-columns", "'SelfEmployed'"], - ) - - script_processor.run( - code=run_args.code, - inputs=run_args.inputs, - outputs=run_args.outputs, - arguments=run_args.arguments, - wait=True, - logs=False, - job_name="my_job_name", - experiment_config={"ExperimentName": "AnExperiment"}, - ) - - expected_args = _get_expected_args(script_processor._current_job_name) - assert expected_args["role_arn"] == role.default_value - - def test_processor_with_missing_network_config_parameters(sagemaker_session): processor = Processor( role=ROLE,