@@ -802,6 +802,32 @@ def test_register_model_with_model_repack_with_pipeline_model(
802
802
raise Exception ("A step exists in the collection of an invalid type." )
803
803
804
804
805
+ def test_register_model_with_model_repack_with_repack_output_path (model ):
806
+ repack_output_path = "s3://{BUCKET}/repack_output"
807
+ register_model = RegisterModel (
808
+ name = "RegisterModelStep" ,
809
+ model = model ,
810
+ content_types = ["content_type" ],
811
+ response_types = ["response_type" ],
812
+ inference_instances = ["inference_instance" ],
813
+ transform_instances = ["transform_instance" ],
814
+ model_package_group_name = "mpg" ,
815
+ approval_status = "Approved" ,
816
+ description = "description" ,
817
+ depends_on = ["TestStep" ],
818
+ tags = [{"Key" : "myKey" , "Value" : "myValue" }],
819
+ repack_output_path = repack_output_path ,
820
+ )
821
+
822
+ request_dicts = register_model .request_dicts ()
823
+
824
+ for request_dict in request_dicts :
825
+ if request_dict ["Type" ] == "Training" :
826
+ arguments = request_dict ["Arguments" ]
827
+ assert arguments ["DebugHookConfig" ]["S3OutputPath" ] == repack_output_path
828
+ assert arguments ["OutputDataConfig" ]["S3OutputPath" ] == repack_output_path
829
+
830
+
805
831
def test_estimator_transformer (estimator ):
806
832
model_data = f"s3://{ BUCKET } /model.tar.gz"
807
833
model_inputs = CreateModelInput (
@@ -983,3 +1009,33 @@ def test_estimator_transformer_with_model_repack(estimator):
983
1009
)
984
1010
else :
985
1011
raise Exception ("A step exists in the collection of an invalid type." )
1012
+
1013
+
1014
+ def test_estimator_transformer_with_model_repack_with_repack_output_path (estimator ):
1015
+ repack_output_path = "s3://{BUCKET}/repack_output"
1016
+ model_data = f"s3://{ BUCKET } /model.tar.gz"
1017
+ model_inputs = CreateModelInput (
1018
+ instance_type = "c4.4xlarge" ,
1019
+ accelerator_type = "ml.eia1.medium" ,
1020
+ )
1021
+ transform_inputs = TransformInput (data = f"s3://{ BUCKET } /transform_manifest" )
1022
+ estimator_transformer = EstimatorTransformer (
1023
+ name = "EstimatorTransformerStep" ,
1024
+ estimator = estimator ,
1025
+ model_data = model_data ,
1026
+ model_inputs = model_inputs ,
1027
+ instance_count = 1 ,
1028
+ instance_type = "ml.c4.4xlarge" ,
1029
+ transform_inputs = transform_inputs ,
1030
+ depends_on = ["TestStep" ],
1031
+ entry_point = f"{ DATA_DIR } /dummy_script.py" ,
1032
+ repack_output_path = repack_output_path ,
1033
+ )
1034
+
1035
+ request_dicts = estimator_transformer .request_dicts ()
1036
+
1037
+ for request_dict in request_dicts :
1038
+ if request_dict ["Type" ] == "Training" :
1039
+ arguments = request_dict ["Arguments" ]
1040
+ assert arguments ["DebugHookConfig" ]["S3OutputPath" ] == repack_output_path
1041
+ assert arguments ["OutputDataConfig" ]["S3OutputPath" ] == repack_output_path
0 commit comments