Skip to content

fix: fixing propagation of tags to SageMaker endpoint #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 17, 2019
8 changes: 6 additions & 2 deletions src/sagemaker/local/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,17 @@ def describe(self):

class _LocalEndpointConfig(object):

def __init__(self, config_name, production_variants):
def __init__(self, config_name, production_variants, tags=None):
self.name = config_name
self.production_variants = production_variants
self.tags = tags
self.creation_time = datetime.datetime.now()

def describe(self):
response = {
'EndpointConfigName': self.name,
'EndpointConfigArn': _UNUSED_ARN,
'Tags': self.tags,
'CreationTime': self.creation_time,
'ProductionVariants': self.production_variants
}
Expand All @@ -348,7 +350,7 @@ class _LocalEndpoint(object):
_IN_SERVICE = 'InService'
_FAILED = 'Failed'

def __init__(self, endpoint_name, endpoint_config_name, local_session=None):
def __init__(self, endpoint_name, endpoint_config_name, tags=None, local_session=None):
# runtime import since there is a cyclic dependency between entities and local_session
from sagemaker.local import LocalSession
self.local_session = local_session or LocalSession()
Expand All @@ -357,6 +359,7 @@ def __init__(self, endpoint_name, endpoint_config_name, local_session=None):
self.name = endpoint_name
self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name)
self.production_variant = self.endpoint_config['ProductionVariants'][0]
self.tags = tags

model_name = self.production_variant['ModelName']
self.primary_container = local_client.describe_model(model_name)['PrimaryContainer']
Expand Down Expand Up @@ -392,6 +395,7 @@ def describe(self):
'EndpointConfigName': self.endpoint_config['EndpointConfigName'],
'CreationTime': self.create_time,
'ProductionVariants': self.endpoint_config['ProductionVariants'],
'Tags': self.tags,
'EndpointName': self.name,
'EndpointArn': _UNUSED_ARN,
'EndpointStatus': self.state
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def describe_endpoint_config(self, EndpointConfigName):
'Code': 'ValidationException', 'Message': 'Could not find local endpoint config'}}
raise ClientError(error_response, 'describe_endpoint_config')

def create_endpoint_config(self, EndpointConfigName, ProductionVariants):
def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None):
LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig(
EndpointConfigName, ProductionVariants)
EndpointConfigName, ProductionVariants, Tags)

def describe_endpoint(self, EndpointName):
if EndpointName not in LocalSagemakerClient._endpoints:
Expand All @@ -138,8 +138,8 @@ def describe_endpoint(self, EndpointName):
else:
return LocalSagemakerClient._endpoints[EndpointName].describe()

def create_endpoint(self, EndpointName, EndpointConfigName):
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, self.sagemaker_session)
def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None):
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session)
LocalSagemakerClient._endpoints[EndpointName] = endpoint
endpoint.serve()

Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
model_name=self.name,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
accelerator_type=accelerator_type)
accelerator_type=accelerator_type,
tags=tags)
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
else:
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
Expand Down
9 changes: 6 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
)
return name

def create_endpoint(self, endpoint_name, config_name, wait=True):
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
"""Create an Amazon SageMaker ``Endpoint`` according to the endpoint configuration specified in the request.

Once the ``Endpoint`` is created, client applications can send requests to obtain inferences.
Expand All @@ -764,7 +764,10 @@ def create_endpoint(self, endpoint_name, config_name, wait=True):
str: Name of the Amazon SageMaker ``Endpoint`` created.
"""
LOGGER.info('Creating endpoint with name {}'.format(endpoint_name))
self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name)

tags = tags or []

self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags)
if wait:
self.wait_for_endpoint(endpoint_name)
return endpoint_name
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None
config_options['Tags'] = tags

self.sagemaker_client.create_endpoint_config(**config_options)
return self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)

def expand_role(self, role):
"""Expand an IAM role name into an ARN.
Expand Down
31 changes: 31 additions & 0 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,37 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
assert 'Could not find model' in str(exception.value)


def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_full_version):
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version)
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags)

returned_model = sagemaker_session.describe_model(EndpointName=model.name)
returned_model_tags = sagemaker_session.list_tags(ResourceArn=returned_model['ModelArn'])['Tags']

endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
endpoint_tags = sagemaker_session.list_tags(ResourceArn=endpoint['EndpointArn'])['Tags']

endpoint_config = sagemaker_session.describe_endpoint_config(EndpointConfigName=endpoint['EndpointConfigName'])
endpoint_config_tags = sagemaker_session.list_tags(ResourceArn=endpoint_config['EndpointConfigArn'])['Tags']

production_variants = endpoint_config['ProductionVariants']

assert returned_model_tags == tags
assert endpoint_config_tags == tags
assert endpoint_tags == tags
assert production_variants[0]['InstanceType'] == 'ml.m4.xlarge'
assert production_variants[0]['InitialInstanceCount'] == 1


def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_create_deploy_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_create_endpoint_no_wait(sagemaker_session):

assert returned_name == ENDPOINT_NAME
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[])


def test_create_endpoint_wait(sagemaker_session):
Expand All @@ -105,5 +105,5 @@ def test_create_endpoint_wait(sagemaker_session):

assert returned_name == ENDPOINT_NAME
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[])
sagemaker_session.wait_for_endpoint.assert_called_once_with(ENDPOINT_NAME)
3 changes: 2 additions & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir):
model_name=model.name,
initial_instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
accelerator_type=ACCELERATOR_TYPE
accelerator_type=ACCELERATOR_TYPE,
tags=None
)
config_name = sagemaker_session.create_endpoint_config(
name=model.name,
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,8 @@ def test_endpoint_from_production_variants(sagemaker_session):
ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex)
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=[])
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs)
Expand All @@ -936,7 +937,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
tags = [{'ModelName': 'TestModel'}]
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=tags)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs,
Expand All @@ -953,7 +955,8 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi
tags = [{'ModelName': 'TestModel'}]
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=tags)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs,
Expand Down