Skip to content

Commit 9517334

Browse files
author
Basil Beirouti
committed
add validation specification
1 parent 1d151c2 commit 9517334

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

src/sagemaker/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def register(
305305
description=None,
306306
drift_check_baselines=None,
307307
customer_metadata_properties=None,
308+
validation_specification=None
308309
):
309310
"""Creates a model package for creating SageMaker models or listing on Marketplace.
310311
@@ -360,6 +361,7 @@ def register(
360361
container_def_list=[container_def],
361362
drift_check_baselines=drift_check_baselines,
362363
customer_metadata_properties=customer_metadata_properties,
364+
validation_specification=validation_specification
363365
)
364366
model_package = self.sagemaker_session.create_model_package_from_containers(
365367
**model_pkg_args

src/sagemaker/session.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,6 +2801,7 @@ def create_model_package_from_containers(
28012801
description=None,
28022802
drift_check_baselines=None,
28032803
customer_metadata_properties=None,
2804+
validation_specification=None
28042805
):
28052806
"""Get request dictionary for CreateModelPackage API.
28062807
@@ -2846,6 +2847,7 @@ def create_model_package_from_containers(
28462847
description,
28472848
drift_check_baselines=drift_check_baselines,
28482849
customer_metadata_properties=customer_metadata_properties,
2850+
validation_specification=validation_specification
28492851
)
28502852
if model_package_group_name is not None:
28512853
try:
@@ -4206,6 +4208,7 @@ def get_model_package_args(
42064208
container_def_list=None,
42074209
drift_check_baselines=None,
42084210
customer_metadata_properties=None,
4211+
validation_specification=None
42094212
):
42104213
"""Get arguments for create_model_package method.
42114214
@@ -4275,6 +4278,8 @@ def get_model_package_args(
42754278
model_package_args["tags"] = tags
42764279
if customer_metadata_properties is not None:
42774280
model_package_args["customer_metadata_properties"] = customer_metadata_properties
4281+
if validation_specification is not None:
4282+
model_package_args["validation_specification"] = validation_specification
42784283
return model_package_args
42794284

42804285

@@ -4294,6 +4299,7 @@ def get_create_model_package_request(
42944299
tags=None,
42954300
drift_check_baselines=None,
42964301
customer_metadata_properties=None,
4302+
validation_specification=None
42974303
):
42984304
"""Get request dictionary for CreateModelPackage API.
42994305
@@ -4345,6 +4351,8 @@ def get_create_model_package_request(
43454351
request_dict["MetadataProperties"] = metadata_properties
43464352
if customer_metadata_properties is not None:
43474353
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
4354+
if validation_specification:
4355+
request_dict["ValidationSpecification"] = validation_specification
43484356
if containers is not None:
43494357
if not all([content_types, response_types, inference_instances, transform_instances]):
43504358
raise ValueError(

0 commit comments

Comments
 (0)