24
24
from sagemaker .tuner import HyperparameterTuner
25
25
from sagemaker .workflow .pipeline_context import PipelineSession
26
26
27
- from sagemaker .processing import Processor , ScriptProcessor , FrameworkProcessor
27
+ from sagemaker .processing import (
28
+ Processor ,
29
+ ScriptProcessor ,
30
+ FrameworkProcessor ,
31
+ ProcessingOutput ,
32
+ ProcessingInput ,
33
+ )
28
34
from sagemaker .sklearn .processing import SKLearnProcessor
29
35
from sagemaker .pytorch .processing import PyTorchProcessor
30
36
from sagemaker .tensorflow .processing import TensorFlowProcessor
34
40
from sagemaker .wrangler .processing import DataWranglerProcessor
35
41
from sagemaker .spark .processing import SparkJarProcessor , PySparkProcessor
36
42
37
- from sagemaker .processing import ProcessingInput
38
43
39
44
from sagemaker .workflow .steps import CacheConfig , ProcessingStep
40
45
from sagemaker .workflow .pipeline import Pipeline
41
46
from sagemaker .workflow .properties import PropertyFile
47
+ from sagemaker .workflow .parameters import ParameterString
48
+ from sagemaker .workflow .functions import Join
42
49
43
50
from sagemaker .network import NetworkConfig
44
51
from sagemaker .pytorch .estimator import PyTorch
62
69
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
63
70
INSTANCE_TYPE = "ml.m4.xlarge"
64
71
72
+ FRAMEWORK_PROCESSOR = [
73
+ (
74
+ FrameworkProcessor (
75
+ framework_version = "1.8" ,
76
+ instance_type = INSTANCE_TYPE ,
77
+ instance_count = 1 ,
78
+ role = ROLE ,
79
+ estimator_cls = PyTorch ,
80
+ ),
81
+ {"code" : DUMMY_S3_SCRIPT_PATH },
82
+ ),
83
+ (
84
+ SKLearnProcessor (
85
+ framework_version = "0.23-1" ,
86
+ instance_type = INSTANCE_TYPE ,
87
+ instance_count = 1 ,
88
+ role = ROLE ,
89
+ ),
90
+ {"code" : DUMMY_S3_SCRIPT_PATH },
91
+ ),
92
+ (
93
+ PyTorchProcessor (
94
+ role = ROLE ,
95
+ instance_type = INSTANCE_TYPE ,
96
+ instance_count = 1 ,
97
+ framework_version = "1.8.0" ,
98
+ py_version = "py3" ,
99
+ ),
100
+ {"code" : DUMMY_S3_SCRIPT_PATH },
101
+ ),
102
+ (
103
+ TensorFlowProcessor (
104
+ role = ROLE ,
105
+ instance_type = INSTANCE_TYPE ,
106
+ instance_count = 1 ,
107
+ framework_version = "2.0" ,
108
+ ),
109
+ {"code" : DUMMY_S3_SCRIPT_PATH },
110
+ ),
111
+ (
112
+ HuggingFaceProcessor (
113
+ transformers_version = "4.6" ,
114
+ pytorch_version = "1.7" ,
115
+ role = ROLE ,
116
+ instance_count = 1 ,
117
+ instance_type = "ml.p3.2xlarge" ,
118
+ ),
119
+ {"code" : DUMMY_S3_SCRIPT_PATH },
120
+ ),
121
+ (
122
+ XGBoostProcessor (
123
+ framework_version = "1.3-1" ,
124
+ py_version = "py3" ,
125
+ role = ROLE ,
126
+ instance_count = 1 ,
127
+ instance_type = INSTANCE_TYPE ,
128
+ base_job_name = "test-xgboost" ,
129
+ ),
130
+ {"code" : DUMMY_S3_SCRIPT_PATH },
131
+ ),
132
+ (
133
+ MXNetProcessor (
134
+ framework_version = "1.4.1" ,
135
+ py_version = "py3" ,
136
+ role = ROLE ,
137
+ instance_count = 1 ,
138
+ instance_type = INSTANCE_TYPE ,
139
+ base_job_name = "test-mxnet" ,
140
+ ),
141
+ {"code" : DUMMY_S3_SCRIPT_PATH },
142
+ ),
143
+ (
144
+ DataWranglerProcessor (
145
+ role = ROLE ,
146
+ data_wrangler_flow_source = "s3://my-bucket/dw.flow" ,
147
+ instance_count = 1 ,
148
+ instance_type = INSTANCE_TYPE ,
149
+ ),
150
+ {},
151
+ ),
152
+ (
153
+ SparkJarProcessor (
154
+ role = ROLE ,
155
+ framework_version = "2.4" ,
156
+ instance_count = 1 ,
157
+ instance_type = INSTANCE_TYPE ,
158
+ ),
159
+ {
160
+ "submit_app" : "s3://my-jar" ,
161
+ "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
162
+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
163
+ },
164
+ ),
165
+ (
166
+ PySparkProcessor (
167
+ role = ROLE ,
168
+ framework_version = "2.4" ,
169
+ instance_count = 1 ,
170
+ instance_type = INSTANCE_TYPE ,
171
+ ),
172
+ {
173
+ "submit_app" : "s3://my-jar" ,
174
+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
175
+ },
176
+ ),
177
+ ]
178
+
179
+ PROCESSING_INPUT = [
180
+ ProcessingInput (source = "s3://my-bucket/processing_manifest" , destination = "processing_manifest" ),
181
+ ProcessingInput (
182
+ source = ParameterString (name = "my-processing-input" ),
183
+ destination = "processing-input" ,
184
+ ),
185
+ ProcessingInput (
186
+ source = ParameterString (
187
+ name = "my-processing-input" , default_value = "s3://my-bucket/my-processing"
188
+ ),
189
+ destination = "processing-input" ,
190
+ ),
191
+ ProcessingInput (
192
+ source = Join (on = "/" , values = ["s3://my-bucket" , "my-input" ]),
193
+ destination = "processing-input" ,
194
+ ),
195
+ ]
196
+
197
+ PROCESSING_OUTPUT = [
198
+ ProcessingOutput (source = "/opt/ml/output" , destination = "s3://my-bucket/my-output" ),
199
+ ProcessingOutput (source = "/opt/ml/output" , destination = ParameterString (name = "my-output" )),
200
+ ProcessingOutput (
201
+ source = "/opt/ml/output" ,
202
+ destination = ParameterString (name = "my-output" , default_value = "s3://my-bucket/my-output" ),
203
+ ),
204
+ ProcessingOutput (
205
+ source = "/opt/ml/output" ,
206
+ destination = Join (on = "/" , values = ["s3://my-bucket" , "my-output" ]),
207
+ ),
208
+ ]
209
+
65
210
66
211
@pytest .fixture
67
212
def client ():
@@ -253,117 +398,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
253
398
}
254
399
255
400
256
- @pytest .mark .parametrize (
257
- "framework_processor" ,
258
- [
259
- (
260
- FrameworkProcessor (
261
- framework_version = "1.8" ,
262
- instance_type = INSTANCE_TYPE ,
263
- instance_count = 1 ,
264
- role = ROLE ,
265
- estimator_cls = PyTorch ,
266
- ),
267
- {"code" : DUMMY_S3_SCRIPT_PATH },
268
- ),
269
- (
270
- SKLearnProcessor (
271
- framework_version = "0.23-1" ,
272
- instance_type = INSTANCE_TYPE ,
273
- instance_count = 1 ,
274
- role = ROLE ,
275
- ),
276
- {"code" : DUMMY_S3_SCRIPT_PATH },
277
- ),
278
- (
279
- PyTorchProcessor (
280
- role = ROLE ,
281
- instance_type = INSTANCE_TYPE ,
282
- instance_count = 1 ,
283
- framework_version = "1.8.0" ,
284
- py_version = "py3" ,
285
- ),
286
- {"code" : DUMMY_S3_SCRIPT_PATH },
287
- ),
288
- (
289
- TensorFlowProcessor (
290
- role = ROLE ,
291
- instance_type = INSTANCE_TYPE ,
292
- instance_count = 1 ,
293
- framework_version = "2.0" ,
294
- ),
295
- {"code" : DUMMY_S3_SCRIPT_PATH },
296
- ),
297
- (
298
- HuggingFaceProcessor (
299
- transformers_version = "4.6" ,
300
- pytorch_version = "1.7" ,
301
- role = ROLE ,
302
- instance_count = 1 ,
303
- instance_type = "ml.p3.2xlarge" ,
304
- ),
305
- {"code" : DUMMY_S3_SCRIPT_PATH },
306
- ),
307
- (
308
- XGBoostProcessor (
309
- framework_version = "1.3-1" ,
310
- py_version = "py3" ,
311
- role = ROLE ,
312
- instance_count = 1 ,
313
- instance_type = INSTANCE_TYPE ,
314
- base_job_name = "test-xgboost" ,
315
- ),
316
- {"code" : DUMMY_S3_SCRIPT_PATH },
317
- ),
318
- (
319
- MXNetProcessor (
320
- framework_version = "1.4.1" ,
321
- py_version = "py3" ,
322
- role = ROLE ,
323
- instance_count = 1 ,
324
- instance_type = INSTANCE_TYPE ,
325
- base_job_name = "test-mxnet" ,
326
- ),
327
- {"code" : DUMMY_S3_SCRIPT_PATH },
328
- ),
329
- (
330
- DataWranglerProcessor (
331
- role = ROLE ,
332
- data_wrangler_flow_source = f"s3://{ BUCKET } /dw.flow" ,
333
- instance_count = 1 ,
334
- instance_type = INSTANCE_TYPE ,
335
- ),
336
- {},
337
- ),
338
- (
339
- SparkJarProcessor (
340
- role = ROLE ,
341
- framework_version = "2.4" ,
342
- instance_count = 1 ,
343
- instance_type = INSTANCE_TYPE ,
344
- ),
345
- {
346
- "submit_app" : "s3://my-jar" ,
347
- "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
348
- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
349
- },
350
- ),
351
- (
352
- PySparkProcessor (
353
- role = ROLE ,
354
- framework_version = "2.4" ,
355
- instance_count = 1 ,
356
- instance_type = INSTANCE_TYPE ,
357
- ),
358
- {
359
- "submit_app" : "s3://my-jar" ,
360
- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
361
- },
362
- ),
363
- ],
364
- )
401
+ @pytest .mark .parametrize ("framework_processor" , FRAMEWORK_PROCESSOR )
402
+ @pytest .mark .parametrize ("processing_input" , PROCESSING_INPUT )
403
+ @pytest .mark .parametrize ("processing_output" , PROCESSING_OUTPUT )
365
404
def test_processing_step_with_framework_processor (
366
- framework_processor , pipeline_session , processing_input , network_config
405
+ framework_processor , pipeline_session , processing_input , processing_output , network_config
367
406
):
368
407
369
408
processor , run_inputs = framework_processor
@@ -373,7 +412,8 @@ def test_processing_step_with_framework_processor(
373
412
processor .volume_kms_key = "volume-kms-key"
374
413
processor .network_config = network_config
375
414
376
- run_inputs ["inputs" ] = processing_input
415
+ run_inputs ["inputs" ] = [processing_input ]
416
+ run_inputs ["outputs" ] = [processing_output ]
377
417
378
418
step_args = processor .run (** run_inputs )
379
419
@@ -387,10 +427,25 @@ def test_processing_step_with_framework_processor(
387
427
sagemaker_session = pipeline_session ,
388
428
)
389
429
390
- assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
430
+ step_args = step_args .args
431
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
432
+
433
+ assert step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ] == processing_input .source
434
+ assert (
435
+ step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
436
+ == processing_output .destination
437
+ )
438
+
439
+ del step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
440
+ del step_def ["Arguments" ]["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
441
+
442
+ del step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
443
+ del step_def ["Arguments" ]["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
444
+
445
+ assert step_def == {
391
446
"Name" : "MyProcessingStep" ,
392
447
"Type" : "Processing" ,
393
- "Arguments" : step_args . args ,
448
+ "Arguments" : step_args ,
394
449
}
395
450
396
451
0 commit comments