Skip to content

Commit cbf4d46

Browse files
authored
change: add KMS key option for Endpoint Configs (#762)
1 parent 34d2849 commit cbf4d46

File tree

7 files changed

+68
-22
lines changed

7 files changed

+68
-22
lines changed

src/sagemaker/model.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def compile(self, target_instance_family, input_shape, output_path, role,
209209
return self
210210

211211
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
212-
update_endpoint=False, tags=None):
212+
update_endpoint=False, tags=None, kms_key=None):
213213
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
214214
215215
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
@@ -235,6 +235,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
235235
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
236236
corresponding to the previous EndpointConfig. If False, a new endpoint will be created. Default: False
237237
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
238+
kms_key (str): The ARN of the KMS key that is used to encrypt the data on the
239+
storage volume attached to the instance hosting the endpoint.
238240
239241
Returns:
240242
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
@@ -270,10 +272,12 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
270272
initial_instance_count=initial_instance_count,
271273
instance_type=instance_type,
272274
accelerator_type=accelerator_type,
273-
tags=tags)
275+
tags=tags,
276+
kms_key=kms_key)
274277
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
275278
else:
276-
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
279+
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant],
280+
tags, kms_key)
277281

278282
if self.predictor_cls:
279283
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)

src/sagemaker/session.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def wait_for_model_package(self, model_package_name, poll=5):
710710
return desc
711711

712712
def create_endpoint_config(self, name, model_name, initial_instance_count, instance_type,
713-
accelerator_type=None, tags=None):
713+
accelerator_type=None, tags=None, kms_key=None):
714714
"""Create an Amazon SageMaker endpoint configuration.
715715
716716
The endpoint configuration identifies the Amazon SageMaker model (created using the
@@ -738,12 +738,21 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
738738

739739
tags = tags or []
740740

741-
self.sagemaker_client.create_endpoint_config(
742-
EndpointConfigName=name,
743-
ProductionVariants=[production_variant(model_name, instance_type, initial_instance_count,
744-
accelerator_type=accelerator_type)],
745-
Tags=tags
746-
)
741+
request = {
742+
'EndpointConfigName': name,
743+
'ProductionVariants': [
744+
production_variant(model_name, instance_type, initial_instance_count,
745+
accelerator_type=accelerator_type)
746+
],
747+
}
748+
749+
if tags is not None:
750+
request['Tags'] = tags
751+
752+
if kms_key is not None:
753+
request['KmsKeyId'] = kms_key
754+
755+
self.sagemaker_client.create_endpoint_config(**request)
747756
return name
748757

749758
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
@@ -1032,13 +1041,15 @@ def endpoint_from_model_data(self, model_s3_location, deployment_image, initial_
10321041
self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
10331042
return name
10341043

1035-
def endpoint_from_production_variants(self, name, production_variants, tags=None, wait=True):
1044+
def endpoint_from_production_variants(self, name, production_variants, tags=None, kms_key=None, wait=True):
10361045
"""Create an SageMaker ``Endpoint`` from a list of production variants.
10371046
10381047
Args:
10391048
name (str): The name of the ``Endpoint`` to create.
10401049
production_variants (list[dict[str, str]]): The list of production variants to deploy.
10411050
tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint (default: None).
1051+
kms_key (str): The KMS key that is used to encrypt the data on the storage volume attached
1052+
to the instance hosting the endpoint.
10421053
wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True).
10431054
10441055
Returns:
@@ -1050,6 +1061,8 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None
10501061
config_options = {'EndpointConfigName': name, 'ProductionVariants': production_variants}
10511062
if tags:
10521063
config_options['Tags'] = tags
1064+
if kms_key:
1065+
config_options['KmsKeyId'] = kms_key
10531066

10541067
self.sagemaker_client.create_endpoint_config(**config_options)
10551068
return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)

tests/integ/kms_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,16 @@ def _add_role_to_policy(kms_client,
101101
Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal))
102102

103103

104-
def get_or_create_kms_key(kms_client,
105-
account_id,
104+
def get_or_create_kms_key(sagemaker_session,
106105
role_arn=None,
107106
alias=KEY_ALIAS,
108107
sagemaker_role='SageMakerRole'):
108+
kms_client = sagemaker_session.boto_session.client('kms')
109109
kms_key_arn = _get_kms_key_arn(kms_client, alias)
110110

111+
sts_client = sagemaker_session.boto_session.client('sts')
112+
account_id = sts_client.get_caller_identity()['Account']
113+
111114
if kms_key_arn is None:
112115
return _create_kms_key(kms_client, account_id, role_arn, sagemaker_role, alias)
113116

tests/integ/test_mxnet_train.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.mxnet.model import MXNetModel
2424
from sagemaker.utils import sagemaker_timestamp
2525
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
26+
from tests.integ.kms_utils import get_or_create_kms_key
2627
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2728

2829

@@ -78,7 +79,7 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
7879
assert 'Could not find model' in str(exception.value)
7980

8081

81-
def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_full_version):
82+
def test_deploy_model_with_tags_and_kms(mxnet_training_job, sagemaker_session, mxnet_full_version):
8283
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
8384

8485
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
@@ -88,8 +89,11 @@ def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_ful
8889
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
8990
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
9091
framework_version=mxnet_full_version)
92+
9193
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
92-
model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags)
94+
kms_key_arn = get_or_create_kms_key(sagemaker_session)
95+
96+
model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags, kms_key=kms_key_arn)
9397

9498
returned_model = sagemaker_session.describe_model(EndpointName=model.name)
9599
returned_model_tags = sagemaker_session.list_tags(ResourceArn=returned_model['ModelArn'])['Tags']
@@ -107,6 +111,7 @@ def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_ful
107111
assert endpoint_tags == tags
108112
assert production_variants[0]['InstanceType'] == 'ml.m4.xlarge'
109113
assert production_variants[0]['InitialInstanceCount'] == 1
114+
assert endpoint_config['KmsKeyId'] == kms_key_arn
110115

111116

112117
def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):

tests/integ/test_transformer.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version):
5050
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
5151
key_prefix=transform_input_key_prefix)
5252

53-
sts_client = sagemaker_session.boto_session.client('sts')
54-
account_id = sts_client.get_caller_identity()['Account']
55-
kms_client = sagemaker_session.boto_session.client('kms')
56-
kms_key_arn = get_or_create_kms_key(kms_client, account_id)
53+
kms_key_arn = get_or_create_kms_key(sagemaker_session)
5754

5855
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
5956
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,

tests/unit/test_estimator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,8 @@ def test_fit_deploy_keep_tags(sagemaker_session):
902902
job_name = estimator._current_job_name
903903
sagemaker_session.endpoint_from_production_variants.assert_called_with(job_name,
904904
variant,
905-
tags)
905+
tags,
906+
None)
906907

907908
sagemaker_session.create_model.assert_called_with(
908909
ANY,

tests/unit/test_model.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def test_deploy(sagemaker_session, tmpdir):
165165
'InstanceType': INSTANCE_TYPE,
166166
'InitialInstanceCount': 1,
167167
'VariantName': 'AllTraffic'}],
168+
None,
168169
None)
169170

170171

@@ -180,6 +181,7 @@ def test_deploy_endpoint_name(sagemaker_session, tmpdir):
180181
'InstanceType': INSTANCE_TYPE,
181182
'InitialInstanceCount': 55,
182183
'VariantName': 'AllTraffic'}],
184+
None,
183185
None)
184186

185187

@@ -196,7 +198,8 @@ def test_deploy_tags(sagemaker_session, tmpdir):
196198
'InstanceType': INSTANCE_TYPE,
197199
'InitialInstanceCount': 1,
198200
'VariantName': 'AllTraffic'}],
199-
tags)
201+
tags,
202+
None)
200203

201204

202205
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -213,9 +216,28 @@ def test_deploy_accelerator_type(tfo, time, sagemaker_session):
213216
'InitialInstanceCount': 1,
214217
'VariantName': 'AllTraffic',
215218
'AcceleratorType': ACCELERATOR_TYPE}],
219+
None,
216220
None)
217221

218222

223+
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
224+
@patch('tarfile.open')
225+
@patch('time.strftime', return_value=TIMESTAMP)
226+
def test_deploy_kms_key(tfo, time, sagemaker_session):
227+
key = 'some-key-arn'
228+
model = DummyFrameworkModel(sagemaker_session)
229+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, kms_key=key)
230+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
231+
MODEL_NAME,
232+
[{'InitialVariantWeight': 1,
233+
'ModelName': MODEL_NAME,
234+
'InstanceType': INSTANCE_TYPE,
235+
'InitialInstanceCount': 1,
236+
'VariantName': 'AllTraffic'}],
237+
None,
238+
key)
239+
240+
219241
@patch('sagemaker.session.Session')
220242
@patch('sagemaker.local.LocalSession')
221243
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
@@ -246,7 +268,8 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir):
246268
initial_instance_count=INSTANCE_COUNT,
247269
instance_type=INSTANCE_TYPE,
248270
accelerator_type=ACCELERATOR_TYPE,
249-
tags=None
271+
tags=None,
272+
kms_key=None,
250273
)
251274
config_name = sagemaker_session.create_endpoint_config(
252275
name=model.name,

0 commit comments

Comments
 (0)