|
34 | 34 | from sagemaker.sklearn.processing import SKLearnProcessor
|
35 | 35 | from sagemaker.pytorch.processing import PyTorchProcessor
|
36 | 36 | from sagemaker.tensorflow.processing import TensorFlowProcessor
|
| 37 | +from sagemaker.workflow import ParameterString |
37 | 38 | from sagemaker.xgboost.processing import XGBoostProcessor
|
38 | 39 | from sagemaker.mxnet.processing import MXNetProcessor
|
39 | 40 | from sagemaker.network import NetworkConfig
|
@@ -737,6 +738,62 @@ def test_processor_with_required_parameters(sagemaker_session):
|
737 | 738 | sagemaker_session.process.assert_called_with(**expected_args)
|
738 | 739 |
|
739 | 740 |
|
| 741 | +def test_processor_with_role_as_pipeline_parameter(sagemaker_session): |
| 742 | + |
| 743 | + role = ParameterString(name="Role", default_value=ROLE) |
| 744 | + |
| 745 | + processor = Processor( |
| 746 | + role=role, |
| 747 | + image_uri=CUSTOM_IMAGE_URI, |
| 748 | + instance_count=1, |
| 749 | + instance_type="ml.m4.xlarge", |
| 750 | + sagemaker_session=sagemaker_session, |
| 751 | + ) |
| 752 | + |
| 753 | + processor.run() |
| 754 | + |
| 755 | + expected_args = _get_expected_args(processor._current_job_name) |
| 756 | + assert expected_args["role_arn"] == role.default_value |
| 757 | + |
| 758 | + |
| 759 | +@patch("os.path.exists", return_value=True) |
| 760 | +@patch("os.path.isfile", return_value=True) |
| 761 | +def test_script_processor_with_role_as_pipeline_parameter( |
| 762 | + exists_mock, isfile_mock, sagemaker_session |
| 763 | +): |
| 764 | + role = ParameterString(name="Role", default_value=ROLE) |
| 765 | + |
| 766 | + script_processor = ScriptProcessor( |
| 767 | + role=role, |
| 768 | + image_uri=CUSTOM_IMAGE_URI, |
| 769 | + instance_count=1, |
| 770 | + instance_type="ml.m4.xlarge", |
| 771 | + sagemaker_session=sagemaker_session, |
| 772 | + command=["python3"], |
| 773 | + ) |
| 774 | + |
| 775 | + run_args = script_processor.get_run_args( |
| 776 | + code="/local/path/to/processing_code.py", |
| 777 | + inputs=_get_data_inputs_all_parameters(), |
| 778 | + outputs=_get_data_outputs_all_parameters(), |
| 779 | + arguments=["--drop-columns", "'SelfEmployed'"], |
| 780 | + ) |
| 781 | + |
| 782 | + script_processor.run( |
| 783 | + code=run_args.code, |
| 784 | + inputs=run_args.inputs, |
| 785 | + outputs=run_args.outputs, |
| 786 | + arguments=run_args.arguments, |
| 787 | + wait=True, |
| 788 | + logs=False, |
| 789 | + job_name="my_job_name", |
| 790 | + experiment_config={"ExperimentName": "AnExperiment"}, |
| 791 | + ) |
| 792 | + |
| 793 | + expected_args = _get_expected_args(script_processor._current_job_name) |
| 794 | + assert expected_args["role_arn"] == role.default_value |
| 795 | + |
| 796 | + |
740 | 797 | def test_processor_with_missing_network_config_parameters(sagemaker_session):
|
741 | 798 | processor = Processor(
|
742 | 799 | role=ROLE,
|
|
0 commit comments