Skip to content

Commit ba40cb7

Browse files
author
Eugene Teoh
committed
fix: add estimator transformer with model repack test
1 parent 6317caa commit ba40cb7

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,3 +867,123 @@ def test_estimator_transformer(sagemaker_session):
867867
}
868868
else:
869869
raise Exception("A step exists in the collection of an invalid type.")
870+
871+
def test_estimator_transformer_with_model_repack(sagemaker_session):
872+
model_data = f"s3://{BUCKET}/model.tar.gz"
873+
model_inputs = CreateModelInput(
874+
instance_type="c4.4xlarge",
875+
accelerator_type="ml.eia1.medium",
876+
)
877+
service_fault_retry_policy = StepRetryPolicy(
878+
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10
879+
)
880+
transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
881+
dummy_requirements = f"{DATA_DIR}/dummy_requirements.txt"
882+
estimator_transformer = EstimatorTransformer(
883+
name="EstimatorTransformerStep",
884+
model_data=model_data,
885+
model_inputs=model_inputs,
886+
instance_count=1,
887+
instance_type="ml.c4.4xlarge",
888+
transform_inputs=transform_inputs,
889+
depends_on=["TestStep"],
890+
model_step_retry_policies=[service_fault_retry_policy],
891+
transform_step_retry_policies=[service_fault_retry_policy],
892+
repack_model_step_retry_policies=[service_fault_retry_policy],
893+
image_uri=IMAGE_URI,
894+
sagemaker_session=sagemaker_session,
895+
role=ROLE,
896+
entry_point=f"{DATA_DIR}/dummy_script.py",
897+
dependencies=[dummy_requirements]
898+
)
899+
request_dicts = estimator_transformer.request_dicts()
900+
assert len(request_dicts) == 3
901+
902+
for request_dict in request_dicts:
903+
if request_dict["Type"] == "Training":
904+
assert request_dict["Name"] == "EstimatorTransformerStepRepackModel"
905+
assert len(request_dict["DependsOn"]) == 1
906+
assert request_dict["DependsOn"][0] == "TestStep"
907+
arguments = request_dict["Arguments"]
908+
repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"]
909+
assert ordered(arguments) == ordered(
910+
{
911+
"AlgorithmSpecification": {
912+
"TrainingImage": MODEL_REPACKING_IMAGE_URI,
913+
"TrainingInputMode": "File",
914+
},
915+
"DebugHookConfig": {
916+
"CollectionConfigurations": [],
917+
"S3OutputPath": f"s3://{BUCKET}/",
918+
},
919+
"HyperParameters": {
920+
"inference_script": '"dummy_script.py"',
921+
"dependencies": f'"{dummy_requirements}"',
922+
"model_archive": '"model.tar.gz"',
923+
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
924+
BUCKET, repacker_job_name.replace('"', "")
925+
),
926+
"sagemaker_program": '"_repack_model.py"',
927+
"sagemaker_container_log_level": "20",
928+
"sagemaker_job_name": repacker_job_name,
929+
"sagemaker_region": f'"{REGION}"',
930+
"source_dir": "null",
931+
},
932+
"InputDataConfig": [
933+
{
934+
"ChannelName": "training",
935+
"DataSource": {
936+
"S3DataSource": {
937+
"S3DataDistributionType": "FullyReplicated",
938+
"S3DataType": "S3Prefix",
939+
"S3Uri": f"s3://{BUCKET}",
940+
}
941+
},
942+
}
943+
],
944+
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
945+
"ResourceConfig": {
946+
"InstanceCount": 1,
947+
"InstanceType": "ml.m5.large",
948+
"VolumeSizeInGB": 30,
949+
},
950+
"RoleArn": ROLE,
951+
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
952+
}
953+
)
954+
955+
elif request_dict["Type"] == "Model":
956+
assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep"
957+
assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()]
958+
arguments = request_dict["Arguments"]
959+
assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties)
960+
arguments["PrimaryContainer"].pop("ModelDataUrl")
961+
assert arguments == {
962+
"ExecutionRoleArn": "DummyRole",
963+
"PrimaryContainer": {
964+
"Environment": {},
965+
"Image": "fakeimage",
966+
}
967+
}
968+
969+
elif request_dict["Type"] == "Transform":
970+
assert request_dict["Name"] == "EstimatorTransformerStepTransformStep"
971+
assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()]
972+
arguments = request_dict["Arguments"]
973+
assert isinstance(arguments["ModelName"], Properties)
974+
arguments.pop("ModelName")
975+
assert "DependsOn" not in request_dict
976+
assert arguments == {
977+
"TransformInput": {
978+
"DataSource": {
979+
"S3DataSource": {
980+
"S3DataType": "S3Prefix",
981+
"S3Uri": f"s3://{BUCKET}/transform_manifest",
982+
}
983+
}
984+
},
985+
"TransformOutput": {"S3OutputPath": None},
986+
"TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"},
987+
}
988+
else:
989+
raise Exception("A step exists in the collection of an invalid type.")

0 commit comments

Comments
 (0)