Skip to content

Commit 29c95ac

Browse files
author
Basil Beirouti
committed
added unit test
1 parent 89493cf commit 29c95ac

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

tests/unit/sagemaker/model/test_model.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,47 @@
4848
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
4949

5050

51+
MODEL_DESCRIPTION = "a description"
52+
53+
SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES = ["ml.m4.xlarge"]
54+
SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES = ["ml.m4.xlarge"]
55+
56+
SUPPORTED_CONTENT_TYPES = ["text/csv", "application/json", "application/jsonlines"]
57+
SUPPORTED_RESPONSE_MIME_TYPES = ["application/json", "text/csv", "application/jsonlines"]
58+
59+
VALIDATION_FILE_NAME = "input.csv"
60+
VALIDATION_INPUT_PATH = "s3://" + BUCKET_NAME + "/validation-input-csv/"
61+
VALIDATION_OUTPUT_PATH = "s3://" + BUCKET_NAME + "/validation-output-csv/"
62+
63+
VALIDATION_SPECIFICATION = {
64+
"ValidationRole": "some_role",
65+
"ValidationProfiles": [
66+
{
67+
"ProfileName": "Validation-test",
68+
"TransformJobDefinition": {
69+
"BatchStrategy": "SingleRecord",
70+
"TransformInput": {
71+
"DataSource": {
72+
"S3DataSource": {
73+
"S3DataType": "S3Prefix",
74+
"S3Uri": VALIDATION_INPUT_PATH,
75+
}
76+
},
77+
"ContentType": SUPPORTED_CONTENT_TYPES[0],
78+
},
79+
"TransformOutput": {
80+
"S3OutputPath": VALIDATION_OUTPUT_PATH,
81+
},
82+
"TransformResources": {
83+
"InstanceType": SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES[0],
84+
"InstanceCount": 1,
85+
},
86+
},
87+
},
88+
],
89+
}
90+
91+
5192
class DummyFrameworkModel(FrameworkModel):
5293
def __init__(self, **kwargs):
5394
super(DummyFrameworkModel, self).__init__(
@@ -687,3 +728,40 @@ def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagema
687728
]
688729
== "/opt/ml/model/code"
689730
)
731+
732+
733+
@patch("sagemaker.get_model_package_args")
734+
def test_call_to_get_model_package_args(get_model_package_args, sagemaker_session):
735+
736+
source_dir = "s3://blah/blah/blah"
737+
t = Model(
738+
entry_point=ENTRY_POINT_INFERENCE,
739+
role=ROLE,
740+
sagemaker_session=sagemaker_session,
741+
source_dir=source_dir,
742+
image_uri=IMAGE_URI,
743+
model_data=MODEL_DATA,
744+
)
745+
746+
t.register(
747+
SUPPORTED_CONTENT_TYPES,
748+
SUPPORTED_RESPONSE_MIME_TYPES,
749+
SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES,
750+
SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES,
751+
marketplace_cert=True,
752+
description=MODEL_DESCRIPTION,
753+
model_package_name=MODEL_NAME,
754+
validation_specification=VALIDATION_SPECIFICATION,
755+
)
756+
757+
# check that the kwarg validation_specification was passed to the internal method 'get_model_package_args'
758+
assert (
759+
"validation_specification" in get_model_package_args.call_args_list[0][1]
760+
), "validation_specification kwarg was not passed to get_model_package_args"
761+
762+
# check that the kwarg validation_specification is identical to the one passed into the method 'register'
763+
assert (
764+
VALIDATION_SPECIFICATION
765+
== get_model_package_args.call_args_list[0][1]["validation_specification"]
766+
), """ValidationSpecification from model.register method is not identical to validation_spec from
767+
get_model_package_args"""

0 commit comments

Comments
 (0)