Skip to content

Commit faf4ad5

Browse files
feature: Add tests for RegisterModel with repack output
1 parent ee6afcf commit faf4ad5

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,32 @@ def test_register_model_with_model_repack_with_pipeline_model(
802802
raise Exception("A step exists in the collection of an invalid type.")
803803

804804

805+
def test_register_model_with_model_repack_with_repack_output_path(model):
806+
repack_output_path = "s3://{BUCKET}/repack_output"
807+
register_model = RegisterModel(
808+
name="RegisterModelStep",
809+
model=model,
810+
content_types=["content_type"],
811+
response_types=["response_type"],
812+
inference_instances=["inference_instance"],
813+
transform_instances=["transform_instance"],
814+
model_package_group_name="mpg",
815+
approval_status="Approved",
816+
description="description",
817+
depends_on=["TestStep"],
818+
tags=[{"Key": "myKey", "Value": "myValue"}],
819+
repack_output_path=repack_output_path,
820+
)
821+
822+
request_dicts = register_model.request_dicts()
823+
824+
for request_dict in request_dicts:
825+
if request_dict["Type"] == "Training":
826+
arguments = request_dict["Arguments"]
827+
assert arguments["DebugHookConfig"]["S3OutputPath"] == repack_output_path
828+
assert arguments["OutputDataConfig"]["S3OutputPath"] == repack_output_path
829+
830+
805831
def test_estimator_transformer(estimator):
806832
model_data = f"s3://{BUCKET}/model.tar.gz"
807833
model_inputs = CreateModelInput(
@@ -983,3 +1009,33 @@ def test_estimator_transformer_with_model_repack(estimator):
9831009
)
9841010
else:
9851011
raise Exception("A step exists in the collection of an invalid type.")
1012+
1013+
1014+
def test_estimator_transformer_with_model_repack_with_repack_output_path(estimator):
1015+
repack_output_path = "s3://{BUCKET}/repack_output"
1016+
model_data = f"s3://{BUCKET}/model.tar.gz"
1017+
model_inputs = CreateModelInput(
1018+
instance_type="c4.4xlarge",
1019+
accelerator_type="ml.eia1.medium",
1020+
)
1021+
transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
1022+
estimator_transformer = EstimatorTransformer(
1023+
name="EstimatorTransformerStep",
1024+
estimator=estimator,
1025+
model_data=model_data,
1026+
model_inputs=model_inputs,
1027+
instance_count=1,
1028+
instance_type="ml.c4.4xlarge",
1029+
transform_inputs=transform_inputs,
1030+
depends_on=["TestStep"],
1031+
entry_point=f"{DATA_DIR}/dummy_script.py",
1032+
repack_output_path=repack_output_path,
1033+
)
1034+
1035+
request_dicts = estimator_transformer.request_dicts()
1036+
1037+
for request_dict in request_dicts:
1038+
if request_dict["Type"] == "Training":
1039+
arguments = request_dict["Arguments"]
1040+
assert arguments["DebugHookConfig"]["S3OutputPath"] == repack_output_path
1041+
assert arguments["OutputDataConfig"]["S3OutputPath"] == repack_output_path

0 commit comments

Comments
 (0)