@@ -544,7 +544,8 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
544
544
self .sagemaker_client .create_transform_job (** transform_request )
545
545
546
546
def create_model (self , name , role , container_defs , vpc_config = None ,
547
- enable_network_isolation = False , primary_container = None ):
547
+ enable_network_isolation = False , primary_container = None ,
548
+ tags = None ):
548
549
"""Create an Amazon SageMaker ``Model``.
549
550
Specify the S3 location of the model artifacts and Docker image containing
550
551
the inference code. Amazon SageMaker uses this information to deploy the
@@ -570,6 +571,11 @@ def create_model(self, name, role, container_defs, vpc_config=None,
570
571
You can also specify the return value of ``sagemaker.container_def()``, which is used to create
571
572
more advanced container configurations, including model containers which need artifacts from S3. This
572
573
field is deprecated, please use container_defs instead.
574
+ tags(List[dict[str, str]]): Optional. The list of tags to add to the model. Example:
575
+ >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
576
+ For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
577
+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
578
+
573
579
574
580
Returns:
575
581
str: Name of the Amazon SageMaker ``Model`` created.
@@ -583,12 +589,16 @@ def create_model(self, name, role, container_defs, vpc_config=None,
583
589
container_defs = primary_container
584
590
585
591
role = self .expand_role (role )
586
- create_model_request = {}
592
+
587
593
if isinstance (container_defs , list ):
588
- create_model_request = _create_model_request ( name = name , role = role , container_def = container_defs )
594
+ container_definition = container_defs
589
595
else :
590
- primary_container = _expand_container_def (container_defs )
591
- create_model_request = _create_model_request (name = name , role = role , container_def = primary_container )
596
+ container_definition = _expand_container_def (container_defs )
597
+
598
+ create_model_request = _create_model_request (name = name ,
599
+ role = role ,
600
+ container_def = container_definition ,
601
+ tags = tags )
592
602
593
603
if vpc_config :
594
604
create_model_request ['VpcConfig' ] = vpc_config
@@ -702,7 +712,8 @@ def wait_for_model_package(self, model_package_name, poll=5):
702
712
model_package_name , status , reason ))
703
713
return desc
704
714
705
- def create_endpoint_config (self , name , model_name , initial_instance_count , instance_type , accelerator_type = None ):
715
+ def create_endpoint_config (self , name , model_name , initial_instance_count , instance_type ,
716
+ accelerator_type = None , tags = None ):
706
717
"""Create an Amazon SageMaker endpoint configuration.
707
718
708
719
The endpoint configuration identifies the Amazon SageMaker model (created using the
@@ -717,17 +728,24 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
717
728
instance_type (str): Type of EC2 instance to launch, for example, 'ml.c4.xlarge'.
718
729
accelerator_type (str): Type of Elastic Inference accelerator to attach to the instance. For example,
719
730
'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
731
+ tags(List[dict[str, str]]): Optional. The list of tags to add to the endpoint config. Example:
732
+ >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
733
+ For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
734
+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
720
735
721
736
722
737
Returns:
723
738
str: Name of the endpoint point configuration created.
724
739
"""
725
740
LOGGER .info ('Creating endpoint-config with name {}' .format (name ))
726
741
742
+ tags = tags or []
743
+
727
744
self .sagemaker_client .create_endpoint_config (
728
745
EndpointConfigName = name ,
729
746
ProductionVariants = [production_variant (model_name , instance_type , initial_instance_count ,
730
- accelerator_type = accelerator_type )]
747
+ accelerator_type = accelerator_type )],
748
+ Tags = tags
731
749
)
732
750
return name
733
751
@@ -1383,19 +1401,18 @@ def __init__(self, model_data, image, env=None):
1383
1401
self .env = env
1384
1402
1385
1403
1386
- def _create_model_request (name , role , container_def = None ): # pylint: disable=redefined-outer-name
1404
+ def _create_model_request (name , role , container_def = None , tags = None ): # pylint: disable=redefined-outer-name
1405
+ request = {'ModelName' : name , 'ExecutionRoleArn' : role }
1406
+
1387
1407
if isinstance (container_def , list ):
1388
- return {
1389
- 'ModelName' : name ,
1390
- 'Containers' : container_def ,
1391
- 'ExecutionRoleArn' : role
1392
- }
1408
+ request ['Containers' ] = container_def
1393
1409
else :
1394
- return {
1395
- 'ModelName' : name ,
1396
- 'PrimaryContainer' : container_def ,
1397
- 'ExecutionRoleArn' : role
1398
- }
1410
+ request ['PrimaryContainer' ] = container_def
1411
+
1412
+ if tags :
1413
+ request ['Tags' ] = tags
1414
+
1415
+ return request
1399
1416
1400
1417
1401
1418
def _deployment_entity_exists (describe_fn ):
0 commit comments