24
24
from sagemaker .tuner import HyperparameterTuner
25
25
from sagemaker .workflow .pipeline_context import PipelineSession
26
26
27
- from sagemaker .processing import Processor , ScriptProcessor , FrameworkProcessor
27
+ from sagemaker .processing import (
28
+ Processor ,
29
+ ScriptProcessor ,
30
+ FrameworkProcessor ,
31
+ ProcessingOutput ,
32
+ ProcessingInput ,
33
+ )
28
34
from sagemaker .sklearn .processing import SKLearnProcessor
29
35
from sagemaker .pytorch .processing import PyTorchProcessor
30
36
from sagemaker .tensorflow .processing import TensorFlowProcessor
34
40
from sagemaker .wrangler .processing import DataWranglerProcessor
35
41
from sagemaker .spark .processing import SparkJarProcessor , PySparkProcessor
36
42
37
- from sagemaker .processing import ProcessingInput
38
43
39
44
from sagemaker .workflow .steps import CacheConfig , ProcessingStep
40
45
from sagemaker .workflow .pipeline import Pipeline
41
46
from sagemaker .workflow .properties import PropertyFile
47
+ from sagemaker .workflow .parameters import ParameterString
48
+ from sagemaker .workflow .functions import Join
42
49
43
50
from sagemaker .network import NetworkConfig
44
51
from sagemaker .pytorch .estimator import PyTorch
62
69
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
63
70
INSTANCE_TYPE = "ml.m4.xlarge"
64
71
72
+ FRAMEWORK_PROCESSOR = [
73
+ (
74
+ FrameworkProcessor (
75
+ framework_version = "1.8" ,
76
+ instance_type = INSTANCE_TYPE ,
77
+ instance_count = 1 ,
78
+ role = ROLE ,
79
+ estimator_cls = PyTorch ,
80
+ ),
81
+ {"code" : DUMMY_S3_SCRIPT_PATH },
82
+ ),
83
+ (
84
+ SKLearnProcessor (
85
+ framework_version = "0.23-1" ,
86
+ instance_type = INSTANCE_TYPE ,
87
+ instance_count = 1 ,
88
+ role = ROLE ,
89
+ ),
90
+ {"code" : DUMMY_S3_SCRIPT_PATH },
91
+ ),
92
+ (
93
+ PyTorchProcessor (
94
+ role = ROLE ,
95
+ instance_type = INSTANCE_TYPE ,
96
+ instance_count = 1 ,
97
+ framework_version = "1.8.0" ,
98
+ py_version = "py3" ,
99
+ ),
100
+ {"code" : DUMMY_S3_SCRIPT_PATH },
101
+ ),
102
+ (
103
+ TensorFlowProcessor (
104
+ role = ROLE ,
105
+ instance_type = INSTANCE_TYPE ,
106
+ instance_count = 1 ,
107
+ framework_version = "2.0" ,
108
+ ),
109
+ {"code" : DUMMY_S3_SCRIPT_PATH },
110
+ ),
111
+ (
112
+ HuggingFaceProcessor (
113
+ transformers_version = "4.6" ,
114
+ pytorch_version = "1.7" ,
115
+ role = ROLE ,
116
+ instance_count = 1 ,
117
+ instance_type = "ml.p3.2xlarge" ,
118
+ ),
119
+ {"code" : DUMMY_S3_SCRIPT_PATH },
120
+ ),
121
+ (
122
+ XGBoostProcessor (
123
+ framework_version = "1.3-1" ,
124
+ py_version = "py3" ,
125
+ role = ROLE ,
126
+ instance_count = 1 ,
127
+ instance_type = INSTANCE_TYPE ,
128
+ base_job_name = "test-xgboost" ,
129
+ ),
130
+ {"code" : DUMMY_S3_SCRIPT_PATH },
131
+ ),
132
+ (
133
+ MXNetProcessor (
134
+ framework_version = "1.4.1" ,
135
+ py_version = "py3" ,
136
+ role = ROLE ,
137
+ instance_count = 1 ,
138
+ instance_type = INSTANCE_TYPE ,
139
+ base_job_name = "test-mxnet" ,
140
+ ),
141
+ {"code" : DUMMY_S3_SCRIPT_PATH },
142
+ ),
143
+ (
144
+ DataWranglerProcessor (
145
+ role = ROLE ,
146
+ data_wrangler_flow_source = "s3://my-bucket/dw.flow" ,
147
+ instance_count = 1 ,
148
+ instance_type = INSTANCE_TYPE ,
149
+ ),
150
+ {},
151
+ ),
152
+ (
153
+ SparkJarProcessor (
154
+ role = ROLE ,
155
+ framework_version = "2.4" ,
156
+ instance_count = 1 ,
157
+ instance_type = INSTANCE_TYPE ,
158
+ ),
159
+ {
160
+ "submit_app" : "s3://my-jar" ,
161
+ "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
162
+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
163
+ },
164
+ ),
165
+ (
166
+ PySparkProcessor (
167
+ role = ROLE ,
168
+ framework_version = "2.4" ,
169
+ instance_count = 1 ,
170
+ instance_type = INSTANCE_TYPE ,
171
+ ),
172
+ {
173
+ "submit_app" : "s3://my-jar" ,
174
+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
175
+ },
176
+ ),
177
+ ]
178
+
179
+ PROCESSING_INPUT = [
180
+ ProcessingInput (
181
+ source = f"s3://my-bucket/processing_manifest" , destination = "processing_manifest"
182
+ ),
183
+ ProcessingInput (
184
+ source = ParameterString (name = "my-processing-input" ),
185
+ destination = "processing-input" ,
186
+ ),
187
+ ProcessingInput (
188
+ source = ParameterString (
189
+ name = "my-processing-input" , default_value = "s3://my-bucket/my-processing"
190
+ ),
191
+ destination = "processing-input" ,
192
+ ),
193
+ ProcessingInput (
194
+ source = Join (on = "/" , values = ["s3://my-bucket" , "my-input" ]),
195
+ destination = "processing-input" ,
196
+ ),
197
+ ]
198
+
199
+ PROCESSING_OUTPUT = [
200
+ ProcessingOutput (source = "/opt/ml/output" , destination = "s3://my-bucket/my-output" ),
201
+ ProcessingOutput (source = "/opt/ml/output" , destination = ParameterString (name = "my-output" )),
202
+ ProcessingOutput (
203
+ source = "/opt/ml/output" ,
204
+ destination = ParameterString (name = "my-output" , default_value = "s3://my-bucket/my-output" ),
205
+ ),
206
+ ProcessingOutput (
207
+ source = "/opt/ml/output" ,
208
+ destination = Join (on = "/" , values = ["s3://my-bucket" , "my-output" ]),
209
+ ),
210
+ ]
211
+
65
212
66
213
@pytest .fixture
67
214
def client ():
@@ -253,117 +400,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
253
400
}
254
401
255
402
256
- @pytest .mark .parametrize (
257
- "framework_processor" ,
258
- [
259
- (
260
- FrameworkProcessor (
261
- framework_version = "1.8" ,
262
- instance_type = INSTANCE_TYPE ,
263
- instance_count = 1 ,
264
- role = ROLE ,
265
- estimator_cls = PyTorch ,
266
- ),
267
- {"code" : DUMMY_S3_SCRIPT_PATH },
268
- ),
269
- (
270
- SKLearnProcessor (
271
- framework_version = "0.23-1" ,
272
- instance_type = INSTANCE_TYPE ,
273
- instance_count = 1 ,
274
- role = ROLE ,
275
- ),
276
- {"code" : DUMMY_S3_SCRIPT_PATH },
277
- ),
278
- (
279
- PyTorchProcessor (
280
- role = ROLE ,
281
- instance_type = INSTANCE_TYPE ,
282
- instance_count = 1 ,
283
- framework_version = "1.8.0" ,
284
- py_version = "py3" ,
285
- ),
286
- {"code" : DUMMY_S3_SCRIPT_PATH },
287
- ),
288
- (
289
- TensorFlowProcessor (
290
- role = ROLE ,
291
- instance_type = INSTANCE_TYPE ,
292
- instance_count = 1 ,
293
- framework_version = "2.0" ,
294
- ),
295
- {"code" : DUMMY_S3_SCRIPT_PATH },
296
- ),
297
- (
298
- HuggingFaceProcessor (
299
- transformers_version = "4.6" ,
300
- pytorch_version = "1.7" ,
301
- role = ROLE ,
302
- instance_count = 1 ,
303
- instance_type = "ml.p3.2xlarge" ,
304
- ),
305
- {"code" : DUMMY_S3_SCRIPT_PATH },
306
- ),
307
- (
308
- XGBoostProcessor (
309
- framework_version = "1.3-1" ,
310
- py_version = "py3" ,
311
- role = ROLE ,
312
- instance_count = 1 ,
313
- instance_type = INSTANCE_TYPE ,
314
- base_job_name = "test-xgboost" ,
315
- ),
316
- {"code" : DUMMY_S3_SCRIPT_PATH },
317
- ),
318
- (
319
- MXNetProcessor (
320
- framework_version = "1.4.1" ,
321
- py_version = "py3" ,
322
- role = ROLE ,
323
- instance_count = 1 ,
324
- instance_type = INSTANCE_TYPE ,
325
- base_job_name = "test-mxnet" ,
326
- ),
327
- {"code" : DUMMY_S3_SCRIPT_PATH },
328
- ),
329
- (
330
- DataWranglerProcessor (
331
- role = ROLE ,
332
- data_wrangler_flow_source = f"s3://{ BUCKET } /dw.flow" ,
333
- instance_count = 1 ,
334
- instance_type = INSTANCE_TYPE ,
335
- ),
336
- {},
337
- ),
338
- (
339
- SparkJarProcessor (
340
- role = ROLE ,
341
- framework_version = "2.4" ,
342
- instance_count = 1 ,
343
- instance_type = INSTANCE_TYPE ,
344
- ),
345
- {
346
- "submit_app" : "s3://my-jar" ,
347
- "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
348
- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
349
- },
350
- ),
351
- (
352
- PySparkProcessor (
353
- role = ROLE ,
354
- framework_version = "2.4" ,
355
- instance_count = 1 ,
356
- instance_type = INSTANCE_TYPE ,
357
- ),
358
- {
359
- "submit_app" : "s3://my-jar" ,
360
- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
361
- },
362
- ),
363
- ],
364
- )
403
+ @pytest .mark .parametrize ("framework_processor" , FRAMEWORK_PROCESSOR )
404
+ @pytest .mark .parametrize ("processing_input" , PROCESSING_INPUT )
405
+ @pytest .mark .parametrize ("processing_output" , PROCESSING_OUTPUT )
365
406
def test_processing_step_with_framework_processor (
366
- framework_processor , pipeline_session , processing_input , network_config
407
+ framework_processor , pipeline_session , processing_input , processing_output , network_config
367
408
):
368
409
369
410
processor , run_inputs = framework_processor
@@ -373,7 +414,8 @@ def test_processing_step_with_framework_processor(
373
414
processor .volume_kms_key = "volume-kms-key"
374
415
processor .network_config = network_config
375
416
376
- run_inputs ["inputs" ] = processing_input
417
+ run_inputs ["inputs" ] = [processing_input ]
418
+ run_inputs ["outputs" ] = [processing_output ]
377
419
378
420
step_args = processor .run (** run_inputs )
379
421
@@ -387,10 +429,25 @@ def test_processing_step_with_framework_processor(
387
429
sagemaker_session = pipeline_session ,
388
430
)
389
431
390
- assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
432
+ step_args = step_args .args
433
+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
434
+
435
+ assert step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ] == processing_input .source
436
+ assert (
437
+ step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
438
+ == processing_output .destination
439
+ )
440
+
441
+ del step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
442
+ del step_def ["Arguments" ]["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
443
+
444
+ del step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
445
+ del step_def ["Arguments" ]["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
446
+
447
+ assert step_def == {
391
448
"Name" : "MyProcessingStep" ,
392
449
"Type" : "Processing" ,
393
- "Arguments" : step_args . args ,
450
+ "Arguments" : step_args ,
394
451
}
395
452
396
453
0 commit comments