Skip to content

Commit 87937ad

Browse files
authored
feature: add 'Domain' property to RegisterModel step (#3118)
1 parent 255a339 commit 87937ad

File tree

12 files changed

+57
-1
lines changed

12 files changed

+57
-1
lines changed

src/sagemaker/estimator.py

+4
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ def register(
12801280
model_name=None,
12811281
drift_check_baselines=None,
12821282
customer_metadata_properties=None,
1283+
domain=None,
12831284
**kwargs,
12841285
):
12851286
"""Creates a model package for creating SageMaker models or listing on Marketplace.
@@ -1311,6 +1312,8 @@ def register(
13111312
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
13121313
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
13131314
metadata properties (default: None).
1315+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
1316+
"MACHINE_LEARNING" (default: None).
13141317
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
13151318
``create_model()`` to accept ``**kwargs`` to customize model creation during
13161319
deploy. For more, see the implementation docs.
@@ -1342,6 +1345,7 @@ def register(
13421345
description,
13431346
drift_check_baselines=drift_check_baselines,
13441347
customer_metadata_properties=customer_metadata_properties,
1348+
domain=domain,
13451349
)
13461350

13471351
@property

src/sagemaker/huggingface/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def register(
304304
approval_status=None,
305305
description=None,
306306
drift_check_baselines=None,
307+
domain=None,
307308
):
308309
"""Creates a model package for creating SageMaker models or listing on Marketplace.
309310
@@ -331,6 +332,8 @@ def register(
331332
or "PendingManualApproval". Defaults to ``PendingManualApproval``.
332333
description (str): Model Package description. Defaults to ``None``.
333334
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
335+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
336+
"MACHINE_LEARNING" (default: None).
334337
335338
Returns:
336339
A `sagemaker.model.ModelPackage` instance.
@@ -359,6 +362,7 @@ def register(
359362
approval_status,
360363
description,
361364
drift_check_baselines=drift_check_baselines,
365+
domain=domain,
362366
)
363367

364368
def prepare_container_def(

src/sagemaker/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def register(
309309
drift_check_baselines=None,
310310
customer_metadata_properties=None,
311311
validation_specification=None,
312+
domain=None,
312313
):
313314
"""Creates a model package for creating SageMaker models or listing on Marketplace.
314315
@@ -336,6 +337,8 @@ def register(
336337
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
337338
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
338339
metadata properties (default: None).
340+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
341+
"MACHINE_LEARNING" (default: None).
339342
340343
Returns:
341344
A `sagemaker.model.ModelPackage` instance.
@@ -365,6 +368,7 @@ def register(
365368
drift_check_baselines=drift_check_baselines,
366369
customer_metadata_properties=customer_metadata_properties,
367370
validation_specification=validation_specification,
371+
domain=domain,
368372
)
369373
model_package = self.sagemaker_session.create_model_package_from_containers(
370374
**model_pkg_args

src/sagemaker/mxnet/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def register(
158158
description=None,
159159
drift_check_baselines=None,
160160
customer_metadata_properties=None,
161+
domain=None,
161162
):
162163
"""Creates a model package for creating SageMaker models or listing on Marketplace.
163164
@@ -185,6 +186,8 @@ def register(
185186
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
186187
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
187188
metadata properties (default: None).
189+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
190+
"MACHINE_LEARNING" (default: None).
188191
189192
Returns:
190193
A `sagemaker.model.ModelPackage` instance.
@@ -214,6 +217,7 @@ def register(
214217
description,
215218
drift_check_baselines=drift_check_baselines,
216219
customer_metadata_properties=customer_metadata_properties,
220+
domain=domain,
217221
)
218222

219223
def prepare_container_def(

src/sagemaker/pytorch/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def register(
159159
description=None,
160160
drift_check_baselines=None,
161161
customer_metadata_properties=None,
162+
domain=None,
162163
):
163164
"""Creates a model package for creating SageMaker models or listing on Marketplace.
164165
@@ -186,6 +187,8 @@ def register(
186187
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
187188
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
188189
metadata properties (default: None).
190+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
191+
"MACHINE_LEARNING" (default: None).
189192
190193
Returns:
191194
A `sagemaker.model.ModelPackage` instance.
@@ -215,6 +218,7 @@ def register(
215218
description,
216219
drift_check_baselines=drift_check_baselines,
217220
customer_metadata_properties=customer_metadata_properties,
221+
domain=domain,
218222
)
219223

220224
def prepare_container_def(

src/sagemaker/session.py

+14
Original file line numberDiff line numberDiff line change
@@ -2803,6 +2803,7 @@ def create_model_package_from_containers(
28032803
drift_check_baselines=None,
28042804
customer_metadata_properties=None,
28052805
validation_specification=None,
2806+
domain=None,
28062807
):
28072808
"""Get request dictionary for CreateModelPackage API.
28082809
@@ -2830,6 +2831,8 @@ def create_model_package_from_containers(
28302831
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
28312832
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
28322833
metadata properties (default: None).
2834+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
2835+
"MACHINE_LEARNING" (default: None).
28332836
"""
28342837

28352838
model_pkg_request = get_create_model_package_request(
@@ -2848,6 +2851,7 @@ def create_model_package_from_containers(
28482851
drift_check_baselines=drift_check_baselines,
28492852
customer_metadata_properties=customer_metadata_properties,
28502853
validation_specification=validation_specification,
2854+
domain=domain,
28512855
)
28522856

28532857
def submit(request):
@@ -4218,6 +4222,7 @@ def get_model_package_args(
42184222
drift_check_baselines=None,
42194223
customer_metadata_properties=None,
42204224
validation_specification=None,
4225+
domain=None,
42214226
):
42224227
"""Get arguments for create_model_package method.
42234228
@@ -4248,6 +4253,8 @@ def get_model_package_args(
42484253
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
42494254
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
42504255
metadata properties (default: None).
4256+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
4257+
"MACHINE_LEARNING" (default: None).
42514258
Returns:
42524259
dict: A dictionary of method argument names and values.
42534260
"""
@@ -4289,6 +4296,8 @@ def get_model_package_args(
42894296
model_package_args["customer_metadata_properties"] = customer_metadata_properties
42904297
if validation_specification is not None:
42914298
model_package_args["validation_specification"] = validation_specification
4299+
if domain is not None:
4300+
model_package_args["domain"] = domain
42924301
return model_package_args
42934302

42944303

@@ -4309,6 +4318,7 @@ def get_create_model_package_request(
43094318
drift_check_baselines=None,
43104319
customer_metadata_properties=None,
43114320
validation_specification=None,
4321+
domain=None,
43124322
):
43134323
"""Get request dictionary for CreateModelPackage API.
43144324
@@ -4337,6 +4347,8 @@ def get_create_model_package_request(
43374347
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
43384348
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
43394349
metadata properties (default: None).
4350+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
4351+
"MACHINE_LEARNING" (default: None).
43404352
"""
43414353

43424354
if all([model_package_name, model_package_group_name]):
@@ -4362,6 +4374,8 @@ def get_create_model_package_request(
43624374
request_dict["CustomerMetadataProperties"] = customer_metadata_properties
43634375
if validation_specification:
43644376
request_dict["ValidationSpecification"] = validation_specification
4377+
if domain is not None:
4378+
request_dict["Domain"] = domain
43654379
if containers is not None:
43664380
if not all([content_types, response_types, inference_instances, transform_instances]):
43674381
raise ValueError(

src/sagemaker/sklearn/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def register(
151151
marketplace_cert=False,
152152
approval_status=None,
153153
description=None,
154+
domain=None,
154155
):
155156
"""Creates a model package for creating SageMaker models or listing on Marketplace.
156157
@@ -175,6 +176,8 @@ def register(
175176
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
176177
or "PendingManualApproval" (default: "PendingManualApproval").
177178
description (str): Model Package description (default: None).
179+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
180+
"MACHINE_LEARNING" (default: None).
178181
179182
Returns:
180183
A `sagemaker.model.ModelPackage` instance.
@@ -202,6 +205,7 @@ def register(
202205
marketplace_cert,
203206
approval_status,
204207
description,
208+
domain=domain,
205209
)
206210

207211
def prepare_container_def(

src/sagemaker/tensorflow/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def register(
205205
description=None,
206206
drift_check_baselines=None,
207207
customer_metadata_properties=None,
208+
domain=None,
208209
):
209210
"""Creates a model package for creating SageMaker models or listing on Marketplace.
210211
@@ -232,7 +233,8 @@ def register(
232233
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
233234
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
234235
metadata properties (default: None).
235-
236+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
237+
"MACHINE_LEARNING" (default: None).
236238
237239
Returns:
238240
A `sagemaker.model.ModelPackage` instance.
@@ -262,6 +264,7 @@ def register(
262264
description,
263265
drift_check_baselines=drift_check_baselines,
264266
customer_metadata_properties=customer_metadata_properties,
267+
domain=domain,
265268
)
266269

267270
def deploy(

src/sagemaker/workflow/_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __init__(
284284
container_def_list=None,
285285
drift_check_baselines=None,
286286
customer_metadata_properties=None,
287+
domain=None,
287288
**kwargs,
288289
):
289290
"""Constructor of a register model step.
@@ -326,6 +327,8 @@ def __init__(
326327
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
327328
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
328329
metadata properties (default: None).
330+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
331+
"MACHINE_LEARNING" (default: None).
329332
**kwargs: additional arguments to `create_model`.
330333
"""
331334
super(_RegisterModelStep, self).__init__(
@@ -356,6 +359,7 @@ def __init__(
356359
self.model_metrics = model_metrics
357360
self.drift_check_baselines = drift_check_baselines
358361
self.customer_metadata_properties = customer_metadata_properties
362+
self.domain = domain
359363
self.metadata_properties = metadata_properties
360364
self.approval_status = approval_status
361365
self.image_uri = image_uri
@@ -433,6 +437,7 @@ def arguments(self) -> RequestType:
433437
tags=self.tags,
434438
container_def_list=self.container_def_list,
435439
customer_metadata_properties=self.customer_metadata_properties,
440+
domain=self.domain,
436441
)
437442

438443
request_dict = get_create_model_package_request(**model_package_args)

src/sagemaker/workflow/step_collections.py

+4
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
model: Union[Model, PipelineModel] = None,
8282
drift_check_baselines=None,
8383
customer_metadata_properties=None,
84+
domain=None,
8485
**kwargs,
8586
):
8687
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -122,6 +123,8 @@ def __init__(
122123
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
123124
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
124125
metadata properties (default: None).
126+
domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING",
127+
"MACHINE_LEARNING" (default: None).
125128
126129
**kwargs: additional arguments to `create_model`.
127130
"""
@@ -241,6 +244,7 @@ def __init__(
241244
container_def_list=self.container_def_list,
242245
retry_policies=register_model_step_retry_policies,
243246
customer_metadata_properties=customer_metadata_properties,
247+
domain=domain,
244248
**kwargs,
245249
)
246250
if not repack_model:

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

+3
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def test_model_registration_with_drift_check_baselines(
550550
),
551551
)
552552
customer_metadata_properties = {"key1": "value1"}
553+
domain = "COMPUTER_VISION"
553554
estimator = XGBoost(
554555
entry_point="training.py",
555556
source_dir=os.path.join(DATA_DIR, "sip"),
@@ -572,6 +573,7 @@ def test_model_registration_with_drift_check_baselines(
572573
model_metrics=model_metrics,
573574
drift_check_baselines=drift_check_baselines,
574575
customer_metadata_properties=customer_metadata_properties,
576+
domain=domain,
575577
)
576578

577579
pipeline = Pipeline(
@@ -643,6 +645,7 @@ def test_model_registration_with_drift_check_baselines(
643645
== "application/json"
644646
)
645647
assert response["CustomerMetadataProperties"] == customer_metadata_properties
648+
assert response["Domain"] == domain
646649
break
647650
finally:
648651
try:

tests/unit/test_session.py

+3
Original file line numberDiff line numberDiff line change
@@ -2386,6 +2386,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
23862386
approval_status = ("Approved",)
23872387
description = "description"
23882388
customer_metadata_properties = {"key1": "value1"}
2389+
domain = "COMPUTER_VISION"
23892390
sagemaker_session.create_model_package_from_containers(
23902391
containers=containers,
23912392
content_types=content_types,
@@ -2400,6 +2401,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
24002401
description=description,
24012402
drift_check_baselines=drift_check_baselines,
24022403
customer_metadata_properties=customer_metadata_properties,
2404+
domain=domain,
24032405
)
24042406
expected_args = {
24052407
"ModelPackageName": model_package_name,
@@ -2417,6 +2419,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session):
24172419
"ModelApprovalStatus": approval_status,
24182420
"DriftCheckBaselines": drift_check_baselines,
24192421
"CustomerMetadataProperties": customer_metadata_properties,
2422+
"Domain": domain,
24202423
}
24212424
sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)
24222425

0 commit comments

Comments
 (0)