@@ -273,13 +273,61 @@ def test_spark_processor_base_extend_processing_args(
273
273
serialized_configuration = BytesIO ("test" .encode ("utf-8" ))
274
274
275
275
276
+ @pytest .mark .parametrize (
277
+ "config, expected" ,
278
+ [
279
+ (
280
+ {
281
+ "spark_processor_type" : "py_spark_processor" ,
282
+ "configuration_location" : None ,
283
+ },
284
+ "s3://bucket/None/input/conf/configuration.json" ,
285
+ ),
286
+ (
287
+ {
288
+ "spark_processor_type" : "py_spark_processor" ,
289
+ "configuration_location" : "s3://configbucket/someprefix/" ,
290
+ },
291
+ "s3://configbucket/someprefix/None/input/conf/configuration.json" ,
292
+ ),
293
+ (
294
+ {
295
+ "spark_processor_type" : "spark_jar_processor" ,
296
+ "configuration_location" : None ,
297
+ },
298
+ "s3://bucket/None/input/conf/configuration.json" ,
299
+ ),
300
+ (
301
+ {
302
+ "spark_processor_type" : "spark_jar_processor" ,
303
+ "configuration_location" : "s3://configbucket/someprefix" ,
304
+ },
305
+ "s3://configbucket/someprefix/None/input/conf/configuration.json" ,
306
+ ),
307
+ ],
308
+ )
276
309
@patch ("sagemaker.spark.processing.BytesIO" )
277
310
@patch ("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body" )
278
- def test_stage_configuration (mock_s3_upload , mock_bytesIO , py_spark_processor , sagemaker_session ):
279
- desired_s3_uri = "s3://bucket/None/input/conf/configuration.json"
311
+ def test_stage_configuration (mock_s3_upload , mock_bytesIO , config , expected , sagemaker_session ):
312
+ spark_processor_type = {
313
+ "py_spark_processor" : PySparkProcessor ,
314
+ "spark_jar_processor" : SparkJarProcessor ,
315
+ }[config ["spark_processor_type" ]]
316
+ spark_processor = spark_processor_type (
317
+ base_job_name = "sm-spark" ,
318
+ role = "AmazonSageMaker-ExecutionRole" ,
319
+ framework_version = "2.4" ,
320
+ instance_count = 1 ,
321
+ instance_type = "ml.c5.xlarge" ,
322
+ image_uri = "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1" ,
323
+ configuration_location = config ["configuration_location" ],
324
+ sagemaker_session = sagemaker_session ,
325
+ )
326
+
327
+ desired_s3_uri = expected
280
328
mock_bytesIO .return_value = serialized_configuration
281
329
282
- result = py_spark_processor ._stage_configuration ({})
330
+ result = spark_processor ._stage_configuration ({})
283
331
284
332
mock_s3_upload .assert_called_with (
285
333
body = serialized_configuration ,
@@ -292,23 +340,121 @@ def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, s
292
340
@pytest .mark .parametrize (
293
341
"config, expected" ,
294
342
[
295
- ({"submit_deps" : None , "input_channel_name" : "channelName" }, ValueError ),
296
- ({"submit_deps" : ["s3" ], "input_channel_name" : None }, ValueError ),
297
- ({"submit_deps" : ["other" ], "input_channel_name" : "channelName" }, ValueError ),
298
- ({"submit_deps" : ["file" ], "input_channel_name" : "channelName" }, ValueError ),
299
- ({"submit_deps" : ["file" ], "input_channel_name" : "channelName" }, ValueError ),
300
343
(
301
- {"submit_deps" : ["s3" , "s3" ], "input_channel_name" : "channelName" },
344
+ {
345
+ "spark_processor_type" : "py_spark_processor" ,
346
+ "dependency_location" : None ,
347
+ "submit_deps" : None ,
348
+ "input_channel_name" : "channelName" ,
349
+ },
350
+ ValueError ,
351
+ ),
352
+ (
353
+ {
354
+ "spark_processor_type" : "py_spark_processor" ,
355
+ "dependency_location" : None ,
356
+ "submit_deps" : ["s3" ],
357
+ "input_channel_name" : None ,
358
+ },
359
+ ValueError ,
360
+ ),
361
+ (
362
+ {
363
+ "spark_processor_type" : "py_spark_processor" ,
364
+ "dependency_location" : None ,
365
+ "submit_deps" : ["other" ],
366
+ "input_channel_name" : "channelName" ,
367
+ },
368
+ ValueError ,
369
+ ),
370
+ (
371
+ {
372
+ "spark_processor_type" : "py_spark_processor" ,
373
+ "dependency_location" : None ,
374
+ "submit_deps" : ["file" ],
375
+ "input_channel_name" : "channelName" ,
376
+ },
377
+ ValueError ,
378
+ ),
379
+ (
380
+ {
381
+ "spark_processor_type" : "py_spark_processor" ,
382
+ "dependency_location" : None ,
383
+ "submit_deps" : ["file" ],
384
+ "input_channel_name" : "channelName" ,
385
+ },
386
+ ValueError ,
387
+ ),
388
+ (
389
+ {
390
+ "spark_processor_type" : "py_spark_processor" ,
391
+ "dependency_location" : None ,
392
+ "submit_deps" : ["s3" , "s3" ],
393
+ "input_channel_name" : "channelName" ,
394
+ },
302
395
(None , "s3://bucket,s3://bucket" ),
303
396
),
304
397
(
305
- {"submit_deps" : ["jar" ], "input_channel_name" : "channelName" },
306
- (processing_input , "s3://bucket" ),
398
+ {
399
+ "spark_processor_type" : "py_spark_processor" ,
400
+ "dependency_location" : None ,
401
+ "submit_deps" : ["jar" ],
402
+ "input_channel_name" : "channelName" ,
403
+ },
404
+ ("s3://bucket/None/input/channelName" , "/opt/ml/processing/input/channelName" ),
405
+ ),
406
+ (
407
+ {
408
+ "spark_processor_type" : "py_spark_processor" ,
409
+ "dependency_location" : "s3://codebucket/someprefix/" ,
410
+ "submit_deps" : ["jar" ],
411
+ "input_channel_name" : "channelName" ,
412
+ },
413
+ (
414
+ "s3://codebucket/someprefix/None/input/channelName" ,
415
+ "/opt/ml/processing/input/channelName" ,
416
+ ),
417
+ ),
418
+ (
419
+ {
420
+ "spark_processor_type" : "spark_jar_processor" ,
421
+ "dependency_location" : None ,
422
+ "submit_deps" : ["jar" ],
423
+ "input_channel_name" : "channelName" ,
424
+ },
425
+ ("s3://bucket/None/input/channelName" , "/opt/ml/processing/input/channelName" ),
426
+ ),
427
+ (
428
+ {
429
+ "spark_processor_type" : "spark_jar_processor" ,
430
+ "dependency_location" : "s3://codebucket/someprefix" ,
431
+ "submit_deps" : ["jar" ],
432
+ "input_channel_name" : "channelName" ,
433
+ },
434
+ (
435
+ "s3://codebucket/someprefix/None/input/channelName" ,
436
+ "/opt/ml/processing/input/channelName" ,
437
+ ),
307
438
),
308
439
],
309
440
)
310
441
@patch ("sagemaker.spark.processing.S3Uploader" )
311
- def test_stage_submit_deps (mock_s3_uploader , py_spark_processor , jar_file , config , expected ):
442
+ def test_stage_submit_deps (mock_s3_uploader , jar_file , config , expected , sagemaker_session ):
443
+ spark_processor_type = {
444
+ "py_spark_processor" : PySparkProcessor ,
445
+ "spark_jar_processor" : SparkJarProcessor ,
446
+ }[config ["spark_processor_type" ]]
447
+ spark_processor = spark_processor_type (
448
+ base_job_name = "sm-spark" ,
449
+ role = "AmazonSageMaker-ExecutionRole" ,
450
+ framework_version = "2.4" ,
451
+ instance_count = 1 ,
452
+ instance_type = "ml.c5.xlarge" ,
453
+ image_uri = "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1" ,
454
+ dependency_location = config ["dependency_location" ],
455
+ sagemaker_session = sagemaker_session ,
456
+ )
457
+
312
458
submit_deps_dict = {
313
459
None : None ,
314
460
"s3" : "s3://bucket" ,
@@ -322,21 +468,20 @@ def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, confi
322
468
323
469
if expected is ValueError :
324
470
with pytest .raises (expected ) as e :
325
- py_spark_processor ._stage_submit_deps (submit_deps , config ["input_channel_name" ])
471
+ spark_processor ._stage_submit_deps (submit_deps , config ["input_channel_name" ])
326
472
327
473
assert isinstance (e .value , expected )
328
474
else :
329
- input_channel , spark_opt = py_spark_processor ._stage_submit_deps (
475
+ input_channel , spark_opt = spark_processor ._stage_submit_deps (
330
476
submit_deps , config ["input_channel_name" ]
331
477
)
332
478
333
479
if expected [0 ] is None :
334
480
assert input_channel is None
335
481
assert spark_opt == expected [1 ]
336
482
else :
337
- expected_source = "s3://bucket/None/input/channelName"
338
- assert input_channel .source == expected_source
339
- assert spark_opt == "/opt/ml/processing/input/channelName"
483
+ assert input_channel .source == expected [0 ]
484
+ assert spark_opt == expected [1 ]
340
485
341
486
342
487
@pytest .mark .parametrize (
0 commit comments