Skip to content

Commit 6a42335

Browse files
authored
Propagate Tags from estimator to model, endpoint, and endpoint config (#699)
* Propagate Tags from estimator to model, endpoint, and endpoint config
1 parent f1a34c2 commit 6a42335

File tree

6 files changed

+106
-24
lines changed

6 files changed

+106
-24
lines changed

src/sagemaker/estimator.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,11 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
350350
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
351351
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
352352
corresponding to the previous EndpointConfig. Default: False
353+
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific endpoint. Example:
354+
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
355+
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
356+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
357+
353358
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
354359
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
355360
For more, see the implementation docs.
@@ -374,7 +379,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
374379
initial_instance_count=initial_instance_count,
375380
accelerator_type=accelerator_type,
376381
endpoint_name=endpoint_name,
377-
update_endpoint=update_endpoint)
382+
update_endpoint=update_endpoint,
383+
tags=self.tags)
378384

379385
@property
380386
def model_data(self):

src/sagemaker/model.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def enable_network_isolation(self):
9696
"""
9797
return False
9898

99-
def _create_sagemaker_model(self, instance_type, accelerator_type=None):
99+
def _create_sagemaker_model(self, instance_type, accelerator_type=None, tags=None):
100100
"""Create a SageMaker Model Entity
101101
102102
Args:
@@ -105,6 +105,11 @@ def _create_sagemaker_model(self, instance_type, accelerator_type=None):
105105
accelerator_type (str): Type of Elastic Inference accelerator to attach to an endpoint for model loading
106106
and inference, for example, 'ml.eia1.medium'. If not specified, no Elastic Inference accelerator
107107
will be attached to the endpoint.
108+
tags(List[dict[str, str]]): Optional. The list of tags to add to the model. Example:
109+
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
110+
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
111+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
112+
108113
"""
109114
container_def = self.prepare_container_def(instance_type, accelerator_type=accelerator_type)
110115
self.name = self.name or utils.name_from_image(container_def['Image'])

src/sagemaker/session.py

+35-18
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
544544
self.sagemaker_client.create_transform_job(**transform_request)
545545

546546
def create_model(self, name, role, container_defs, vpc_config=None,
547-
enable_network_isolation=False, primary_container=None):
547+
enable_network_isolation=False, primary_container=None,
548+
tags=None):
548549
"""Create an Amazon SageMaker ``Model``.
549550
Specify the S3 location of the model artifacts and Docker image containing
550551
the inference code. Amazon SageMaker uses this information to deploy the
@@ -570,6 +571,11 @@ def create_model(self, name, role, container_defs, vpc_config=None,
570571
You can also specify the return value of ``sagemaker.container_def()``, which is used to create
571572
more advanced container configurations, including model containers which need artifacts from S3. This
572573
field is deprecated, please use container_defs instead.
574+
tags(List[dict[str, str]]): Optional. The list of tags to add to the model. Example:
575+
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
576+
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
577+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
578+
573579
574580
Returns:
575581
str: Name of the Amazon SageMaker ``Model`` created.
@@ -583,12 +589,16 @@ def create_model(self, name, role, container_defs, vpc_config=None,
583589
container_defs = primary_container
584590

585591
role = self.expand_role(role)
586-
create_model_request = {}
592+
587593
if isinstance(container_defs, list):
588-
create_model_request = _create_model_request(name=name, role=role, container_def=container_defs)
594+
container_definition = container_defs
589595
else:
590-
primary_container = _expand_container_def(container_defs)
591-
create_model_request = _create_model_request(name=name, role=role, container_def=primary_container)
596+
container_definition = _expand_container_def(container_defs)
597+
598+
create_model_request = _create_model_request(name=name,
599+
role=role,
600+
container_def=container_definition,
601+
tags=tags)
592602

593603
if vpc_config:
594604
create_model_request['VpcConfig'] = vpc_config
@@ -702,7 +712,8 @@ def wait_for_model_package(self, model_package_name, poll=5):
702712
model_package_name, status, reason))
703713
return desc
704714

705-
def create_endpoint_config(self, name, model_name, initial_instance_count, instance_type, accelerator_type=None):
715+
def create_endpoint_config(self, name, model_name, initial_instance_count, instance_type,
716+
accelerator_type=None, tags=None):
706717
"""Create an Amazon SageMaker endpoint configuration.
707718
708719
The endpoint configuration identifies the Amazon SageMaker model (created using the
@@ -717,17 +728,24 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
717728
instance_type (str): Type of EC2 instance to launch, for example, 'ml.c4.xlarge'.
718729
accelerator_type (str): Type of Elastic Inference accelerator to attach to the instance. For example,
719730
'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
731+
tags(List[dict[str, str]]): Optional. The list of tags to add to the endpoint config. Example:
732+
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
733+
For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
734+
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
720735
721736
722737
Returns:
723738
str: Name of the endpoint point configuration created.
724739
"""
725740
LOGGER.info('Creating endpoint-config with name {}'.format(name))
726741

742+
tags = tags or []
743+
727744
self.sagemaker_client.create_endpoint_config(
728745
EndpointConfigName=name,
729746
ProductionVariants=[production_variant(model_name, instance_type, initial_instance_count,
730-
accelerator_type=accelerator_type)]
747+
accelerator_type=accelerator_type)],
748+
Tags=tags
731749
)
732750
return name
733751

@@ -1383,19 +1401,18 @@ def __init__(self, model_data, image, env=None):
13831401
self.env = env
13841402

13851403

1386-
def _create_model_request(name, role, container_def=None): # pylint: disable=redefined-outer-name
1404+
def _create_model_request(name, role, container_def=None, tags=None): # pylint: disable=redefined-outer-name
1405+
request = {'ModelName': name, 'ExecutionRoleArn': role}
1406+
13871407
if isinstance(container_def, list):
1388-
return {
1389-
'ModelName': name,
1390-
'Containers': container_def,
1391-
'ExecutionRoleArn': role
1392-
}
1408+
request['Containers'] = container_def
13931409
else:
1394-
return {
1395-
'ModelName': name,
1396-
'PrimaryContainer': container_def,
1397-
'ExecutionRoleArn': role
1398-
}
1410+
request['PrimaryContainer'] = container_def
1411+
1412+
if tags:
1413+
request['Tags'] = tags
1414+
1415+
return request
13991416

14001417

14011418
def _deployment_entity_exists(describe_fn):

tests/unit/test_create_deploy_entities.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_create_endpoint_config(sagemaker_session):
7070
'InitialVariantWeight': 1,
7171
'VariantName': 'AllTraffic'}]
7272
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_once_with(
73-
EndpointConfigName=ENDPOINT_CONFIG_NAME, ProductionVariants=expected_pvs)
73+
EndpointConfigName=ENDPOINT_CONFIG_NAME, ProductionVariants=expected_pvs, Tags=[])
7474

7575

7676
def test_create_endpoint_config_with_accelerator(sagemaker_session):
@@ -87,7 +87,7 @@ def test_create_endpoint_config_with_accelerator(sagemaker_session):
8787
'VariantName': 'AllTraffic',
8888
'AcceleratorType': ACCELERATOR_TYPE}]
8989
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_once_with(
90-
EndpointConfigName=ENDPOINT_CONFIG_NAME, ProductionVariants=expected_pvs)
90+
EndpointConfigName=ENDPOINT_CONFIG_NAME, ProductionVariants=expected_pvs, Tags=[])
9191

9292

9393
def test_create_endpoint_no_wait(sagemaker_session):

tests/unit/test_estimator.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from time import sleep
1919

2020
import pytest
21-
from mock import MagicMock, Mock, patch
21+
from mock import ANY, MagicMock, Mock, patch
2222

2323
from sagemaker.amazon.amazon_estimator import registry
2424
from sagemaker.algorithm import AlgorithmEstimator
@@ -882,6 +882,36 @@ def test_unsupported_type_in_dict():
882882
HP_TRAIN_CALL.update({'hyperparameters': STRINGIFIED_HYPERPARAMS})
883883

884884

885+
def test_fit_deploy_keep_tags(sagemaker_session):
886+
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
887+
estimator = Estimator(IMAGE_NAME,
888+
ROLE,
889+
INSTANCE_COUNT,
890+
INSTANCE_TYPE,
891+
tags=tags,
892+
sagemaker_session=sagemaker_session)
893+
894+
estimator.fit()
895+
896+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
897+
898+
variant = [{'InstanceType': 'c4.4xlarge', 'VariantName': 'AllTraffic',
899+
'ModelName': ANY, 'InitialVariantWeight': 1,
900+
'InitialInstanceCount': 1}]
901+
902+
job_name = estimator._current_job_name
903+
sagemaker_session.endpoint_from_production_variants.assert_called_with(job_name,
904+
variant,
905+
tags)
906+
907+
sagemaker_session.create_model.assert_called_with(
908+
ANY,
909+
'DummyRole',
910+
{'ModelDataUrl': 's3://bucket/model.tar.gz', 'Environment': {}, 'Image': 'fakeimage'},
911+
enable_network_isolation=False,
912+
vpc_config=None)
913+
914+
885915
def test_generic_to_fit_no_input(sagemaker_session):
886916
e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
887917
sagemaker_session=sagemaker_session)

tests/unit/test_session.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020
import six
2121
from botocore.exceptions import ClientError
22-
from mock import MagicMock, Mock, patch, call, mock_open
22+
from mock import ANY, MagicMock, Mock, patch, call, mock_open
2323

2424
import sagemaker
2525
from sagemaker import s3_input, Session, get_execution_role
@@ -758,6 +758,19 @@ def test_create_model(expand_container_def, sagemaker_session):
758758
PrimaryContainer=PRIMARY_CONTAINER)
759759

760760

761+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
762+
def test_create_model_with_tags(expand_container_def, sagemaker_session):
763+
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
764+
model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER, tags=tags)
765+
766+
assert model == MODEL_NAME
767+
tags = [{'Value': 'TagtestValue', 'Key': 'TagtestKey'}]
768+
sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
769+
ModelName=MODEL_NAME,
770+
PrimaryContainer=PRIMARY_CONTAINER,
771+
Tags=tags)
772+
773+
761774
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
762775
def test_create_model_with_primary_container(expand_container_def, sagemaker_session):
763776
model = sagemaker_session.create_model(MODEL_NAME, ROLE, container_defs=PRIMARY_CONTAINER)
@@ -903,6 +916,17 @@ def test_endpoint_from_production_variants(sagemaker_session):
903916
ProductionVariants=pvs)
904917

905918

919+
def test_create_endpoint_config_with_tags(sagemaker_session):
920+
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
921+
922+
sagemaker_session.create_endpoint_config('endpoint-test', 'simple-model', 1, 'local', tags=tags)
923+
924+
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
925+
EndpointConfigName='endpoint-test',
926+
ProductionVariants=ANY,
927+
Tags=tags)
928+
929+
906930
def test_endpoint_from_production_variants_with_tags(sagemaker_session):
907931
ims = sagemaker_session
908932
ims.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointStatus': 'InService'})

0 commit comments

Comments
 (0)