@@ -271,13 +271,61 @@ def test_spark_processor_base_extend_processing_args(
271
271
serialized_configuration = BytesIO ("test" .encode ("utf-8" ))
272
272
273
273
274
+ @pytest .mark .parametrize (
275
+ "config, expected" ,
276
+ [
277
+ (
278
+ {
279
+ "spark_processor_type" : "py_spark_processor" ,
280
+ "configuration_location" : None ,
281
+ },
282
+ "s3://bucket/None/input/conf/configuration.json" ,
283
+ ),
284
+ (
285
+ {
286
+ "spark_processor_type" : "py_spark_processor" ,
287
+ "configuration_location" : "s3://configbucket/someprefix/" ,
288
+ },
289
+ "s3://configbucket/someprefix/None/input/conf/configuration.json" ,
290
+ ),
291
+ (
292
+ {
293
+ "spark_processor_type" : "spark_jar_processor" ,
294
+ "configuration_location" : None ,
295
+ },
296
+ "s3://bucket/None/input/conf/configuration.json" ,
297
+ ),
298
+ (
299
+ {
300
+ "spark_processor_type" : "spark_jar_processor" ,
301
+ "configuration_location" : "s3://configbucket/someprefix" ,
302
+ },
303
+ "s3://configbucket/someprefix/None/input/conf/configuration.json" ,
304
+ ),
305
+ ],
306
+ )
274
307
@patch ("sagemaker.spark.processing.BytesIO" )
275
308
@patch ("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body" )
276
- def test_stage_configuration (mock_s3_upload , mock_bytesIO , py_spark_processor , sagemaker_session ):
277
- desired_s3_uri = "s3://bucket/None/input/conf/configuration.json"
309
+ def test_stage_configuration (mock_s3_upload , mock_bytesIO , config , expected , sagemaker_session ):
310
+ spark_processor_type = {
311
+ "py_spark_processor" : PySparkProcessor ,
312
+ "spark_jar_processor" : SparkJarProcessor ,
313
+ }[config ["spark_processor_type" ]]
314
+ spark_processor = spark_processor_type (
315
+ base_job_name = "sm-spark" ,
316
+ role = "AmazonSageMaker-ExecutionRole" ,
317
+ framework_version = "2.4" ,
318
+ instance_count = 1 ,
319
+ instance_type = "ml.c5.xlarge" ,
320
+ image_uri = "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1" ,
321
+ configuration_location = config ["configuration_location" ],
322
+ sagemaker_session = sagemaker_session ,
323
+ )
324
+
325
+ desired_s3_uri = expected
278
326
mock_bytesIO .return_value = serialized_configuration
279
327
280
- result = py_spark_processor ._stage_configuration ({})
328
+ result = spark_processor ._stage_configuration ({})
281
329
282
330
mock_s3_upload .assert_called_with (
283
331
body = serialized_configuration ,
@@ -290,23 +338,121 @@ def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, s
290
338
@pytest .mark .parametrize (
291
339
"config, expected" ,
292
340
[
293
- ({"submit_deps" : None , "input_channel_name" : "channelName" }, ValueError ),
294
- ({"submit_deps" : ["s3" ], "input_channel_name" : None }, ValueError ),
295
- ({"submit_deps" : ["other" ], "input_channel_name" : "channelName" }, ValueError ),
296
- ({"submit_deps" : ["file" ], "input_channel_name" : "channelName" }, ValueError ),
297
- ({"submit_deps" : ["file" ], "input_channel_name" : "channelName" }, ValueError ),
298
341
(
299
- {"submit_deps" : ["s3" , "s3" ], "input_channel_name" : "channelName" },
342
+ {
343
+ "spark_processor_type" : "py_spark_processor" ,
344
+ "dependency_location" : None ,
345
+ "submit_deps" : None ,
346
+ "input_channel_name" : "channelName" ,
347
+ },
348
+ ValueError ,
349
+ ),
350
+ (
351
+ {
352
+ "spark_processor_type" : "py_spark_processor" ,
353
+ "dependency_location" : None ,
354
+ "submit_deps" : ["s3" ],
355
+ "input_channel_name" : None ,
356
+ },
357
+ ValueError ,
358
+ ),
359
+ (
360
+ {
361
+ "spark_processor_type" : "py_spark_processor" ,
362
+ "dependency_location" : None ,
363
+ "submit_deps" : ["other" ],
364
+ "input_channel_name" : "channelName" ,
365
+ },
366
+ ValueError ,
367
+ ),
368
+ (
369
+ {
370
+ "spark_processor_type" : "py_spark_processor" ,
371
+ "dependency_location" : None ,
372
+ "submit_deps" : ["file" ],
373
+ "input_channel_name" : "channelName" ,
374
+ },
375
+ ValueError ,
376
+ ),
377
+ (
378
+ {
379
+ "spark_processor_type" : "py_spark_processor" ,
380
+ "dependency_location" : None ,
381
+ "submit_deps" : ["file" ],
382
+ "input_channel_name" : "channelName" ,
383
+ },
384
+ ValueError ,
385
+ ),
386
+ (
387
+ {
388
+ "spark_processor_type" : "py_spark_processor" ,
389
+ "dependency_location" : None ,
390
+ "submit_deps" : ["s3" , "s3" ],
391
+ "input_channel_name" : "channelName" ,
392
+ },
300
393
(None , "s3://bucket,s3://bucket" ),
301
394
),
302
395
(
303
- {"submit_deps" : ["jar" ], "input_channel_name" : "channelName" },
304
- (processing_input , "s3://bucket" ),
396
+ {
397
+ "spark_processor_type" : "py_spark_processor" ,
398
+ "dependency_location" : None ,
399
+ "submit_deps" : ["jar" ],
400
+ "input_channel_name" : "channelName" ,
401
+ },
402
+ ("s3://bucket/None/input/channelName" , "/opt/ml/processing/input/channelName" ),
403
+ ),
404
+ (
405
+ {
406
+ "spark_processor_type" : "py_spark_processor" ,
407
+ "dependency_location" : "s3://codebucket/someprefix/" ,
408
+ "submit_deps" : ["jar" ],
409
+ "input_channel_name" : "channelName" ,
410
+ },
411
+ (
412
+ "s3://codebucket/someprefix/None/input/channelName" ,
413
+ "/opt/ml/processing/input/channelName"
414
+ ),
415
+ ),
416
+ (
417
+ {
418
+ "spark_processor_type" : "spark_jar_processor" ,
419
+ "dependency_location" : None ,
420
+ "submit_deps" : ["jar" ],
421
+ "input_channel_name" : "channelName" ,
422
+ },
423
+ ("s3://bucket/None/input/channelName" , "/opt/ml/processing/input/channelName" ),
424
+ ),
425
+ (
426
+ {
427
+ "spark_processor_type" : "spark_jar_processor" ,
428
+ "dependency_location" : "s3://codebucket/someprefix" ,
429
+ "submit_deps" : ["jar" ],
430
+ "input_channel_name" : "channelName" ,
431
+ },
432
+ (
433
+ "s3://codebucket/someprefix/None/input/channelName" ,
434
+ "/opt/ml/processing/input/channelName"
435
+ ),
305
436
),
306
437
],
307
438
)
308
439
@patch ("sagemaker.spark.processing.S3Uploader" )
309
- def test_stage_submit_deps (mock_s3_uploader , py_spark_processor , jar_file , config , expected ):
440
+ def test_stage_submit_deps (mock_s3_uploader , jar_file , config , expected , sagemaker_session ):
441
+ spark_processor_type = {
442
+ "py_spark_processor" : PySparkProcessor ,
443
+ "spark_jar_processor" : SparkJarProcessor ,
444
+ }[config ["spark_processor_type" ]]
445
+ spark_processor = spark_processor_type (
446
+ base_job_name = "sm-spark" ,
447
+ role = "AmazonSageMaker-ExecutionRole" ,
448
+ framework_version = "2.4" ,
449
+ instance_count = 1 ,
450
+ instance_type = "ml.c5.xlarge" ,
451
+ image_uri = "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1" ,
452
+ dependency_location = config ["dependency_location" ],
453
+ sagemaker_session = sagemaker_session ,
454
+ )
455
+
310
456
submit_deps_dict = {
311
457
None : None ,
312
458
"s3" : "s3://bucket" ,
@@ -320,21 +466,20 @@ def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, confi
320
466
321
467
if expected is ValueError :
322
468
with pytest .raises (expected ) as e :
323
- py_spark_processor ._stage_submit_deps (submit_deps , config ["input_channel_name" ])
469
+ spark_processor ._stage_submit_deps (submit_deps , config ["input_channel_name" ])
324
470
325
471
assert isinstance (e .value , expected )
326
472
else :
327
- input_channel , spark_opt = py_spark_processor ._stage_submit_deps (
473
+ input_channel , spark_opt = spark_processor ._stage_submit_deps (
328
474
submit_deps , config ["input_channel_name" ]
329
475
)
330
476
331
477
if expected [0 ] is None :
332
478
assert input_channel is None
333
479
assert spark_opt == expected [1 ]
334
480
else :
335
- expected_source = "s3://bucket/None/input/channelName"
336
- assert input_channel .source == expected_source
337
- assert spark_opt == "/opt/ml/processing/input/channelName"
481
+ assert input_channel .source == expected [0 ]
482
+ assert spark_opt == expected [1 ]
338
483
339
484
340
485
@pytest .mark .parametrize (
0 commit comments