24
24
from sagemaker .tuner import HyperparameterTuner
25
25
from sagemaker .workflow .pipeline_context import PipelineSession
26
26
27
- from sagemaker .processing import Processor , ScriptProcessor , FrameworkProcessor
27
+ from sagemaker .processing import (
28
+ Processor ,
29
+ ScriptProcessor ,
30
+ FrameworkProcessor ,
31
+ ProcessingOutput ,
32
+ ProcessingInput ,
33
+ )
28
34
from sagemaker .sklearn .processing import SKLearnProcessor
29
35
from sagemaker .pytorch .processing import PyTorchProcessor
30
36
from sagemaker .tensorflow .processing import TensorFlowProcessor
34
40
from sagemaker .wrangler .processing import DataWranglerProcessor
35
41
from sagemaker .spark .processing import SparkJarProcessor , PySparkProcessor
36
42
37
- from sagemaker .processing import ProcessingInput
38
43
39
44
from sagemaker .workflow .steps import CacheConfig , ProcessingStep
40
45
from sagemaker .workflow .pipeline import Pipeline
64
69
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
65
70
INSTANCE_TYPE = "ml.m4.xlarge"
66
71
72
+ FRAMEWORK_PROCESSOR = [
73
+ (
74
+ FrameworkProcessor (
75
+ framework_version = "1.8" ,
76
+ instance_type = INSTANCE_TYPE ,
77
+ instance_count = 1 ,
78
+ role = ROLE ,
79
+ estimator_cls = PyTorch ,
80
+ ),
81
+ {"code" : DUMMY_S3_SCRIPT_PATH },
82
+ ),
83
+ (
84
+ SKLearnProcessor (
85
+ framework_version = "0.23-1" ,
86
+ instance_type = INSTANCE_TYPE ,
87
+ instance_count = 1 ,
88
+ role = ROLE ,
89
+ ),
90
+ {"code" : DUMMY_S3_SCRIPT_PATH },
91
+ ),
92
+ (
93
+ PyTorchProcessor (
94
+ role = ROLE ,
95
+ instance_type = INSTANCE_TYPE ,
96
+ instance_count = 1 ,
97
+ framework_version = "1.8.0" ,
98
+ py_version = "py3" ,
99
+ ),
100
+ {"code" : DUMMY_S3_SCRIPT_PATH },
101
+ ),
102
+ (
103
+ TensorFlowProcessor (
104
+ role = ROLE ,
105
+ instance_type = INSTANCE_TYPE ,
106
+ instance_count = 1 ,
107
+ framework_version = "2.0" ,
108
+ ),
109
+ {"code" : DUMMY_S3_SCRIPT_PATH },
110
+ ),
111
+ (
112
+ HuggingFaceProcessor (
113
+ transformers_version = "4.6" ,
114
+ pytorch_version = "1.7" ,
115
+ role = ROLE ,
116
+ instance_count = 1 ,
117
+ instance_type = "ml.p3.2xlarge" ,
118
+ ),
119
+ {"code" : DUMMY_S3_SCRIPT_PATH },
120
+ ),
121
+ (
122
+ XGBoostProcessor (
123
+ framework_version = "1.3-1" ,
124
+ py_version = "py3" ,
125
+ role = ROLE ,
126
+ instance_count = 1 ,
127
+ instance_type = INSTANCE_TYPE ,
128
+ base_job_name = "test-xgboost" ,
129
+ ),
130
+ {"code" : DUMMY_S3_SCRIPT_PATH },
131
+ ),
132
+ (
133
+ MXNetProcessor (
134
+ framework_version = "1.4.1" ,
135
+ py_version = "py3" ,
136
+ role = ROLE ,
137
+ instance_count = 1 ,
138
+ instance_type = INSTANCE_TYPE ,
139
+ base_job_name = "test-mxnet" ,
140
+ ),
141
+ {"code" : DUMMY_S3_SCRIPT_PATH },
142
+ ),
143
+ (
144
+ DataWranglerProcessor (
145
+ role = ROLE ,
146
+ data_wrangler_flow_source = "s3://my-bucket/dw.flow" ,
147
+ instance_count = 1 ,
148
+ instance_type = INSTANCE_TYPE ,
149
+ ),
150
+ {},
151
+ ),
152
+ (
153
+ SparkJarProcessor (
154
+ role = ROLE ,
155
+ framework_version = "2.4" ,
156
+ instance_count = 1 ,
157
+ instance_type = INSTANCE_TYPE ,
158
+ ),
159
+ {
160
+ "submit_app" : "s3://my-jar" ,
161
+ "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
162
+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
163
+ },
164
+ ),
165
+ (
166
+ PySparkProcessor (
167
+ role = ROLE ,
168
+ framework_version = "2.4" ,
169
+ instance_count = 1 ,
170
+ instance_type = INSTANCE_TYPE ,
171
+ ),
172
+ {
173
+ "submit_app" : "s3://my-jar" ,
174
+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
175
+ },
176
+ ),
177
+ ]
178
+
179
+ PROCESSING_INPUT = [
180
+ ProcessingInput (source = "s3://my-bucket/processing_manifest" , destination = "processing_manifest" ),
181
+ ProcessingInput (
182
+ source = ParameterString (name = "my-processing-input" ),
183
+ destination = "processing-input" ,
184
+ ),
185
+ ProcessingInput (
186
+ source = ParameterString (
187
+ name = "my-processing-input" , default_value = "s3://my-bucket/my-processing"
188
+ ),
189
+ destination = "processing-input" ,
190
+ ),
191
+ ProcessingInput (
192
+ source = Join (on = "/" , values = ["s3://my-bucket" , "my-input" ]),
193
+ destination = "processing-input" ,
194
+ ),
195
+ ]
196
+
197
+ PROCESSING_OUTPUT = [
198
+ ProcessingOutput (source = "/opt/ml/output" , destination = "s3://my-bucket/my-output" ),
199
+ ProcessingOutput (source = "/opt/ml/output" , destination = ParameterString (name = "my-output" )),
200
+ ProcessingOutput (
201
+ source = "/opt/ml/output" ,
202
+ destination = ParameterString (name = "my-output" , default_value = "s3://my-bucket/my-output" ),
203
+ ),
204
+ ProcessingOutput (
205
+ source = "/opt/ml/output" ,
206
+ destination = Join (on = "/" , values = ["s3://my-bucket" , "my-output" ]),
207
+ ),
208
+ ]
209
+
67
210
68
211
@pytest .fixture
69
212
def client ():
@@ -265,117 +408,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
265
408
}
266
409
267
410
268
- @pytest .mark .parametrize (
269
- "framework_processor" ,
270
- [
271
- (
272
- FrameworkProcessor (
273
- framework_version = "1.8" ,
274
- instance_type = INSTANCE_TYPE ,
275
- instance_count = 1 ,
276
- role = ROLE ,
277
- estimator_cls = PyTorch ,
278
- ),
279
- {"code" : DUMMY_S3_SCRIPT_PATH },
280
- ),
281
- (
282
- SKLearnProcessor (
283
- framework_version = "0.23-1" ,
284
- instance_type = INSTANCE_TYPE ,
285
- instance_count = 1 ,
286
- role = ROLE ,
287
- ),
288
- {"code" : DUMMY_S3_SCRIPT_PATH },
289
- ),
290
- (
291
- PyTorchProcessor (
292
- role = ROLE ,
293
- instance_type = INSTANCE_TYPE ,
294
- instance_count = 1 ,
295
- framework_version = "1.8.0" ,
296
- py_version = "py3" ,
297
- ),
298
- {"code" : DUMMY_S3_SCRIPT_PATH },
299
- ),
300
- (
301
- TensorFlowProcessor (
302
- role = ROLE ,
303
- instance_type = INSTANCE_TYPE ,
304
- instance_count = 1 ,
305
- framework_version = "2.0" ,
306
- ),
307
- {"code" : DUMMY_S3_SCRIPT_PATH },
308
- ),
309
- (
310
- HuggingFaceProcessor (
311
- transformers_version = "4.6" ,
312
- pytorch_version = "1.7" ,
313
- role = ROLE ,
314
- instance_count = 1 ,
315
- instance_type = "ml.p3.2xlarge" ,
316
- ),
317
- {"code" : DUMMY_S3_SCRIPT_PATH },
318
- ),
319
- (
320
- XGBoostProcessor (
321
- framework_version = "1.3-1" ,
322
- py_version = "py3" ,
323
- role = ROLE ,
324
- instance_count = 1 ,
325
- instance_type = INSTANCE_TYPE ,
326
- base_job_name = "test-xgboost" ,
327
- ),
328
- {"code" : DUMMY_S3_SCRIPT_PATH },
329
- ),
330
- (
331
- MXNetProcessor (
332
- framework_version = "1.4.1" ,
333
- py_version = "py3" ,
334
- role = ROLE ,
335
- instance_count = 1 ,
336
- instance_type = INSTANCE_TYPE ,
337
- base_job_name = "test-mxnet" ,
338
- ),
339
- {"code" : DUMMY_S3_SCRIPT_PATH },
340
- ),
341
- (
342
- DataWranglerProcessor (
343
- role = ROLE ,
344
- data_wrangler_flow_source = f"s3://{ BUCKET } /dw.flow" ,
345
- instance_count = 1 ,
346
- instance_type = INSTANCE_TYPE ,
347
- ),
348
- {},
349
- ),
350
- (
351
- SparkJarProcessor (
352
- role = ROLE ,
353
- framework_version = "2.4" ,
354
- instance_count = 1 ,
355
- instance_type = INSTANCE_TYPE ,
356
- ),
357
- {
358
- "submit_app" : "s3://my-jar" ,
359
- "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
360
- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
361
- },
362
- ),
363
- (
364
- PySparkProcessor (
365
- role = ROLE ,
366
- framework_version = "2.4" ,
367
- instance_count = 1 ,
368
- instance_type = INSTANCE_TYPE ,
369
- ),
370
- {
371
- "submit_app" : "s3://my-jar" ,
372
- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
373
- },
374
- ),
375
- ],
376
- )
411
+ @pytest .mark .parametrize ("framework_processor" , FRAMEWORK_PROCESSOR )
412
+ @pytest .mark .parametrize ("processing_input" , PROCESSING_INPUT )
413
+ @pytest .mark .parametrize ("processing_output" , PROCESSING_OUTPUT )
377
414
def test_processing_step_with_framework_processor (
378
- framework_processor , pipeline_session , processing_input , network_config
415
+ framework_processor , pipeline_session , processing_input , processing_output , network_config
379
416
):
380
417
381
418
processor , run_inputs = framework_processor
@@ -385,7 +422,8 @@ def test_processing_step_with_framework_processor(
385
422
processor .volume_kms_key = "volume-kms-key"
386
423
processor .network_config = network_config
387
424
388
- run_inputs ["inputs" ] = processing_input
425
+ run_inputs ["inputs" ] = [processing_input ]
426
+ run_inputs ["outputs" ] = [processing_output ]
389
427
390
428
step_args = processor .run (** run_inputs )
391
429
@@ -399,10 +437,25 @@ def test_processing_step_with_framework_processor(
399
437
sagemaker_session = pipeline_session ,
400
438
)
401
439
402
- assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
440
+ step_args = step_args .args
441
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
442
+
443
+ assert step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ] == processing_input .source
444
+ assert (
445
+ step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
446
+ == processing_output .destination
447
+ )
448
+
449
+ del step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
450
+ del step_def ["Arguments" ]["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
451
+
452
+ del step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
453
+ del step_def ["Arguments" ]["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
454
+
455
+ assert step_def == {
403
456
"Name" : "MyProcessingStep" ,
404
457
"Type" : "Processing" ,
405
- "Arguments" : step_args . args ,
458
+ "Arguments" : step_args ,
406
459
}
407
460
408
461
0 commit comments