Skip to content

fix: fixing propagation of tags to SageMaker endpoint #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 17, 2019
3 changes: 2 additions & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
model_name=self.name,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
accelerator_type=accelerator_type)
accelerator_type=accelerator_type,
tags=tags)
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
else:
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
)
return name

def create_endpoint(self, endpoint_name, config_name, wait=True):
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
"""Create an Amazon SageMaker ``Endpoint`` according to the endpoint configuration specified in the request.

Once the ``Endpoint`` is created, client applications can send requests to obtain inferences.
Expand All @@ -764,7 +764,7 @@ def create_endpoint(self, endpoint_name, config_name, wait=True):
str: Name of the Amazon SageMaker ``Endpoint`` created.
"""
LOGGER.info('Creating endpoint with name {}'.format(endpoint_name))
self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name)
self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags)
if wait:
self.wait_for_endpoint(endpoint_name)
return endpoint_name
Expand Down Expand Up @@ -1052,7 +1052,7 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None
config_options['Tags'] = tags

self.sagemaker_client.create_endpoint_config(**config_options)
return self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)

def expand_role(self, role):
"""Expand an IAM role name into an ARN.
Expand Down
25 changes: 25 additions & 0 deletions tests/integ/test_mxnet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,31 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
assert 'Could not find model' in str(exception.value)

def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_full_version):
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
framework_version=mxnet_full_version)
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags)

endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
endpoint_tags = sagemaker_session.list_tags(ResourceArn=endpoint['EndpointArn'])['Tags']

endpoint_config = sagemaker_session.describe_endpoint_config(EndpointConfigName=endpoint['EndpointConfigName'])
endpoint_config_tags = sagemaker_session.list_tags(ResourceArn=endpoint_config['EndpointConfigArn'])['Tags']

production_variants = endpoint_config['ProductionVariants']

assert endpoint_config_tags == tags
assert endpoint_tags == tags
assert production_variants[0]['InstanceType'] == 'ml.m4.xlarge'
assert production_variants[0]['InitialInstanceCount'] == 1

def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_create_deploy_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_create_endpoint_no_wait(sagemaker_session):

assert returned_name == ENDPOINT_NAME
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=None)


def test_create_endpoint_wait(sagemaker_session):
Expand All @@ -105,5 +105,5 @@ def test_create_endpoint_wait(sagemaker_session):

assert returned_name == ENDPOINT_NAME
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=None)
sagemaker_session.wait_for_endpoint.assert_called_once_with(ENDPOINT_NAME)
3 changes: 2 additions & 1 deletion tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir):
model_name=model.name,
initial_instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
accelerator_type=ACCELERATOR_TYPE
accelerator_type=ACCELERATOR_TYPE,
tags=None
)
config_name = sagemaker_session.create_endpoint_config(
name=model.name,
Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,8 @@ def test_endpoint_from_production_variants(sagemaker_session):
ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex)
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=None)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs)
Expand All @@ -936,7 +937,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
tags = [{'ModelName': 'TestModel'}]
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=tags)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs,
Expand All @@ -953,7 +955,8 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi
tags = [{'ModelName': 'TestModel'}]
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
EndpointName='some-endpoint')
EndpointName='some-endpoint',
Tags=tags)
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
EndpointConfigName='some-endpoint',
ProductionVariants=pvs,
Expand Down