Skip to content

Commit 2b5b4da

Browse files
add input parameterization tests for workflow job steps (#3150)
* fix pipeline doc code example where process.run only accepts argument * remove unused imports
1 parent e6c210a commit 2b5b4da

File tree

3 files changed

+324
-258
lines changed

3 files changed

+324
-258
lines changed

tests/unit/sagemaker/workflow/test_processing_step.py

+170-115
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@
2424
from sagemaker.tuner import HyperparameterTuner
2525
from sagemaker.workflow.pipeline_context import PipelineSession
2626

27-
from sagemaker.processing import Processor, ScriptProcessor, FrameworkProcessor
27+
from sagemaker.processing import (
28+
Processor,
29+
ScriptProcessor,
30+
FrameworkProcessor,
31+
ProcessingOutput,
32+
ProcessingInput,
33+
)
2834
from sagemaker.sklearn.processing import SKLearnProcessor
2935
from sagemaker.pytorch.processing import PyTorchProcessor
3036
from sagemaker.tensorflow.processing import TensorFlowProcessor
@@ -34,11 +40,12 @@
3440
from sagemaker.wrangler.processing import DataWranglerProcessor
3541
from sagemaker.spark.processing import SparkJarProcessor, PySparkProcessor
3642

37-
from sagemaker.processing import ProcessingInput
3843

3944
from sagemaker.workflow.steps import CacheConfig, ProcessingStep
4045
from sagemaker.workflow.pipeline import Pipeline
4146
from sagemaker.workflow.properties import PropertyFile
47+
from sagemaker.workflow.parameters import ParameterString
48+
from sagemaker.workflow.functions import Join
4249

4350
from sagemaker.network import NetworkConfig
4451
from sagemaker.pytorch.estimator import PyTorch
@@ -62,6 +69,144 @@
6269
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
6370
INSTANCE_TYPE = "ml.m4.xlarge"
6471

72+
FRAMEWORK_PROCESSOR = [
73+
(
74+
FrameworkProcessor(
75+
framework_version="1.8",
76+
instance_type=INSTANCE_TYPE,
77+
instance_count=1,
78+
role=ROLE,
79+
estimator_cls=PyTorch,
80+
),
81+
{"code": DUMMY_S3_SCRIPT_PATH},
82+
),
83+
(
84+
SKLearnProcessor(
85+
framework_version="0.23-1",
86+
instance_type=INSTANCE_TYPE,
87+
instance_count=1,
88+
role=ROLE,
89+
),
90+
{"code": DUMMY_S3_SCRIPT_PATH},
91+
),
92+
(
93+
PyTorchProcessor(
94+
role=ROLE,
95+
instance_type=INSTANCE_TYPE,
96+
instance_count=1,
97+
framework_version="1.8.0",
98+
py_version="py3",
99+
),
100+
{"code": DUMMY_S3_SCRIPT_PATH},
101+
),
102+
(
103+
TensorFlowProcessor(
104+
role=ROLE,
105+
instance_type=INSTANCE_TYPE,
106+
instance_count=1,
107+
framework_version="2.0",
108+
),
109+
{"code": DUMMY_S3_SCRIPT_PATH},
110+
),
111+
(
112+
HuggingFaceProcessor(
113+
transformers_version="4.6",
114+
pytorch_version="1.7",
115+
role=ROLE,
116+
instance_count=1,
117+
instance_type="ml.p3.2xlarge",
118+
),
119+
{"code": DUMMY_S3_SCRIPT_PATH},
120+
),
121+
(
122+
XGBoostProcessor(
123+
framework_version="1.3-1",
124+
py_version="py3",
125+
role=ROLE,
126+
instance_count=1,
127+
instance_type=INSTANCE_TYPE,
128+
base_job_name="test-xgboost",
129+
),
130+
{"code": DUMMY_S3_SCRIPT_PATH},
131+
),
132+
(
133+
MXNetProcessor(
134+
framework_version="1.4.1",
135+
py_version="py3",
136+
role=ROLE,
137+
instance_count=1,
138+
instance_type=INSTANCE_TYPE,
139+
base_job_name="test-mxnet",
140+
),
141+
{"code": DUMMY_S3_SCRIPT_PATH},
142+
),
143+
(
144+
DataWranglerProcessor(
145+
role=ROLE,
146+
data_wrangler_flow_source="s3://my-bucket/dw.flow",
147+
instance_count=1,
148+
instance_type=INSTANCE_TYPE,
149+
),
150+
{},
151+
),
152+
(
153+
SparkJarProcessor(
154+
role=ROLE,
155+
framework_version="2.4",
156+
instance_count=1,
157+
instance_type=INSTANCE_TYPE,
158+
),
159+
{
160+
"submit_app": "s3://my-jar",
161+
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
162+
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
163+
},
164+
),
165+
(
166+
PySparkProcessor(
167+
role=ROLE,
168+
framework_version="2.4",
169+
instance_count=1,
170+
instance_type=INSTANCE_TYPE,
171+
),
172+
{
173+
"submit_app": "s3://my-jar",
174+
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
175+
},
176+
),
177+
]
178+
179+
PROCESSING_INPUT = [
180+
ProcessingInput(source="s3://my-bucket/processing_manifest", destination="processing_manifest"),
181+
ProcessingInput(
182+
source=ParameterString(name="my-processing-input"),
183+
destination="processing-input",
184+
),
185+
ProcessingInput(
186+
source=ParameterString(
187+
name="my-processing-input", default_value="s3://my-bucket/my-processing"
188+
),
189+
destination="processing-input",
190+
),
191+
ProcessingInput(
192+
source=Join(on="/", values=["s3://my-bucket", "my-input"]),
193+
destination="processing-input",
194+
),
195+
]
196+
197+
PROCESSING_OUTPUT = [
198+
ProcessingOutput(source="/opt/ml/output", destination="s3://my-bucket/my-output"),
199+
ProcessingOutput(source="/opt/ml/output", destination=ParameterString(name="my-output")),
200+
ProcessingOutput(
201+
source="/opt/ml/output",
202+
destination=ParameterString(name="my-output", default_value="s3://my-bucket/my-output"),
203+
),
204+
ProcessingOutput(
205+
source="/opt/ml/output",
206+
destination=Join(on="/", values=["s3://my-bucket", "my-output"]),
207+
),
208+
]
209+
65210

66211
@pytest.fixture
67212
def client():
@@ -253,117 +398,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
253398
}
254399

255400

256-
@pytest.mark.parametrize(
257-
"framework_processor",
258-
[
259-
(
260-
FrameworkProcessor(
261-
framework_version="1.8",
262-
instance_type=INSTANCE_TYPE,
263-
instance_count=1,
264-
role=ROLE,
265-
estimator_cls=PyTorch,
266-
),
267-
{"code": DUMMY_S3_SCRIPT_PATH},
268-
),
269-
(
270-
SKLearnProcessor(
271-
framework_version="0.23-1",
272-
instance_type=INSTANCE_TYPE,
273-
instance_count=1,
274-
role=ROLE,
275-
),
276-
{"code": DUMMY_S3_SCRIPT_PATH},
277-
),
278-
(
279-
PyTorchProcessor(
280-
role=ROLE,
281-
instance_type=INSTANCE_TYPE,
282-
instance_count=1,
283-
framework_version="1.8.0",
284-
py_version="py3",
285-
),
286-
{"code": DUMMY_S3_SCRIPT_PATH},
287-
),
288-
(
289-
TensorFlowProcessor(
290-
role=ROLE,
291-
instance_type=INSTANCE_TYPE,
292-
instance_count=1,
293-
framework_version="2.0",
294-
),
295-
{"code": DUMMY_S3_SCRIPT_PATH},
296-
),
297-
(
298-
HuggingFaceProcessor(
299-
transformers_version="4.6",
300-
pytorch_version="1.7",
301-
role=ROLE,
302-
instance_count=1,
303-
instance_type="ml.p3.2xlarge",
304-
),
305-
{"code": DUMMY_S3_SCRIPT_PATH},
306-
),
307-
(
308-
XGBoostProcessor(
309-
framework_version="1.3-1",
310-
py_version="py3",
311-
role=ROLE,
312-
instance_count=1,
313-
instance_type=INSTANCE_TYPE,
314-
base_job_name="test-xgboost",
315-
),
316-
{"code": DUMMY_S3_SCRIPT_PATH},
317-
),
318-
(
319-
MXNetProcessor(
320-
framework_version="1.4.1",
321-
py_version="py3",
322-
role=ROLE,
323-
instance_count=1,
324-
instance_type=INSTANCE_TYPE,
325-
base_job_name="test-mxnet",
326-
),
327-
{"code": DUMMY_S3_SCRIPT_PATH},
328-
),
329-
(
330-
DataWranglerProcessor(
331-
role=ROLE,
332-
data_wrangler_flow_source=f"s3://{BUCKET}/dw.flow",
333-
instance_count=1,
334-
instance_type=INSTANCE_TYPE,
335-
),
336-
{},
337-
),
338-
(
339-
SparkJarProcessor(
340-
role=ROLE,
341-
framework_version="2.4",
342-
instance_count=1,
343-
instance_type=INSTANCE_TYPE,
344-
),
345-
{
346-
"submit_app": "s3://my-jar",
347-
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
348-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
349-
},
350-
),
351-
(
352-
PySparkProcessor(
353-
role=ROLE,
354-
framework_version="2.4",
355-
instance_count=1,
356-
instance_type=INSTANCE_TYPE,
357-
),
358-
{
359-
"submit_app": "s3://my-jar",
360-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
361-
},
362-
),
363-
],
364-
)
401+
@pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR)
402+
@pytest.mark.parametrize("processing_input", PROCESSING_INPUT)
403+
@pytest.mark.parametrize("processing_output", PROCESSING_OUTPUT)
365404
def test_processing_step_with_framework_processor(
366-
framework_processor, pipeline_session, processing_input, network_config
405+
framework_processor, pipeline_session, processing_input, processing_output, network_config
367406
):
368407

369408
processor, run_inputs = framework_processor
@@ -373,7 +412,8 @@ def test_processing_step_with_framework_processor(
373412
processor.volume_kms_key = "volume-kms-key"
374413
processor.network_config = network_config
375414

376-
run_inputs["inputs"] = processing_input
415+
run_inputs["inputs"] = [processing_input]
416+
run_inputs["outputs"] = [processing_output]
377417

378418
step_args = processor.run(**run_inputs)
379419

@@ -387,10 +427,25 @@ def test_processing_step_with_framework_processor(
387427
sagemaker_session=pipeline_session,
388428
)
389429

390-
assert json.loads(pipeline.definition())["Steps"][0] == {
430+
step_args = step_args.args
431+
step_def = json.loads(pipeline.definition())["Steps"][0]
432+
433+
assert step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"] == processing_input.source
434+
assert (
435+
step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
436+
== processing_output.destination
437+
)
438+
439+
del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"]
440+
del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"]
441+
442+
del step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
443+
del step_def["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
444+
445+
assert step_def == {
391446
"Name": "MyProcessingStep",
392447
"Type": "Processing",
393-
"Arguments": step_args.args,
448+
"Arguments": step_args,
394449
}
395450

396451

0 commit comments

Comments
 (0)