@@ -865,3 +865,121 @@ def test_estimator_transformer(estimator):
865
865
}
866
866
else :
867
867
raise Exception ("A step exists in the collection of an invalid type." )
868
+
869
+
870
+ def test_estimator_transformer_with_model_repack (estimator ):
871
+ model_data = f"s3://{ BUCKET } /model.tar.gz"
872
+ dummy_requirements = f"{ DATA_DIR } /dummy_requirements.txt"
873
+ model_inputs = CreateModelInput (
874
+ instance_type = "c4.4xlarge" ,
875
+ accelerator_type = "ml.eia1.medium" ,
876
+ )
877
+ transform_inputs = TransformInput (data = f"s3://{ BUCKET } /transform_manifest" )
878
+ estimator_transformer = EstimatorTransformer (
879
+ name = "EstimatorTransformerStep" ,
880
+ estimator = estimator ,
881
+ model_data = model_data ,
882
+ model_inputs = model_inputs ,
883
+ instance_count = 1 ,
884
+ instance_type = "ml.c4.4xlarge" ,
885
+ transform_inputs = transform_inputs ,
886
+ depends_on = ["TestStep" ],
887
+ entry_point = f"{ DATA_DIR } /dummy_script.py" ,
888
+ dependencies = [dummy_requirements ],
889
+ )
890
+ request_dicts = estimator_transformer .request_dicts ()
891
+ assert len (request_dicts ) == 3
892
+
893
+ for request_dict in request_dicts :
894
+ if request_dict ["Type" ] == "Training" :
895
+ assert request_dict ["Name" ] == "EstimatorTransformerStepRepackModel"
896
+ assert len (request_dict ["DependsOn" ]) == 1
897
+ assert request_dict ["DependsOn" ][0 ] == "TestStep"
898
+ arguments = request_dict ["Arguments" ]
899
+ repacker_job_name = arguments ["HyperParameters" ]["sagemaker_job_name" ]
900
+ assert ordered (arguments ) == ordered (
901
+ {
902
+ "AlgorithmSpecification" : {
903
+ "TrainingImage" : MODEL_REPACKING_IMAGE_URI ,
904
+ "TrainingInputMode" : "File" ,
905
+ },
906
+ "DebugHookConfig" : {
907
+ "CollectionConfigurations" : [],
908
+ "S3OutputPath" : f"s3://{ BUCKET } /" ,
909
+ },
910
+ "HyperParameters" : {
911
+ "inference_script" : '"dummy_script.py"' ,
912
+ "dependencies" : f'"{ dummy_requirements } "' ,
913
+ "model_archive" : '"model.tar.gz"' ,
914
+ "sagemaker_submit_directory" : '"s3://{}/{}/source/sourcedir.tar.gz"' .format (
915
+ BUCKET , repacker_job_name .replace ('"' , "" )
916
+ ),
917
+ "sagemaker_program" : '"_repack_model.py"' ,
918
+ "sagemaker_container_log_level" : "20" ,
919
+ "sagemaker_job_name" : repacker_job_name ,
920
+ "sagemaker_region" : f'"{ REGION } "' ,
921
+ "source_dir" : "null" ,
922
+ },
923
+ "InputDataConfig" : [
924
+ {
925
+ "ChannelName" : "training" ,
926
+ "DataSource" : {
927
+ "S3DataSource" : {
928
+ "S3DataDistributionType" : "FullyReplicated" ,
929
+ "S3DataType" : "S3Prefix" ,
930
+ "S3Uri" : f"s3://{ BUCKET } " ,
931
+ }
932
+ },
933
+ }
934
+ ],
935
+ "OutputDataConfig" : {"S3OutputPath" : f"s3://{ BUCKET } /" },
936
+ "ResourceConfig" : {
937
+ "InstanceCount" : 1 ,
938
+ "InstanceType" : "ml.m5.large" ,
939
+ "VolumeSizeInGB" : 30 ,
940
+ },
941
+ "RoleArn" : ROLE ,
942
+ "StoppingCondition" : {"MaxRuntimeInSeconds" : 86400 },
943
+ "VpcConfig" : [
944
+ ("SecurityGroupIds" , ["123" , "456" ]),
945
+ ("Subnets" , ["abc" , "def" ]),
946
+ ],
947
+ }
948
+ )
949
+ elif request_dict ["Type" ] == "Model" :
950
+ assert request_dict ["Name" ] == "EstimatorTransformerStepCreateModelStep"
951
+ assert "DependsOn" not in request_dict
952
+ arguments = request_dict ["Arguments" ]
953
+ assert isinstance (arguments ["PrimaryContainer" ]["ModelDataUrl" ], Properties )
954
+ del arguments ["PrimaryContainer" ]["ModelDataUrl" ]
955
+ assert ordered (arguments ) == ordered (
956
+ {
957
+ "ExecutionRoleArn" : "DummyRole" ,
958
+ "PrimaryContainer" : {
959
+ "Environment" : {},
960
+ "Image" : "fakeimage" ,
961
+ },
962
+ }
963
+ )
964
+ elif request_dict ["Type" ] == "Transform" :
965
+ assert request_dict ["Name" ] == "EstimatorTransformerStepTransformStep"
966
+ assert "DependsOn" not in request_dict
967
+ arguments = request_dict ["Arguments" ]
968
+ assert isinstance (arguments ["ModelName" ], Properties )
969
+ arguments .pop ("ModelName" )
970
+ assert ordered (arguments ) == ordered (
971
+ {
972
+ "TransformInput" : {
973
+ "DataSource" : {
974
+ "S3DataSource" : {
975
+ "S3DataType" : "S3Prefix" ,
976
+ "S3Uri" : f"s3://{ BUCKET } /transform_manifest" ,
977
+ }
978
+ }
979
+ },
980
+ "TransformOutput" : {"S3OutputPath" : None },
981
+ "TransformResources" : {"InstanceCount" : 1 , "InstanceType" : "ml.c4.4xlarge" },
982
+ }
983
+ )
984
+ else :
985
+ raise Exception ("A step exists in the collection of an invalid type." )
0 commit comments