Skip to content

Commit 528c9de

Browse files
committed
Merge branch 'master' of https://github.com/aws/sagemaker-python-sdk into fix-processing-image-uri-param
2 parents 8880953 + 2b5b4da commit 528c9de

File tree

3 files changed

+322
-258
lines changed

3 files changed

+322
-258
lines changed

tests/unit/sagemaker/workflow/test_processing_step.py

+168-115
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,13 @@
2424
from sagemaker.tuner import HyperparameterTuner
2525
from sagemaker.workflow.pipeline_context import PipelineSession
2626

27-
from sagemaker.processing import Processor, ScriptProcessor, FrameworkProcessor
27+
from sagemaker.processing import (
28+
Processor,
29+
ScriptProcessor,
30+
FrameworkProcessor,
31+
ProcessingOutput,
32+
ProcessingInput,
33+
)
2834
from sagemaker.sklearn.processing import SKLearnProcessor
2935
from sagemaker.pytorch.processing import PyTorchProcessor
3036
from sagemaker.tensorflow.processing import TensorFlowProcessor
@@ -34,7 +40,6 @@
3440
from sagemaker.wrangler.processing import DataWranglerProcessor
3541
from sagemaker.spark.processing import SparkJarProcessor, PySparkProcessor
3642

37-
from sagemaker.processing import ProcessingInput
3843

3944
from sagemaker.workflow.steps import CacheConfig, ProcessingStep
4045
from sagemaker.workflow.pipeline import Pipeline
@@ -64,6 +69,144 @@
6469
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
6570
INSTANCE_TYPE = "ml.m4.xlarge"
6671

72+
FRAMEWORK_PROCESSOR = [
73+
(
74+
FrameworkProcessor(
75+
framework_version="1.8",
76+
instance_type=INSTANCE_TYPE,
77+
instance_count=1,
78+
role=ROLE,
79+
estimator_cls=PyTorch,
80+
),
81+
{"code": DUMMY_S3_SCRIPT_PATH},
82+
),
83+
(
84+
SKLearnProcessor(
85+
framework_version="0.23-1",
86+
instance_type=INSTANCE_TYPE,
87+
instance_count=1,
88+
role=ROLE,
89+
),
90+
{"code": DUMMY_S3_SCRIPT_PATH},
91+
),
92+
(
93+
PyTorchProcessor(
94+
role=ROLE,
95+
instance_type=INSTANCE_TYPE,
96+
instance_count=1,
97+
framework_version="1.8.0",
98+
py_version="py3",
99+
),
100+
{"code": DUMMY_S3_SCRIPT_PATH},
101+
),
102+
(
103+
TensorFlowProcessor(
104+
role=ROLE,
105+
instance_type=INSTANCE_TYPE,
106+
instance_count=1,
107+
framework_version="2.0",
108+
),
109+
{"code": DUMMY_S3_SCRIPT_PATH},
110+
),
111+
(
112+
HuggingFaceProcessor(
113+
transformers_version="4.6",
114+
pytorch_version="1.7",
115+
role=ROLE,
116+
instance_count=1,
117+
instance_type="ml.p3.2xlarge",
118+
),
119+
{"code": DUMMY_S3_SCRIPT_PATH},
120+
),
121+
(
122+
XGBoostProcessor(
123+
framework_version="1.3-1",
124+
py_version="py3",
125+
role=ROLE,
126+
instance_count=1,
127+
instance_type=INSTANCE_TYPE,
128+
base_job_name="test-xgboost",
129+
),
130+
{"code": DUMMY_S3_SCRIPT_PATH},
131+
),
132+
(
133+
MXNetProcessor(
134+
framework_version="1.4.1",
135+
py_version="py3",
136+
role=ROLE,
137+
instance_count=1,
138+
instance_type=INSTANCE_TYPE,
139+
base_job_name="test-mxnet",
140+
),
141+
{"code": DUMMY_S3_SCRIPT_PATH},
142+
),
143+
(
144+
DataWranglerProcessor(
145+
role=ROLE,
146+
data_wrangler_flow_source="s3://my-bucket/dw.flow",
147+
instance_count=1,
148+
instance_type=INSTANCE_TYPE,
149+
),
150+
{},
151+
),
152+
(
153+
SparkJarProcessor(
154+
role=ROLE,
155+
framework_version="2.4",
156+
instance_count=1,
157+
instance_type=INSTANCE_TYPE,
158+
),
159+
{
160+
"submit_app": "s3://my-jar",
161+
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
162+
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
163+
},
164+
),
165+
(
166+
PySparkProcessor(
167+
role=ROLE,
168+
framework_version="2.4",
169+
instance_count=1,
170+
instance_type=INSTANCE_TYPE,
171+
),
172+
{
173+
"submit_app": "s3://my-jar",
174+
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
175+
},
176+
),
177+
]
178+
179+
PROCESSING_INPUT = [
180+
ProcessingInput(source="s3://my-bucket/processing_manifest", destination="processing_manifest"),
181+
ProcessingInput(
182+
source=ParameterString(name="my-processing-input"),
183+
destination="processing-input",
184+
),
185+
ProcessingInput(
186+
source=ParameterString(
187+
name="my-processing-input", default_value="s3://my-bucket/my-processing"
188+
),
189+
destination="processing-input",
190+
),
191+
ProcessingInput(
192+
source=Join(on="/", values=["s3://my-bucket", "my-input"]),
193+
destination="processing-input",
194+
),
195+
]
196+
197+
PROCESSING_OUTPUT = [
198+
ProcessingOutput(source="/opt/ml/output", destination="s3://my-bucket/my-output"),
199+
ProcessingOutput(source="/opt/ml/output", destination=ParameterString(name="my-output")),
200+
ProcessingOutput(
201+
source="/opt/ml/output",
202+
destination=ParameterString(name="my-output", default_value="s3://my-bucket/my-output"),
203+
),
204+
ProcessingOutput(
205+
source="/opt/ml/output",
206+
destination=Join(on="/", values=["s3://my-bucket", "my-output"]),
207+
),
208+
]
209+
67210

68211
@pytest.fixture
69212
def client():
@@ -265,117 +408,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
265408
}
266409

267410

268-
@pytest.mark.parametrize(
269-
"framework_processor",
270-
[
271-
(
272-
FrameworkProcessor(
273-
framework_version="1.8",
274-
instance_type=INSTANCE_TYPE,
275-
instance_count=1,
276-
role=ROLE,
277-
estimator_cls=PyTorch,
278-
),
279-
{"code": DUMMY_S3_SCRIPT_PATH},
280-
),
281-
(
282-
SKLearnProcessor(
283-
framework_version="0.23-1",
284-
instance_type=INSTANCE_TYPE,
285-
instance_count=1,
286-
role=ROLE,
287-
),
288-
{"code": DUMMY_S3_SCRIPT_PATH},
289-
),
290-
(
291-
PyTorchProcessor(
292-
role=ROLE,
293-
instance_type=INSTANCE_TYPE,
294-
instance_count=1,
295-
framework_version="1.8.0",
296-
py_version="py3",
297-
),
298-
{"code": DUMMY_S3_SCRIPT_PATH},
299-
),
300-
(
301-
TensorFlowProcessor(
302-
role=ROLE,
303-
instance_type=INSTANCE_TYPE,
304-
instance_count=1,
305-
framework_version="2.0",
306-
),
307-
{"code": DUMMY_S3_SCRIPT_PATH},
308-
),
309-
(
310-
HuggingFaceProcessor(
311-
transformers_version="4.6",
312-
pytorch_version="1.7",
313-
role=ROLE,
314-
instance_count=1,
315-
instance_type="ml.p3.2xlarge",
316-
),
317-
{"code": DUMMY_S3_SCRIPT_PATH},
318-
),
319-
(
320-
XGBoostProcessor(
321-
framework_version="1.3-1",
322-
py_version="py3",
323-
role=ROLE,
324-
instance_count=1,
325-
instance_type=INSTANCE_TYPE,
326-
base_job_name="test-xgboost",
327-
),
328-
{"code": DUMMY_S3_SCRIPT_PATH},
329-
),
330-
(
331-
MXNetProcessor(
332-
framework_version="1.4.1",
333-
py_version="py3",
334-
role=ROLE,
335-
instance_count=1,
336-
instance_type=INSTANCE_TYPE,
337-
base_job_name="test-mxnet",
338-
),
339-
{"code": DUMMY_S3_SCRIPT_PATH},
340-
),
341-
(
342-
DataWranglerProcessor(
343-
role=ROLE,
344-
data_wrangler_flow_source=f"s3://{BUCKET}/dw.flow",
345-
instance_count=1,
346-
instance_type=INSTANCE_TYPE,
347-
),
348-
{},
349-
),
350-
(
351-
SparkJarProcessor(
352-
role=ROLE,
353-
framework_version="2.4",
354-
instance_count=1,
355-
instance_type=INSTANCE_TYPE,
356-
),
357-
{
358-
"submit_app": "s3://my-jar",
359-
"submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
360-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
361-
},
362-
),
363-
(
364-
PySparkProcessor(
365-
role=ROLE,
366-
framework_version="2.4",
367-
instance_count=1,
368-
instance_type=INSTANCE_TYPE,
369-
),
370-
{
371-
"submit_app": "s3://my-jar",
372-
"arguments": ["--input", "input-data-uri", "--output", "output-data-uri"],
373-
},
374-
),
375-
],
376-
)
411+
@pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR)
412+
@pytest.mark.parametrize("processing_input", PROCESSING_INPUT)
413+
@pytest.mark.parametrize("processing_output", PROCESSING_OUTPUT)
377414
def test_processing_step_with_framework_processor(
378-
framework_processor, pipeline_session, processing_input, network_config
415+
framework_processor, pipeline_session, processing_input, processing_output, network_config
379416
):
380417

381418
processor, run_inputs = framework_processor
@@ -385,7 +422,8 @@ def test_processing_step_with_framework_processor(
385422
processor.volume_kms_key = "volume-kms-key"
386423
processor.network_config = network_config
387424

388-
run_inputs["inputs"] = processing_input
425+
run_inputs["inputs"] = [processing_input]
426+
run_inputs["outputs"] = [processing_output]
389427

390428
step_args = processor.run(**run_inputs)
391429

@@ -399,10 +437,25 @@ def test_processing_step_with_framework_processor(
399437
sagemaker_session=pipeline_session,
400438
)
401439

402-
assert json.loads(pipeline.definition())["Steps"][0] == {
440+
step_args = step_args.args
441+
step_def = json.loads(pipeline.definition())["Steps"][0]
442+
443+
assert step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"] == processing_input.source
444+
assert (
445+
step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
446+
== processing_output.destination
447+
)
448+
449+
del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"]
450+
del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"]
451+
452+
del step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
453+
del step_def["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
454+
455+
assert step_def == {
403456
"Name": "MyProcessingStep",
404457
"Type": "Processing",
405-
"Arguments": step_args.args,
458+
"Arguments": step_args,
406459
}
407460

408461

0 commit comments

Comments
 (0)