@@ -856,6 +856,148 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
856
856
pass
857
857
858
858
859
+ def test_steps_with_map_params_pipeline (
860
+ sagemaker_session , role , script_dir , pipeline_name , region_name , athena_dataset_definition
861
+ ):
862
+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
863
+ framework_version = "0.20.0"
864
+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
865
+ output_prefix = ParameterString (name = "OutputPrefix" , default_value = "output" )
866
+ input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
867
+
868
+ sklearn_processor = SKLearnProcessor (
869
+ framework_version = framework_version ,
870
+ instance_type = instance_type ,
871
+ instance_count = instance_count ,
872
+ base_job_name = "test-sklearn" ,
873
+ sagemaker_session = sagemaker_session ,
874
+ role = role ,
875
+ )
876
+ step_process = ProcessingStep (
877
+ name = "my-process" ,
878
+ display_name = "ProcessingStep" ,
879
+ description = "description for Processing step" ,
880
+ processor = sklearn_processor ,
881
+ inputs = [
882
+ ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
883
+ ProcessingInput (dataset_definition = athena_dataset_definition ),
884
+ ],
885
+ outputs = [
886
+ ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
887
+ ProcessingOutput (
888
+ output_name = "test_data" ,
889
+ source = "/opt/ml/processing/test" ,
890
+ destination = Join (
891
+ on = "/" ,
892
+ values = [
893
+ "s3:/" ,
894
+ sagemaker_session .default_bucket (),
895
+ "test-sklearn" ,
896
+ output_prefix ,
897
+ ExecutionVariables .PIPELINE_EXECUTION_ID ,
898
+ ],
899
+ ),
900
+ ),
901
+ ],
902
+ code = os .path .join (script_dir , "preprocessing.py" ),
903
+ )
904
+
905
+ sklearn_train = SKLearn (
906
+ framework_version = framework_version ,
907
+ entry_point = os .path .join (script_dir , "train.py" ),
908
+ instance_type = instance_type ,
909
+ sagemaker_session = sagemaker_session ,
910
+ role = role ,
911
+ hyperparameters = {
912
+ "batch-size" : 500 ,
913
+ "epochs" : 5 ,
914
+ },
915
+ )
916
+ step_train = TrainingStep (
917
+ name = "my-train" ,
918
+ display_name = "TrainingStep" ,
919
+ description = "description for Training step" ,
920
+ estimator = sklearn_train ,
921
+ inputs = TrainingInput (
922
+ s3_data = step_process .properties .ProcessingOutputConfig .Outputs [
923
+ "train_data"
924
+ ].S3Output .S3Uri
925
+ ),
926
+ )
927
+
928
+ model = Model (
929
+ image_uri = sklearn_train .image_uri ,
930
+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
931
+ sagemaker_session = sagemaker_session ,
932
+ role = role ,
933
+ )
934
+ model_inputs = CreateModelInput (
935
+ instance_type = "ml.m5.large" ,
936
+ accelerator_type = "ml.eia1.medium" ,
937
+ )
938
+ step_model = CreateModelStep (
939
+ name = "my-model" ,
940
+ display_name = "ModelStep" ,
941
+ description = "description for Model step" ,
942
+ model = model ,
943
+ inputs = model_inputs ,
944
+ )
945
+
946
+ # Condition step for evaluating model quality and branching execution
947
+ cond_lte = ConditionGreaterThanOrEqualTo (
948
+ left = step_train .properties .HyperParameters ["batch-size" ],
949
+ right = 6.0 ,
950
+ )
951
+
952
+ step_cond = ConditionStep (
953
+ name = "CustomerChurnAccuracyCond" ,
954
+ conditions = [cond_lte ],
955
+ if_steps = [],
956
+ else_steps = [step_model ],
957
+ )
958
+
959
+ pipeline = Pipeline (
960
+ name = pipeline_name ,
961
+ parameters = [instance_type , instance_count , output_prefix ],
962
+ steps = [step_process , step_train , step_cond ],
963
+ sagemaker_session = sagemaker_session ,
964
+ )
965
+
966
+ definition = json .loads (pipeline .definition ())
967
+ assert definition ["Version" ] == "2020-12-01"
968
+
969
+ steps = definition ["Steps" ]
970
+ assert len (steps ) == 3
971
+ training_args = {}
972
+ condition_args = {}
973
+ for step in steps :
974
+ if step ["Type" ] == "Training" :
975
+ training_args = step ["Arguments" ]
976
+ if step ["Type" ] == "Condition" :
977
+ condition_args = step ["Arguments" ]
978
+
979
+ assert training_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] == {
980
+ "Get" : "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
981
+ }
982
+ assert condition_args ["Conditions" ][0 ]["LeftValue" ] == {
983
+ "Get" : "Steps.my-train.HyperParameters['batch-size']"
984
+ }
985
+
986
+ try :
987
+ response = pipeline .create (role )
988
+ create_arn = response ["PipelineArn" ]
989
+ assert re .match (
990
+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
991
+ create_arn ,
992
+ )
993
+
994
+ finally :
995
+ try :
996
+ pipeline .delete ()
997
+ except Exception :
998
+ pass
999
+
1000
+
859
1001
def test_two_step_callback_pipeline_with_output_reference (
860
1002
sagemaker_session , role , pipeline_name , region_name
861
1003
):
0 commit comments