diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 91034ca5c9..596e4fba5d 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -327,15 +327,17 @@ def describe(self): class _LocalEndpointConfig(object): - def __init__(self, config_name, production_variants): + def __init__(self, config_name, production_variants, tags=None): self.name = config_name self.production_variants = production_variants + self.tags = tags self.creation_time = datetime.datetime.now() def describe(self): response = { 'EndpointConfigName': self.name, 'EndpointConfigArn': _UNUSED_ARN, + 'Tags': self.tags, 'CreationTime': self.creation_time, 'ProductionVariants': self.production_variants } @@ -348,7 +350,7 @@ class _LocalEndpoint(object): _IN_SERVICE = 'InService' _FAILED = 'Failed' - def __init__(self, endpoint_name, endpoint_config_name, local_session=None): + def __init__(self, endpoint_name, endpoint_config_name, tags=None, local_session=None): # runtime import since there is a cyclic dependency between entities and local_session from sagemaker.local import LocalSession self.local_session = local_session or LocalSession() @@ -357,6 +359,7 @@ def __init__(self, endpoint_name, endpoint_config_name, local_session=None): self.name = endpoint_name self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name) self.production_variant = self.endpoint_config['ProductionVariants'][0] + self.tags = tags model_name = self.production_variant['ModelName'] self.primary_container = local_client.describe_model(model_name)['PrimaryContainer'] @@ -392,6 +395,7 @@ def describe(self): 'EndpointConfigName': self.endpoint_config['EndpointConfigName'], 'CreationTime': self.create_time, 'ProductionVariants': self.endpoint_config['ProductionVariants'], + 'Tags': self.tags, 'EndpointName': self.name, 'EndpointArn': _UNUSED_ARN, 'EndpointStatus': self.state diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 47d62cb1fb..22da783c72 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -127,9 +127,9 @@ def describe_endpoint_config(self, EndpointConfigName): 'Code': 'ValidationException', 'Message': 'Could not find local endpoint config'}} raise ClientError(error_response, 'describe_endpoint_config') - def create_endpoint_config(self, EndpointConfigName, ProductionVariants): + def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None): LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig( - EndpointConfigName, ProductionVariants) + EndpointConfigName, ProductionVariants, Tags) def describe_endpoint(self, EndpointName): if EndpointName not in LocalSagemakerClient._endpoints: @@ -138,8 +138,8 @@ def describe_endpoint(self, EndpointName): else: return LocalSagemakerClient._endpoints[EndpointName].describe() - def create_endpoint(self, EndpointName, EndpointConfigName): - endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, self.sagemaker_session) + def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None): + endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session) LocalSagemakerClient._endpoints[EndpointName] = endpoint endpoint.serve() diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index eae96e0155..3896145375 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -271,7 +271,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e model_name=self.name, initial_instance_count=initial_instance_count, instance_type=instance_type, - accelerator_type=accelerator_type) + accelerator_type=accelerator_type, + tags=tags) self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name) else: self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index acc5173cc8..d98d93a5aa 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -749,7 +749,7 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta ) return name - def create_endpoint(self, endpoint_name, config_name, wait=True): + def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True): """Create an Amazon SageMaker ``Endpoint`` according to the endpoint configuration specified in the request. Once the ``Endpoint`` is created, client applications can send requests to obtain inferences. @@ -764,7 +764,10 @@ def create_endpoint(self, endpoint_name, config_name, wait=True): str: Name of the Amazon SageMaker ``Endpoint`` created. """ LOGGER.info('Creating endpoint with name {}'.format(endpoint_name)) - self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name) + + tags = tags or [] + + self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags) if wait: self.wait_for_endpoint(endpoint_name) return endpoint_name @@ -1052,7 +1055,7 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None config_options['Tags'] = tags self.sagemaker_client.create_endpoint_config(**config_options) - return self.create_endpoint(endpoint_name=name, config_name=name, wait=wait) + return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait) def expand_role(self, role): """Expand an IAM role name into an ARN. diff --git a/tests/integ/test_mxnet_train.py b/tests/integ/test_mxnet_train.py index 40178a730f..5803cd3ba3 100644 --- a/tests/integ/test_mxnet_train.py +++ b/tests/integ/test_mxnet_train.py @@ -78,6 +78,37 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version) assert 'Could not find model' in str(exception.value) +def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_full_version): + endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp()) + + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): + desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job) + model_data = desc['ModelArtifacts']['S3ModelArtifacts'] + script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py') + model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path, + py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session, + framework_version=mxnet_full_version) + tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}] + model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags) + + returned_model = sagemaker_session.describe_model(EndpointName=model.name) + returned_model_tags = sagemaker_session.list_tags(ResourceArn=returned_model['ModelArn'])['Tags'] + + endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name) + endpoint_tags = sagemaker_session.list_tags(ResourceArn=endpoint['EndpointArn'])['Tags'] + + endpoint_config = sagemaker_session.describe_endpoint_config(EndpointConfigName=endpoint['EndpointConfigName']) + endpoint_config_tags = sagemaker_session.list_tags(ResourceArn=endpoint_config['EndpointConfigArn'])['Tags'] + + production_variants = endpoint_config['ProductionVariants'] + + assert returned_model_tags == tags + assert endpoint_config_tags == tags + assert endpoint_tags == tags + assert production_variants[0]['InstanceType'] == 'ml.m4.xlarge' + assert production_variants[0]['InitialInstanceCount'] == 1 + + def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version): endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp()) diff --git a/tests/unit/test_create_deploy_entities.py b/tests/unit/test_create_deploy_entities.py index 1b7e170317..cad572ffe2 100644 --- a/tests/unit/test_create_deploy_entities.py +++ b/tests/unit/test_create_deploy_entities.py @@ -96,7 +96,7 @@ def test_create_endpoint_no_wait(sagemaker_session): assert returned_name == ENDPOINT_NAME sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with( - EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME) + EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[]) def test_create_endpoint_wait(sagemaker_session): @@ -105,5 +105,5 @@ def test_create_endpoint_wait(sagemaker_session): assert returned_name == ENDPOINT_NAME sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with( - EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME) + EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[]) sagemaker_session.wait_for_endpoint.assert_called_once_with(ENDPOINT_NAME) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index c7474c5889..818a08063c 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -245,7 +245,8 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir): model_name=model.name, initial_instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE, - accelerator_type=ACCELERATOR_TYPE + accelerator_type=ACCELERATOR_TYPE, + tags=None ) config_name = sagemaker_session.create_endpoint_config( name=model.name, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 3a318e63a2..4f34bde068 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -910,7 +910,8 @@ def test_endpoint_from_production_variants(sagemaker_session): ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs) sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint', - EndpointName='some-endpoint') + EndpointName='some-endpoint', + Tags=[]) sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName='some-endpoint', ProductionVariants=pvs) @@ -936,7 +937,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session): tags = [{'ModelName': 'TestModel'}] sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags) sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint', - EndpointName='some-endpoint') + EndpointName='some-endpoint', + Tags=tags) sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName='some-endpoint', ProductionVariants=pvs, @@ -953,7 +955,8 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi tags = [{'ModelName': 'TestModel'}] sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags) sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint', - EndpointName='some-endpoint') + EndpointName='some-endpoint', + Tags=tags) sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with( EndpointConfigName='some-endpoint', ProductionVariants=pvs,