Skip to content

Commit bfd2ae0

Browse files
committed
add input parameterization tests for workflow job steps
1 parent 4acbdb0 commit bfd2ae0

File tree

3 files changed

+326
-229
lines changed

3 files changed

+326
-229
lines changed

tests/unit/sagemaker/workflow/test_processing_step.py

+172-115
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@
2424
from sagemaker.tuner import HyperparameterTuner
2525
from sagemaker.workflow.pipeline_context import PipelineSession
2626

27-
from sagemaker.processing import Processor, ScriptProcessor, FrameworkProcessor
27+
from sagemaker.processing import (
28+
Processor,
29+
ScriptProcessor,
30+
FrameworkProcessor,
31+
ProcessingOutput,
32+
ProcessingInput,
33+
)
2834
from sagemaker.sklearn.processing import SKLearnProcessor
2935
from sagemaker.pytorch.processing import PyTorchProcessor
3036
from sagemaker.tensorflow.processing import TensorFlowProcessor
@@ -34,11 +40,12 @@
3440
from sagemaker.wrangler.processing import DataWranglerProcessor
3541
from sagemaker.spark.processing import SparkJarProcessor, PySparkProcessor
3642

37-
from sagemaker.processing import ProcessingInput
3843

3944
from sagemaker.workflow.steps import CacheConfig, ProcessingStep
4045
from sagemaker.workflow.pipeline import Pipeline
4146
from sagemaker.workflow.properties import PropertyFile
47+
from sagemaker.workflow.parameters import ParameterString
48+
from sagemaker.workflow.functions import Join
4249

4350
from sagemaker.network import NetworkConfig
4451
from sagemaker.pytorch.estimator import PyTorch
@@ -62,6 +69,146 @@
6269
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
6370
INSTANCE_TYPE = "ml.m4.xlarge"
6471

72+
FRAMEWORK_PROCESSOR = [
73+
(
74+
FrameworkProcessor(
75+
framework_version="1.8",
76+
instance_type=INSTANCE_TYPE,
77+
instance_count=1,
78+
role=ROLE,
79+
estimator_cls=PyTorch,
80+
),
81+
{"code": DUMMY_S3_SCRIPT_PATH},
82+
),
83+
(
84+
SKLearnProcessor(
85+
framework_version="0.23-1",
86+
instance_type=INSTANCE_TYPE,
87+
instance_count=1,
88+
role=ROLE,
89+
),
90+
{"code": DUMMY_S3_SCRIPT_PATH},
91+
),
92+
(
93+
PyTorchProcessor(
94+
role=ROLE,
95+
instance_type=INSTANCE_TYPE,
96+
instance_count=1,
97+
framework_version="1.8.0",
98+
py_version="py3",
99+
),
100+
{"code": DUMMY_S3_SCRIPT_PATH},
101+
),
102+
(
103+
TensorFlowProcessor(
104+
role=ROLE,
105+
instance_type=INSTANCE_TYPE,
106+
instance_count=1,
107+
framework_version="2.0",
108+
),
109+
{"code": DUMMY_S3_SCRIPT_PATH},
110+
),
111+
(
112+
HuggingFaceProcessor(
113+
transformers_version="4.6",
114+
pytorch_version="1.7",
115+
role=ROLE,
116+
instance_count=1,
117+
instance_type="ml.p3.2xlarge",
118+
),
119+
{"code": DUMMY_S3_SCRIPT_PATH},
120+
),
121+
(
122+
XGBoostProcessor(
123+
framework_version="1.3-1",
124+
py_version="py3",
125+
role=ROLE,
126+
instance_count=1,
127+
instance_type=INSTANCE_TYPE,
128+
base_job_name="test-xgboost",
129+
),
130+
{"code": DUMMY_S3_SCRIPT_PATH},
131+
),
132+
(
133+
MXNetProcessor(
134+
framework_version="1.4.1",
135+
py_version="py3",
136+
role=ROLE,
137+
instance_count=1,
138+
instance_type=INSTANCE_TYPE,
139+
base_job_name="test-mxnet",
140+
),
141+
{"code": DUMMY_S3_SCRIPT_PATH},
142+
),
143+
(
144+
DataWranglerProcessor(
145+
role=ROLE,
146+
data_wrangler_flow_source="s3://my-bucket/dw.flow",
147+
instance_count=1,
148+
instance_type=INSTANCE_TYPE,
149+
),
150+
{},
151+
),
152+
(
153+
SparkJarProcessor(
154+
role=ROLE,
155+
framework_version="2.4",
156+
instance_count=1,
157+
instance_type=INSTANCE_TYPE,
158+
),
159+
{
160+
"submit_app": "s3://my-jar",
161+
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
162+
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
163+
},
164+
),
165+
(
166+
PySparkProcessor(
167+
role=ROLE,
168+
framework_version="2.4",
169+
instance_count=1,
170+
instance_type=INSTANCE_TYPE,
171+
),
172+
{
173+
"submit_app": "s3://my-jar",
174+
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
175+
},
176+
),
177+
]
178+
179+
PROCESSING_INPUT = [
180+
ProcessingInput(
181+
source=f"s3://my-bucket/processing_manifest", destination="processing_manifest"
182+
),
183+
ProcessingInput(
184+
source=ParameterString(name="my-processing-input"),
185+
destination="processing-input",
186+
),
187+
ProcessingInput(
188+
source=ParameterString(
189+
name="my-processing-input", default_value="s3://my-bucket/my-processing"
190+
),
191+
destination="processing-input",
192+
),
193+
ProcessingInput(
194+
source=Join(on="/", values=["s3://my-bucket", "my-input"]),
195+
destination="processing-input",
196+
),
197+
]
198+
199+
PROCESSING_OUTPUT = [
200+
ProcessingOutput(source="/opt/ml/output", destination="s3://my-bucket/my-output"),
201+
ProcessingOutput(source="/opt/ml/output", destination=ParameterString(name="my-output")),
202+
ProcessingOutput(
203+
source="/opt/ml/output",
204+
destination=ParameterString(name="my-output", default_value="s3://my-bucket/my-output"),
205+
),
206+
ProcessingOutput(
207+
source="/opt/ml/output",
208+
destination=Join(on="/", values=["s3://my-bucket", "my-output"]),
209+
),
210+
]
211+
65212

66213
@pytest.fixture
67214
def client():
@@ -253,117 +400,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
253400
}
254401

255402

256-
@pytest.mark.parametrize(
257-
"framework_processor",
258-
[
259-
(
260-
FrameworkProcessor(
261-
framework_version="1.8",
262-
instance_type=INSTANCE_TYPE,
263-
instance_count=1,
264-
role=ROLE,
265-
estimator_cls=PyTorch,
266-
),
267-
{"code": DUMMY_S3_SCRIPT_PATH},
268-
),
269-
(
270-
SKLearnProcessor(
271-
framework_version="0.23-1",
272-
instance_type=INSTANCE_TYPE,
273-
instance_count=1,
274-
role=ROLE,
275-
),
276-
{"code": DUMMY_S3_SCRIPT_PATH},
277-
),
278-
(
279-
PyTorchProcessor(
280-
role=ROLE,
281-
instance_type=INSTANCE_TYPE,
282-
instance_count=1,
283-
framework_version="1.8.0",
284-
py_version="py3",
285-
),
286-
{"code": DUMMY_S3_SCRIPT_PATH},
287-
),
288-
(
289-
TensorFlowProcessor(
290-
role=ROLE,
291-
instance_type=INSTANCE_TYPE,
292-
instance_count=1,
293-
framework_version="2.0",
294-
),
295-
{"code": DUMMY_S3_SCRIPT_PATH},
296-
),
297-
(
298-
HuggingFaceProcessor(
299-
transformers_version="4.6",
300-
pytorch_version="1.7",
301-
role=ROLE,
302-
instance_count=1,
303-
instance_type="ml.p3.2xlarge",
304-
),
305-
{"code": DUMMY_S3_SCRIPT_PATH},
306-
),
307-
(
308-
XGBoostProcessor(
309-
framework_version="1.3-1",
310-
py_version="py3",
311-
role=ROLE,
312-
instance_count=1,
313-
instance_type=INSTANCE_TYPE,
314-
base_job_name="test-xgboost",
315-
),
316-
{"code": DUMMY_S3_SCRIPT_PATH},
317-
),
318-
(
319-
MXNetProcessor(
320-
framework_version="1.4.1",
321-
py_version="py3",
322-
role=ROLE,
323-
instance_count=1,
324-
instance_type=INSTANCE_TYPE,
325-
base_job_name="test-mxnet",
326-
),
327-
{"code": DUMMY_S3_SCRIPT_PATH},
328-
),
329-
(
330-
DataWranglerProcessor(
331-
role=ROLE,
332-
data_wrangler_flow_source=f"s3://{BUCKET}/dw.flow",
333-
instance_count=1,
334-
instance_type=INSTANCE_TYPE,
335-
),
336-
{},
337-
),
338-
(
339-
SparkJarProcessor(
340-
role=ROLE,
341-
framework_version="2.4",
342-
instance_count=1,
343-
instance_type=INSTANCE_TYPE,
344-
),
345-
{
346-
"submit_app": "s3://my-jar",
347-
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
348-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
349-
},
350-
),
351-
(
352-
PySparkProcessor(
353-
role=ROLE,
354-
framework_version="2.4",
355-
instance_count=1,
356-
instance_type=INSTANCE_TYPE,
357-
),
358-
{
359-
"submit_app": "s3://my-jar",
360-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
361-
},
362-
),
363-
],
364-
)
403+
@pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR)
404+
@pytest.mark.parametrize("processing_input", PROCESSING_INPUT)
405+
@pytest.mark.parametrize("processing_output", PROCESSING_OUTPUT)
365406
def test_processing_step_with_framework_processor(
366-
framework_processor, pipeline_session, processing_input, network_config
407+
framework_processor, pipeline_session, processing_input, processing_output, network_config
367408
):
368409

369410
processor, run_inputs = framework_processor
@@ -373,7 +414,8 @@ def test_processing_step_with_framework_processor(
373414
processor.volume_kms_key = "volume-kms-key"
374415
processor.network_config = network_config
375416

376-
run_inputs["inputs"] = processing_input
417+
run_inputs["inputs"] = [processing_input]
418+
run_inputs["outputs"] = [processing_output]
377419

378420
step_args = processor.run(**run_inputs)
379421

@@ -387,10 +429,25 @@ def test_processing_step_with_framework_processor(
387429
sagemaker_session=pipeline_session,
388430
)
389431

390-
assert json.loads(pipeline.definition())["Steps"][0] == {
432+
step_args = step_args.args
433+
step_def = json.loads(pipeline.definition())["Steps"][0]
434+
435+
assert step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"] == processing_input.source
436+
assert (
437+
step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
438+
== processing_output.destination
439+
)
440+
441+
del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"]
442+
del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"]
443+
444+
del step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
445+
del step_def["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
446+
447+
assert step_def == {
391448
"Name": "MyProcessingStep",
392449
"Type": "Processing",
393-
"Arguments": step_args.args,
450+
"Arguments": step_args,
394451
}
395452

396453

0 commit comments

Comments
 (0)