23
23
from sagemaker .transformer import Transformer
24
24
from sagemaker .tuner import HyperparameterTuner
25
25
from sagemaker .workflow .pipeline_context import PipelineSession
26
- from sagemaker .workflow .functions import Join
27
- from sagemaker .workflow .parameters import ParameterString
28
26
29
27
from sagemaker .processing import (
30
28
Processor ,
@@ -603,8 +601,10 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session):
603
601
)
604
602
605
603
606
- @pytest .mark .parametrize ("spark_processor" , [
607
- (
604
+ @pytest .mark .parametrize (
605
+ "spark_processor" ,
606
+ [
607
+ (
608
608
SparkJarProcessor (
609
609
role = ROLE ,
610
610
framework_version = "2.4" ,
@@ -614,14 +614,28 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session):
614
614
{
615
615
"submit_app" : "s3://my-jar" ,
616
616
"submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
617
- "arguments" : ["--input" , "input-data-uri" , "--output" , ParameterString ("MyArgOutput" )],
618
- "submit_jars" : ["s3://my-jar" , ParameterString ("MyJars" ), "s3://her-jar" , ParameterString ("OurJar" )],
619
- "submit_files" : ["s3://my-files" , ParameterString ("MyFiles" ), "s3://her-files" ,
620
- ParameterString ("OurFiles" )],
621
- "spark_event_logs_s3_uri" : ParameterString ("MySparkEventLogS3Uri" )
617
+ "arguments" : [
618
+ "--input" ,
619
+ "input-data-uri" ,
620
+ "--output" ,
621
+ ParameterString ("MyArgOutput" ),
622
+ ],
623
+ "submit_jars" : [
624
+ "s3://my-jar" ,
625
+ ParameterString ("MyJars" ),
626
+ "s3://her-jar" ,
627
+ ParameterString ("OurJar" ),
628
+ ],
629
+ "submit_files" : [
630
+ "s3://my-files" ,
631
+ ParameterString ("MyFiles" ),
632
+ "s3://her-files" ,
633
+ ParameterString ("OurFiles" ),
634
+ ],
635
+ "spark_event_logs_s3_uri" : ParameterString ("MySparkEventLogS3Uri" ),
622
636
},
623
- ),
624
- (
637
+ ),
638
+ (
625
639
PySparkProcessor (
626
640
role = ROLE ,
627
641
framework_version = "2.4" ,
@@ -630,16 +644,35 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session):
630
644
),
631
645
{
632
646
"submit_app" : "s3://my-jar" ,
633
- "arguments" : ["--input" , "input-data-uri" , "--output" , ParameterString ("MyArgOutput" )],
634
- "submit_py_files" : ["s3://my-py-files" , ParameterString ("MyPyFiles" ), "s3://her-pyfiles" ,
635
- ParameterString ("OurPyFiles" )],
636
- "submit_jars" : ["s3://my-jar" , ParameterString ("MyJars" ), "s3://her-jar" , ParameterString ("OurJar" )],
637
- "submit_files" : ["s3://my-files" , ParameterString ("MyFiles" ), "s3://her-files" ,
638
- ParameterString ("OurFiles" )],
647
+ "arguments" : [
648
+ "--input" ,
649
+ "input-data-uri" ,
650
+ "--output" ,
651
+ ParameterString ("MyArgOutput" ),
652
+ ],
653
+ "submit_py_files" : [
654
+ "s3://my-py-files" ,
655
+ ParameterString ("MyPyFiles" ),
656
+ "s3://her-pyfiles" ,
657
+ ParameterString ("OurPyFiles" ),
658
+ ],
659
+ "submit_jars" : [
660
+ "s3://my-jar" ,
661
+ ParameterString ("MyJars" ),
662
+ "s3://her-jar" ,
663
+ ParameterString ("OurJar" ),
664
+ ],
665
+ "submit_files" : [
666
+ "s3://my-files" ,
667
+ ParameterString ("MyFiles" ),
668
+ "s3://her-files" ,
669
+ ParameterString ("OurFiles" ),
670
+ ],
639
671
"spark_event_logs_s3_uri" : ParameterString ("MySparkEventLogS3Uri" ),
640
672
},
641
- ),
642
- ])
673
+ ),
674
+ ],
675
+ )
643
676
def test_spark_processor (spark_processor , processing_input , pipeline_session ):
644
677
645
678
processor , run_inputs = spark_processor
@@ -656,52 +689,108 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
656
689
657
690
step_args = step_args .args
658
691
659
- assert step_args [' AppSpecification' ][ ' ContainerArguments' ] == run_inputs [' arguments' ]
692
+ assert step_args [" AppSpecification" ][ " ContainerArguments" ] == run_inputs [" arguments" ]
660
693
661
- entry_points = step_args [' AppSpecification' ][ ' ContainerEntrypoint' ]
694
+ entry_points = step_args [" AppSpecification" ][ " ContainerEntrypoint" ]
662
695
entry_points_expr = []
663
696
for entry_point in entry_points :
664
697
if is_pipeline_variable (entry_point ):
665
698
entry_points_expr .append (entry_point .expr )
666
699
else :
667
700
entry_points_expr .append (entry_point )
668
701
669
- if ' submit_py_files' in run_inputs :
702
+ if " submit_py_files" in run_inputs :
670
703
expected = [
671
- 'smspark-submit' ,
672
- '--py-files' , {'Std:Join' : {'On' : ',' , 'Values' : ['s3://my-py-files' , {'Get' : 'Parameters.MyPyFiles' },
673
- 's3://her-pyfiles' , {'Get' : 'Parameters.OurPyFiles' }]}},
674
- '--jars' , {'Std:Join' : {'On' : ',' , 'Values' : ['s3://my-jar' , {'Get' : 'Parameters.MyJars' }, 's3://her-jar' , {'Get' : 'Parameters.OurJar' }]}},
675
- '--files' , {'Std:Join' : {'On' : ',' , 'Values' : ['s3://my-files' , {'Get' : 'Parameters.MyFiles' }, 's3://her-files' , {'Get' : 'Parameters.OurFiles' }]}},
676
- '--local-spark-event-logs-dir' , '/opt/ml/processing/spark-events/' , '/opt/ml/processing/input/code'
704
+ "smspark-submit" ,
705
+ "--py-files" ,
706
+ {
707
+ "Std:Join" : {
708
+ "On" : "," ,
709
+ "Values" : [
710
+ "s3://my-py-files" ,
711
+ {"Get" : "Parameters.MyPyFiles" },
712
+ "s3://her-pyfiles" ,
713
+ {"Get" : "Parameters.OurPyFiles" },
714
+ ],
715
+ }
716
+ },
717
+ "--jars" ,
718
+ {
719
+ "Std:Join" : {
720
+ "On" : "," ,
721
+ "Values" : [
722
+ "s3://my-jar" ,
723
+ {"Get" : "Parameters.MyJars" },
724
+ "s3://her-jar" ,
725
+ {"Get" : "Parameters.OurJar" },
726
+ ],
727
+ }
728
+ },
729
+ "--files" ,
730
+ {
731
+ "Std:Join" : {
732
+ "On" : "," ,
733
+ "Values" : [
734
+ "s3://my-files" ,
735
+ {"Get" : "Parameters.MyFiles" },
736
+ "s3://her-files" ,
737
+ {"Get" : "Parameters.OurFiles" },
738
+ ],
739
+ }
740
+ },
741
+ "--local-spark-event-logs-dir" ,
742
+ "/opt/ml/processing/spark-events/" ,
743
+ "/opt/ml/processing/input/code" ,
677
744
]
678
745
# py spark
679
746
else :
680
747
expected = [
681
- 'smspark-submit' ,
682
- '--class' ,
683
- 'com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp' ,
684
- '--jars' , {'Std:Join' : {'On' : ',' , 'Values' : ['s3://my-jar' , {'Get' : 'Parameters.MyJars' },
685
- 's3://her-jar' , {'Get' : 'Parameters.OurJar' }]}},
686
- '--files' , {'Std:Join' : {'On' : ',' , 'Values' : ['s3://my-files' , {'Get' : 'Parameters.MyFiles' },
687
- 's3://her-files' , {'Get' : 'Parameters.OurFiles' }]}},
688
- '--local-spark-event-logs-dir' , '/opt/ml/processing/spark-events/' , '/opt/ml/processing/input/code'
748
+ "smspark-submit" ,
749
+ "--class" ,
750
+ "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
751
+ "--jars" ,
752
+ {
753
+ "Std:Join" : {
754
+ "On" : "," ,
755
+ "Values" : [
756
+ "s3://my-jar" ,
757
+ {"Get" : "Parameters.MyJars" },
758
+ "s3://her-jar" ,
759
+ {"Get" : "Parameters.OurJar" },
760
+ ],
761
+ }
762
+ },
763
+ "--files" ,
764
+ {
765
+ "Std:Join" : {
766
+ "On" : "," ,
767
+ "Values" : [
768
+ "s3://my-files" ,
769
+ {"Get" : "Parameters.MyFiles" },
770
+ "s3://her-files" ,
771
+ {"Get" : "Parameters.OurFiles" },
772
+ ],
773
+ }
774
+ },
775
+ "--local-spark-event-logs-dir" ,
776
+ "/opt/ml/processing/spark-events/" ,
777
+ "/opt/ml/processing/input/code" ,
689
778
]
690
779
691
780
assert entry_points_expr == expected
692
- for output in step_args [' ProcessingOutputConfig' ][ ' Outputs' ]:
693
- if is_pipeline_variable (output [' S3Output' ][ ' S3Uri' ]):
694
- output [' S3Output' ][ ' S3Uri' ] = output [' S3Output' ][ ' S3Uri' ].expr
781
+ for output in step_args [" ProcessingOutputConfig" ][ " Outputs" ]:
782
+ if is_pipeline_variable (output [" S3Output" ][ " S3Uri" ]):
783
+ output [" S3Output" ][ " S3Uri" ] = output [" S3Output" ][ " S3Uri" ].expr
695
784
696
- assert step_args [' ProcessingOutputConfig' ][ ' Outputs' ] == [
785
+ assert step_args [" ProcessingOutputConfig" ][ " Outputs" ] == [
697
786
{
698
- ' OutputName' : ' output-1' ,
699
- ' AppManaged' : False ,
700
- ' S3Output' : {
701
- ' S3Uri' : {' Get' : ' Parameters.MySparkEventLogS3Uri' },
702
- ' LocalPath' : ' /opt/ml/processing/spark-events/' ,
703
- ' S3UploadMode' : ' Continuous'
704
- }
787
+ " OutputName" : " output-1" ,
788
+ " AppManaged" : False ,
789
+ " S3Output" : {
790
+ " S3Uri" : {" Get" : " Parameters.MySparkEventLogS3Uri" },
791
+ " LocalPath" : " /opt/ml/processing/spark-events/" ,
792
+ " S3UploadMode" : " Continuous" ,
793
+ },
705
794
}
706
795
]
707
796
@@ -711,4 +800,3 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
711
800
sagemaker_session = pipeline_session ,
712
801
)
713
802
pipeline .definition ()
714
-
0 commit comments