Skip to content

Commit 95cba4d

Browse files
BasilBeiroutiBasil Beirouti
and
Basil Beirouti
authored
feat: add validation specification (#3075)
Co-authored-by: Basil Beirouti <[email protected]>
1 parent d79a3f2 commit 95cba4d

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

src/sagemaker/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def register(
305305
description=None,
306306
drift_check_baselines=None,
307307
customer_metadata_properties=None,
308+
validation_specification=None,
308309
):
309310
"""Creates a model package for creating SageMaker models or listing on Marketplace.
310311
@@ -360,6 +361,7 @@ def register(
360361
container_def_list=[container_def],
361362
drift_check_baselines=drift_check_baselines,
362363
customer_metadata_properties=customer_metadata_properties,
364+
validation_specification=validation_specification,
363365
)
364366
model_package = self.sagemaker_session.create_model_package_from_containers(
365367
**model_pkg_args

src/sagemaker/session.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,6 +2801,7 @@ def create_model_package_from_containers(
28012801
description=None,
28022802
drift_check_baselines=None,
28032803
customer_metadata_properties=None,
2804+
validation_specification=None,
28042805
):
28052806
"""Get request dictionary for CreateModelPackage API.
28062807
@@ -2846,6 +2847,7 @@ def create_model_package_from_containers(
28462847
description,
28472848
drift_check_baselines=drift_check_baselines,
28482849
customer_metadata_properties=customer_metadata_properties,
2850+
validation_specification=validation_specification,
28492851
)
28502852
if model_package_group_name is not None:
28512853
try:
@@ -4206,6 +4208,7 @@ def get_model_package_args(
42064208
container_def_list=None,
42074209
drift_check_baselines=None,
42084210
customer_metadata_properties=None,
4211+
validation_specification=None,
42094212
):
42104213
"""Get arguments for create_model_package method.
42114214
@@ -4275,6 +4278,8 @@ def get_model_package_args(
42754278
model_package_args["tags"] = tags
42764279
if customer_metadata_properties is not None:
42774280
model_package_args["customer_metadata_properties"] = customer_metadata_properties
4281+
if validation_specification is not None:
4282+
model_package_args["validation_specification"] = validation_specification
42784283
return model_package_args
42794284

42804285

@@ -4294,6 +4299,7 @@ def get_create_model_package_request(
42944299
tags=None,
42954300
drift_check_baselines=None,
42964301
customer_metadata_properties=None,
4302+
validation_specification=None,
42974303
):
42984304
"""Get request dictionary for CreateModelPackage API.
42994305
@@ -4345,6 +4351,8 @@ def get_create_model_package_request(
43454351
request_dict["MetadataProperties"] = metadata_properties
43464352
if customer_metadata_properties is not None:
43474353
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
4354+
if validation_specification:
4355+
request_dict["ValidationSpecification"] = validation_specification
43484356
if containers is not None:
43494357
if not all([content_types, response_types, inference_instances, transform_instances]):
43504358
raise ValueError(

tests/unit/sagemaker/model/test_model.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,47 @@
4848
IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
4949

5050

51+
MODEL_DESCRIPTION = "a description"
52+
53+
SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES = ["ml.m4.xlarge"]
54+
SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES = ["ml.m4.xlarge"]
55+
56+
SUPPORTED_CONTENT_TYPES = ["text/csv", "application/json", "application/jsonlines"]
57+
SUPPORTED_RESPONSE_MIME_TYPES = ["application/json", "text/csv", "application/jsonlines"]
58+
59+
VALIDATION_FILE_NAME = "input.csv"
60+
VALIDATION_INPUT_PATH = "s3://" + BUCKET_NAME + "/validation-input-csv/"
61+
VALIDATION_OUTPUT_PATH = "s3://" + BUCKET_NAME + "/validation-output-csv/"
62+
63+
VALIDATION_SPECIFICATION = {
64+
"ValidationRole": "some_role",
65+
"ValidationProfiles": [
66+
{
67+
"ProfileName": "Validation-test",
68+
"TransformJobDefinition": {
69+
"BatchStrategy": "SingleRecord",
70+
"TransformInput": {
71+
"DataSource": {
72+
"S3DataSource": {
73+
"S3DataType": "S3Prefix",
74+
"S3Uri": VALIDATION_INPUT_PATH,
75+
}
76+
},
77+
"ContentType": SUPPORTED_CONTENT_TYPES[0],
78+
},
79+
"TransformOutput": {
80+
"S3OutputPath": VALIDATION_OUTPUT_PATH,
81+
},
82+
"TransformResources": {
83+
"InstanceType": SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES[0],
84+
"InstanceCount": 1,
85+
},
86+
},
87+
},
88+
],
89+
}
90+
91+
5192
class DummyFrameworkModel(FrameworkModel):
5293
def __init__(self, **kwargs):
5394
super(DummyFrameworkModel, self).__init__(
@@ -687,3 +728,40 @@ def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagema
687728
]
688729
== "/opt/ml/model/code"
689730
)
731+
732+
733+
@patch("sagemaker.get_model_package_args")
734+
def test_register_calls_model_package_args(get_model_package_args, sagemaker_session):
735+
736+
source_dir = "s3://blah/blah/blah"
737+
t = Model(
738+
entry_point=ENTRY_POINT_INFERENCE,
739+
role=ROLE,
740+
sagemaker_session=sagemaker_session,
741+
source_dir=source_dir,
742+
image_uri=IMAGE_URI,
743+
model_data=MODEL_DATA,
744+
)
745+
746+
t.register(
747+
SUPPORTED_CONTENT_TYPES,
748+
SUPPORTED_RESPONSE_MIME_TYPES,
749+
SUPPORTED_REALTIME_INFERENCE_INSTANCE_TYPES,
750+
SUPPORTED_BATCH_TRANSFORM_INSTANCE_TYPES,
751+
marketplace_cert=True,
752+
description=MODEL_DESCRIPTION,
753+
model_package_name=MODEL_NAME,
754+
validation_specification=VALIDATION_SPECIFICATION,
755+
)
756+
757+
# check that the kwarg validation_specification was passed to the internal method 'get_model_package_args'
758+
assert (
759+
"validation_specification" in get_model_package_args.call_args_list[0][1]
760+
), "validation_specification kwarg was not passed to get_model_package_args"
761+
762+
# check that the kwarg validation_specification is identical to the one passed into the method 'register'
763+
assert (
764+
VALIDATION_SPECIFICATION
765+
== get_model_package_args.call_args_list[0][1]["validation_specification"]
766+
), """ValidationSpecification from model.register method is not identical to validation_spec from
767+
get_model_package_args"""

0 commit comments

Comments
 (0)