@@ -867,3 +867,123 @@ def test_estimator_transformer(sagemaker_session):
867
867
}
868
868
else :
869
869
raise Exception ("A step exists in the collection of an invalid type." )
870
+
871
+ def test_estimator_transformer_with_model_repack (sagemaker_session ):
872
+ model_data = f"s3://{ BUCKET } /model.tar.gz"
873
+ model_inputs = CreateModelInput (
874
+ instance_type = "c4.4xlarge" ,
875
+ accelerator_type = "ml.eia1.medium" ,
876
+ )
877
+ service_fault_retry_policy = StepRetryPolicy (
878
+ exception_types = [StepExceptionTypeEnum .SERVICE_FAULT ], max_attempts = 10
879
+ )
880
+ transform_inputs = TransformInput (data = f"s3://{ BUCKET } /transform_manifest" )
881
+ dummy_requirements = f"{ DATA_DIR } /dummy_requirements.txt"
882
+ estimator_transformer = EstimatorTransformer (
883
+ name = "EstimatorTransformerStep" ,
884
+ model_data = model_data ,
885
+ model_inputs = model_inputs ,
886
+ instance_count = 1 ,
887
+ instance_type = "ml.c4.4xlarge" ,
888
+ transform_inputs = transform_inputs ,
889
+ depends_on = ["TestStep" ],
890
+ model_step_retry_policies = [service_fault_retry_policy ],
891
+ transform_step_retry_policies = [service_fault_retry_policy ],
892
+ repack_model_step_retry_policies = [service_fault_retry_policy ],
893
+ image_uri = IMAGE_URI ,
894
+ sagemaker_session = sagemaker_session ,
895
+ role = ROLE ,
896
+ entry_point = f"{ DATA_DIR } /dummy_script.py" ,
897
+ dependencies = [dummy_requirements ]
898
+ )
899
+ request_dicts = estimator_transformer .request_dicts ()
900
+ assert len (request_dicts ) == 3
901
+
902
+ for request_dict in request_dicts :
903
+ if request_dict ["Type" ] == "Training" :
904
+ assert request_dict ["Name" ] == "EstimatorTransformerStepRepackModel"
905
+ assert len (request_dict ["DependsOn" ]) == 1
906
+ assert request_dict ["DependsOn" ][0 ] == "TestStep"
907
+ arguments = request_dict ["Arguments" ]
908
+ repacker_job_name = arguments ["HyperParameters" ]["sagemaker_job_name" ]
909
+ assert ordered (arguments ) == ordered (
910
+ {
911
+ "AlgorithmSpecification" : {
912
+ "TrainingImage" : MODEL_REPACKING_IMAGE_URI ,
913
+ "TrainingInputMode" : "File" ,
914
+ },
915
+ "DebugHookConfig" : {
916
+ "CollectionConfigurations" : [],
917
+ "S3OutputPath" : f"s3://{ BUCKET } /" ,
918
+ },
919
+ "HyperParameters" : {
920
+ "inference_script" : '"dummy_script.py"' ,
921
+ "dependencies" : f'"{ dummy_requirements } "' ,
922
+ "model_archive" : '"model.tar.gz"' ,
923
+ "sagemaker_submit_directory" : '"s3://{}/{}/source/sourcedir.tar.gz"' .format (
924
+ BUCKET , repacker_job_name .replace ('"' , "" )
925
+ ),
926
+ "sagemaker_program" : '"_repack_model.py"' ,
927
+ "sagemaker_container_log_level" : "20" ,
928
+ "sagemaker_job_name" : repacker_job_name ,
929
+ "sagemaker_region" : f'"{ REGION } "' ,
930
+ "source_dir" : "null" ,
931
+ },
932
+ "InputDataConfig" : [
933
+ {
934
+ "ChannelName" : "training" ,
935
+ "DataSource" : {
936
+ "S3DataSource" : {
937
+ "S3DataDistributionType" : "FullyReplicated" ,
938
+ "S3DataType" : "S3Prefix" ,
939
+ "S3Uri" : f"s3://{ BUCKET } " ,
940
+ }
941
+ },
942
+ }
943
+ ],
944
+ "OutputDataConfig" : {"S3OutputPath" : f"s3://{ BUCKET } /" },
945
+ "ResourceConfig" : {
946
+ "InstanceCount" : 1 ,
947
+ "InstanceType" : "ml.m5.large" ,
948
+ "VolumeSizeInGB" : 30 ,
949
+ },
950
+ "RoleArn" : ROLE ,
951
+ "StoppingCondition" : {"MaxRuntimeInSeconds" : 86400 },
952
+ }
953
+ )
954
+
955
+ elif request_dict ["Type" ] == "Model" :
956
+ assert request_dict ["Name" ] == "EstimatorTransformerStepCreateModelStep"
957
+ assert request_dict ["RetryPolicies" ] == [service_fault_retry_policy .to_request ()]
958
+ arguments = request_dict ["Arguments" ]
959
+ assert isinstance (arguments ["PrimaryContainer" ]["ModelDataUrl" ], Properties )
960
+ arguments ["PrimaryContainer" ].pop ("ModelDataUrl" )
961
+ assert arguments == {
962
+ "ExecutionRoleArn" : "DummyRole" ,
963
+ "PrimaryContainer" : {
964
+ "Environment" : {},
965
+ "Image" : "fakeimage" ,
966
+ }
967
+ }
968
+
969
+ elif request_dict ["Type" ] == "Transform" :
970
+ assert request_dict ["Name" ] == "EstimatorTransformerStepTransformStep"
971
+ assert request_dict ["RetryPolicies" ] == [service_fault_retry_policy .to_request ()]
972
+ arguments = request_dict ["Arguments" ]
973
+ assert isinstance (arguments ["ModelName" ], Properties )
974
+ arguments .pop ("ModelName" )
975
+ assert "DependsOn" not in request_dict
976
+ assert arguments == {
977
+ "TransformInput" : {
978
+ "DataSource" : {
979
+ "S3DataSource" : {
980
+ "S3DataType" : "S3Prefix" ,
981
+ "S3Uri" : f"s3://{ BUCKET } /transform_manifest" ,
982
+ }
983
+ }
984
+ },
985
+ "TransformOutput" : {"S3OutputPath" : None },
986
+ "TransformResources" : {"InstanceCount" : 1 , "InstanceType" : "ml.c4.4xlarge" },
987
+ }
988
+ else :
989
+ raise Exception ("A step exists in the collection of an invalid type." )
0 commit comments