Skip to content

Commit 394a329

Browse files
qidewenwhenJoseJuan98
authored andcommitted
fix: Make test_processor_with_role_as_pipeline_parameter more concrete (aws#3618)
1 parent c8f89a5 commit 394a329

File tree

4 files changed

+63
-63
lines changed

4 files changed

+63
-63
lines changed

src/sagemaker/processing.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,7 @@ def __init__(
13611361
self,
13621362
estimator_cls: type,
13631363
framework_version: str,
1364-
role: str,
1364+
role: Union[str, PipelineVariable],
13651365
instance_count: Union[int, PipelineVariable],
13661366
instance_type: Union[str, PipelineVariable],
13671367
py_version: str = "py3",
@@ -1389,8 +1389,9 @@ def __init__(
13891389
estimator
13901390
framework_version (str): The version of the framework. Value is ignored when
13911391
``image_uri`` is provided.
1392-
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing uses
1393-
this role to access AWS resources, such as data stored in Amazon S3.
1392+
role (str or PipelineVariable): An AWS IAM role name or ARN. Amazon SageMaker
1393+
Processing uses this role to access AWS resources, such as data stored
1394+
in Amazon S3.
13941395
instance_count (int or PipelineVariable): The number of instances to run a
13951396
processing job with.
13961397
instance_type (str or PipelineVariable): The type of EC2 instance to use for

tests/integ/sagemaker/workflow/test_processing_steps.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,10 @@ def test_multi_step_framework_processing_pipeline_same_source_dir(
385385

386386
SOURCE_DIR = "/pipeline/test_source_dir"
387387

388+
role_param = ParameterString(name="Role", default_value=role)
389+
388390
framework_processor_tf = FrameworkProcessor(
389-
role=role,
391+
role=role_param,
390392
instance_type="ml.m5.xlarge",
391393
instance_count=1,
392394
estimator_cls=TensorFlow,
@@ -400,7 +402,7 @@ def test_multi_step_framework_processing_pipeline_same_source_dir(
400402
instance_type="ml.m5.xlarge",
401403
instance_count=1,
402404
base_job_name="my-job",
403-
role=role,
405+
role=role_param,
404406
estimator_cls=SKLearn,
405407
sagemaker_session=pipeline_session,
406408
)
@@ -431,7 +433,10 @@ def test_multi_step_framework_processing_pipeline_same_source_dir(
431433
)
432434

433435
pipeline = Pipeline(
434-
name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session
436+
name=pipeline_name,
437+
steps=[step_1, step_2],
438+
sagemaker_session=pipeline_session,
439+
parameters=[role_param],
435440
)
436441
try:
437442
pipeline.create(role)

tests/unit/sagemaker/workflow/test_processing_step.py

+51
Original file line numberDiff line numberDiff line change
@@ -1064,3 +1064,54 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_
10641064
step_def = json.loads(pipeline.definition())["Steps"][0]
10651065
step_def2 = json.loads(pipeline.definition())["Steps"][0]
10661066
assert step_def == step_def2
1067+
1068+
1069+
_PARAM_ROLE_NAME = "Role"
1070+
1071+
1072+
@pytest.mark.parametrize(
1073+
"processor_args",
1074+
[
1075+
(
1076+
ScriptProcessor(
1077+
role=ParameterString(name=_PARAM_ROLE_NAME, default_value=ROLE),
1078+
image_uri=IMAGE_URI,
1079+
instance_count=1,
1080+
instance_type="ml.m4.xlarge",
1081+
command=["python3"],
1082+
),
1083+
{"code": DUMMY_S3_SCRIPT_PATH},
1084+
),
1085+
(
1086+
Processor(
1087+
role=ParameterString(name=_PARAM_ROLE_NAME, default_value=ROLE),
1088+
image_uri=IMAGE_URI,
1089+
instance_count=1,
1090+
instance_type="ml.m4.xlarge",
1091+
),
1092+
{},
1093+
),
1094+
],
1095+
)
1096+
@patch("os.path.exists", return_value=True)
1097+
@patch("os.path.isfile", return_value=True)
1098+
def test_processor_with_role_as_pipeline_parameter(
1099+
exists_mock, isfile_mock, processor_args, pipeline_session
1100+
):
1101+
processor, run_inputs = processor_args
1102+
processor.sagemaker_session = pipeline_session
1103+
processor.run(**run_inputs)
1104+
1105+
step_args = processor.run(**run_inputs)
1106+
step = ProcessingStep(
1107+
name="MyProcessingStep",
1108+
step_args=step_args,
1109+
)
1110+
pipeline = Pipeline(
1111+
name="MyPipeline",
1112+
steps=[step],
1113+
sagemaker_session=pipeline_session,
1114+
)
1115+
1116+
step_def = json.loads(pipeline.definition())["Steps"][0]
1117+
assert step_def["Arguments"]["RoleArn"] == {"Get": f"Parameters.{_PARAM_ROLE_NAME}"}

tests/unit/test_processing.py

-57
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from sagemaker.sklearn.processing import SKLearnProcessor
3535
from sagemaker.pytorch.processing import PyTorchProcessor
3636
from sagemaker.tensorflow.processing import TensorFlowProcessor
37-
from sagemaker.workflow import ParameterString
3837
from sagemaker.xgboost.processing import XGBoostProcessor
3938
from sagemaker.mxnet.processing import MXNetProcessor
4039
from sagemaker.network import NetworkConfig
@@ -738,62 +737,6 @@ def test_processor_with_required_parameters(sagemaker_session):
738737
sagemaker_session.process.assert_called_with(**expected_args)
739738

740739

741-
def test_processor_with_role_as_pipeline_parameter(sagemaker_session):
742-
743-
role = ParameterString(name="Role", default_value=ROLE)
744-
745-
processor = Processor(
746-
role=role,
747-
image_uri=CUSTOM_IMAGE_URI,
748-
instance_count=1,
749-
instance_type="ml.m4.xlarge",
750-
sagemaker_session=sagemaker_session,
751-
)
752-
753-
processor.run()
754-
755-
expected_args = _get_expected_args(processor._current_job_name)
756-
assert expected_args["role_arn"] == role.default_value
757-
758-
759-
@patch("os.path.exists", return_value=True)
760-
@patch("os.path.isfile", return_value=True)
761-
def test_script_processor_with_role_as_pipeline_parameter(
762-
exists_mock, isfile_mock, sagemaker_session
763-
):
764-
role = ParameterString(name="Role", default_value=ROLE)
765-
766-
script_processor = ScriptProcessor(
767-
role=role,
768-
image_uri=CUSTOM_IMAGE_URI,
769-
instance_count=1,
770-
instance_type="ml.m4.xlarge",
771-
sagemaker_session=sagemaker_session,
772-
command=["python3"],
773-
)
774-
775-
run_args = script_processor.get_run_args(
776-
code="/local/path/to/processing_code.py",
777-
inputs=_get_data_inputs_all_parameters(),
778-
outputs=_get_data_outputs_all_parameters(),
779-
arguments=["--drop-columns", "'SelfEmployed'"],
780-
)
781-
782-
script_processor.run(
783-
code=run_args.code,
784-
inputs=run_args.inputs,
785-
outputs=run_args.outputs,
786-
arguments=run_args.arguments,
787-
wait=True,
788-
logs=False,
789-
job_name="my_job_name",
790-
experiment_config={"ExperimentName": "AnExperiment"},
791-
)
792-
793-
expected_args = _get_expected_args(script_processor._current_job_name)
794-
assert expected_args["role_arn"] == role.default_value
795-
796-
797740
def test_processor_with_missing_network_config_parameters(sagemaker_session):
798741
processor = Processor(
799742
role=ROLE,

0 commit comments

Comments
 (0)