@@ -855,6 +855,148 @@ def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, regi
855
855
pass
856
856
857
857
858
+ def test_steps_with_map_params_pipeline (
859
+ sagemaker_session , role , script_dir , pipeline_name , region_name , athena_dataset_definition
860
+ ):
861
+ instance_count = ParameterInteger (name = "InstanceCount" , default_value = 2 )
862
+ framework_version = "0.20.0"
863
+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
864
+ output_prefix = ParameterString (name = "OutputPrefix" , default_value = "output" )
865
+ input_data = f"s3://sagemaker-sample-data-{ region_name } /processing/census/census-income.csv"
866
+
867
+ sklearn_processor = SKLearnProcessor (
868
+ framework_version = framework_version ,
869
+ instance_type = instance_type ,
870
+ instance_count = instance_count ,
871
+ base_job_name = "test-sklearn" ,
872
+ sagemaker_session = sagemaker_session ,
873
+ role = role ,
874
+ )
875
+ step_process = ProcessingStep (
876
+ name = "my-process" ,
877
+ display_name = "ProcessingStep" ,
878
+ description = "description for Processing step" ,
879
+ processor = sklearn_processor ,
880
+ inputs = [
881
+ ProcessingInput (source = input_data , destination = "/opt/ml/processing/input" ),
882
+ ProcessingInput (dataset_definition = athena_dataset_definition ),
883
+ ],
884
+ outputs = [
885
+ ProcessingOutput (output_name = "train_data" , source = "/opt/ml/processing/train" ),
886
+ ProcessingOutput (
887
+ output_name = "test_data" ,
888
+ source = "/opt/ml/processing/test" ,
889
+ destination = Join (
890
+ on = "/" ,
891
+ values = [
892
+ "s3:/" ,
893
+ sagemaker_session .default_bucket (),
894
+ "test-sklearn" ,
895
+ output_prefix ,
896
+ ExecutionVariables .PIPELINE_EXECUTION_ID ,
897
+ ],
898
+ ),
899
+ ),
900
+ ],
901
+ code = os .path .join (script_dir , "preprocessing.py" ),
902
+ )
903
+
904
+ sklearn_train = SKLearn (
905
+ framework_version = framework_version ,
906
+ entry_point = os .path .join (script_dir , "train.py" ),
907
+ instance_type = instance_type ,
908
+ sagemaker_session = sagemaker_session ,
909
+ role = role ,
910
+ hyperparameters = {
911
+ "batch-size" : 500 ,
912
+ "epochs" : 5 ,
913
+ },
914
+ )
915
+ step_train = TrainingStep (
916
+ name = "my-train" ,
917
+ display_name = "TrainingStep" ,
918
+ description = "description for Training step" ,
919
+ estimator = sklearn_train ,
920
+ inputs = TrainingInput (
921
+ s3_data = step_process .properties .ProcessingOutputConfig .Outputs [
922
+ "train_data"
923
+ ].S3Output .S3Uri
924
+ ),
925
+ )
926
+
927
+ model = Model (
928
+ image_uri = sklearn_train .image_uri ,
929
+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
930
+ sagemaker_session = sagemaker_session ,
931
+ role = role ,
932
+ )
933
+ model_inputs = CreateModelInput (
934
+ instance_type = "ml.m5.large" ,
935
+ accelerator_type = "ml.eia1.medium" ,
936
+ )
937
+ step_model = CreateModelStep (
938
+ name = "my-model" ,
939
+ display_name = "ModelStep" ,
940
+ description = "description for Model step" ,
941
+ model = model ,
942
+ inputs = model_inputs ,
943
+ )
944
+
945
+ # Condition step for evaluating model quality and branching execution
946
+ cond_lte = ConditionGreaterThanOrEqualTo (
947
+ left = step_train .properties .HyperParameters ["batch-size" ],
948
+ right = 6.0 ,
949
+ )
950
+
951
+ step_cond = ConditionStep (
952
+ name = "CustomerChurnAccuracyCond" ,
953
+ conditions = [cond_lte ],
954
+ if_steps = [],
955
+ else_steps = [step_model ],
956
+ )
957
+
958
+ pipeline = Pipeline (
959
+ name = pipeline_name ,
960
+ parameters = [instance_type , instance_count , output_prefix ],
961
+ steps = [step_process , step_train , step_cond ],
962
+ sagemaker_session = sagemaker_session ,
963
+ )
964
+
965
+ definition = json .loads (pipeline .definition ())
966
+ assert definition ["Version" ] == "2020-12-01"
967
+
968
+ steps = definition ["Steps" ]
969
+ assert len (steps ) == 3
970
+ training_args = {}
971
+ condition_args = {}
972
+ for step in steps :
973
+ if step ["Type" ] == "Training" :
974
+ training_args = step ["Arguments" ]
975
+ if step ["Type" ] == "Condition" :
976
+ condition_args = step ["Arguments" ]
977
+
978
+ assert training_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] == {
979
+ "Get" : "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
980
+ }
981
+ assert condition_args ["Conditions" ][0 ]["LeftValue" ] == {
982
+ "Get" : "Steps.my-train.HyperParameters['batch-size']"
983
+ }
984
+
985
+ try :
986
+ response = pipeline .create (role )
987
+ create_arn = response ["PipelineArn" ]
988
+ assert re .match (
989
+ fr"arn:aws:sagemaker:{ region_name } :\d{{12}}:pipeline/{ pipeline_name } " ,
990
+ create_arn ,
991
+ )
992
+
993
+ finally :
994
+ try :
995
+ pipeline .delete ()
996
+ except Exception :
997
+ pass
998
+
999
+
858
1000
def test_two_step_callback_pipeline_with_output_reference (
859
1001
sagemaker_session , role , pipeline_name , region_name
860
1002
):
0 commit comments