@@ -855,6 +855,113 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
855
855
pass
856
856
857
857
858
+ def test_steps_with_map_params_pipeline (
859
+ sagemaker_session , role , script_dir , pipeline_name , region_name , athena_dataset_definition
860
+ ):
861
+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
862
+ framework_version = "0.20.0"
863
+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
864
+ output_prefix = ParameterString (name = "OutputPrefix" , default_value = "output" )
865
+ input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
866
+
867
+ sklearn_processor = SKLearnProcessor (
868
+ framework_version = framework_version ,
869
+ instance_type = instance_type ,
870
+ instance_count = instance_count ,
871
+ base_job_name = "test-sklearn" ,
872
+ sagemaker_session = sagemaker_session ,
873
+ role = role ,
874
+ )
875
+ step_process = ProcessingStep (
876
+ name = "my-process" ,
877
+ display_name = "ProcessingStep" ,
878
+ description = "description for Processing step" ,
879
+ processor = sklearn_processor ,
880
+ inputs = [
881
+ ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
882
+ ProcessingInput (dataset_definition = athena_dataset_definition ),
883
+ ],
884
+ outputs = [
885
+ ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
886
+ ProcessingOutput (
887
+ output_name = "test_data" ,
888
+ source = "/opt/ml/processing/test" ,
889
+ destination = Join (
890
+ on = "/" ,
891
+ values = [
892
+ "s3:/" ,
893
+ sagemaker_session .default_bucket (),
894
+ "test-sklearn" ,
895
+ output_prefix ,
896
+ ExecutionVariables .PIPELINE_EXECUTION_ID ,
897
+ ],
898
+ ),
899
+ ),
900
+ ],
901
+ code = os .path .join (script_dir , "preprocessing.py" ),
902
+ )
903
+
904
+ sklearn_train = SKLearn (
905
+ framework_version = framework_version ,
906
+ entry_point = os .path .join (script_dir , "train.py" ),
907
+ instance_type = instance_type ,
908
+ sagemaker_session = sagemaker_session ,
909
+ role = role ,
910
+ hyperparameters = {
911
+ "batch-size" : 500 ,
912
+ "epochs" : 5 ,
913
+ },
914
+ )
915
+ step_train = TrainingStep (
916
+ name = "my-train" ,
917
+ display_name = "TrainingStep" ,
918
+ description = "description for Training step" ,
919
+ estimator = sklearn_train ,
920
+ inputs = TrainingInput (
921
+ s3_data = step_process .properties .ProcessingOutputConfig .Outputs [
922
+ "train_data"
923
+ ].S3Output .S3Uri
924
+ ),
925
+ )
926
+
927
+ pipeline = Pipeline (
928
+ name = pipeline_name ,
929
+ parameters = [instance_type , instance_count , output_prefix ],
930
+ steps = [step_process , step_train ],
931
+ sagemaker_session = sagemaker_session ,
932
+ )
933
+
934
+ definition = json .loads (pipeline .definition ())
935
+ assert definition ["Version" ] == "2020-12-01"
936
+
937
+ steps = definition ["Steps" ]
938
+ assert len (steps ) == 2
939
+ training_args = {}
940
+ for step in steps :
941
+ if step ["Type" ] == "Training" :
942
+ training_args = step ["Arguments" ]
943
+
944
+ assert training_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] == {
945
+ "Get" : "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
946
+ }
947
+ assert training_args ["HyperParameters" ]["batch-size" ] == "500"
948
+ assert training_args ["HyperParameters" ]["epochs" ] == "5"
949
+
950
+ try :
951
+ response = pipeline .create (role )
952
+ create_arn = response ["PipelineArn" ]
953
+ assert re .match (
954
+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
955
+ create_arn ,
956
+ )
957
+
958
+ finally :
959
+ try :
960
+ pipeline .delete ()
961
+ except Exception :
962
+ pass
963
+
964
+
858
965
def test_two_step_callback_pipeline_with_output_reference (
859
966
sagemaker_session , role , pipeline_name , region_name
860
967
):
0 commit comments