Skip to content

Commit 400f25e

Browse files
apackerPiali Das
authored and
Piali Das
committed
Add volume KMS key to transformer (aws#368)
* Add volume KMS key to transformer This adds support for the optional VolumeKmsKeyId parameter of CreateTuningJob * Fix changelog and add integ test This also adds volume KMS key to Estimator.transformer() * Fixing flake errors * Fix resource config preparation for volume kms key
1 parent 230f3f6 commit 400f25e

File tree

7 files changed

+137
-16
lines changed

7 files changed

+137
-16
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ CHANGELOG
2828
1.9.3
2929
=====
3030

31+
* enhancement: Add support for volume KMS key to Transformer
3132
* bug-fix: Local Mode: Create output/data directory expected by SageMaker Container.
3233
* bug-fix: Estimator accepts the vpc configs made capable by 1.9.1
3334

src/sagemaker/estimator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def delete_endpoint(self):
331331

332332
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
333333
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
334-
max_payload=None, tags=None, role=None):
334+
max_payload=None, tags=None, role=None, volume_kms_key=None):
335335
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
336336
SageMaker Session and base job name used by the Estimator.
337337
@@ -353,6 +353,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
353353
the training job are used for the transform job.
354354
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
355355
transform jobs. If not specified, the role from the Estimator will be used.
356+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
357+
compute instance (default: None).
356358
"""
357359
self._ensure_latest_training_job()
358360

@@ -363,7 +365,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
363365
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
364366
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
365367
env=env, tags=tags, base_transform_job_name=self.base_job_name,
366-
sagemaker_session=self.sagemaker_session)
368+
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)
367369

368370
@property
369371
def training_job_analytics(self):
@@ -767,7 +769,7 @@ def _update_init_params(cls, hp, tf_arguments):
767769

768770
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
769771
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
770-
max_payload=None, tags=None, role=None, model_server_workers=None):
772+
max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None):
771773
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
772774
SageMaker Session and base job name used by the Estimator.
773775
@@ -791,6 +793,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
791793
transform jobs. If not specified, the role from the Estimator will be used.
792794
model_server_workers (int): Optional. The number of worker processes used by the inference server.
793795
If None, server will use one worker per vCPU.
796+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
797+
compute instance (default: None).
794798
"""
795799
self._ensure_latest_training_job()
796800
role = role or self.role
@@ -810,7 +814,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
810814
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
811815
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
812816
env=transform_env, tags=tags, base_transform_job_name=self.base_job_name,
813-
sagemaker_session=self.sagemaker_session)
817+
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)
814818

815819

816820
def _s3_uri_prefix(channel_name, s3_data):

src/sagemaker/transformer.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Transformer(object):
2323

2424
def __init__(self, model_name, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
2525
output_kms_key=None, accept=None, max_concurrent_transforms=None, max_payload=None, tags=None,
26-
env=None, base_transform_job_name=None, sagemaker_session=None):
26+
env=None, base_transform_job_name=None, sagemaker_session=None, volume_kms_key=None):
2727
"""Initialize a ``Transformer``.
2828
2929
Args:
@@ -50,6 +50,8 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
5050
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
5151
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
5252
using the default AWS configuration chain.
53+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
54+
compute instance (default: None).
5355
"""
5456
self.model_name = model_name
5557
self.strategy = strategy
@@ -62,6 +64,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
6264

6365
self.instance_count = instance_count
6466
self.instance_type = instance_type
67+
self.volume_kms_key = volume_kms_key
6568

6669
self.max_concurrent_transforms = max_concurrent_transforms
6770
self.max_payload = max_payload
@@ -159,6 +162,7 @@ def _prepare_init_params_from_job_description(cls, job_details):
159162
init_params['model_name'] = job_details['ModelName']
160163
init_params['instance_count'] = job_details['TransformResources']['InstanceCount']
161164
init_params['instance_type'] = job_details['TransformResources']['InstanceType']
165+
init_params['volume_kms_key'] = job_details['TransformResources'].get('VolumeKmsKeyId')
162166
init_params['strategy'] = job_details.get('BatchStrategy')
163167
init_params['assemble_with'] = job_details['TransformOutput'].get('AssembleWith')
164168
init_params['output_path'] = job_details['TransformOutput']['S3OutputPath']
@@ -200,7 +204,8 @@ def _load_config(data, data_type, content_type, compression_type, split_type, tr
200204
output_config = _TransformJob._prepare_output_config(transformer.output_path, transformer.output_kms_key,
201205
transformer.assemble_with, transformer.accept)
202206

203-
resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type)
207+
resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type,
208+
transformer.volume_kms_key)
204209

205210
return {'input_config': input_config,
206211
'output_config': output_config,
@@ -241,5 +246,10 @@ def _prepare_output_config(s3_path, kms_key_id, assemble_with, accept):
241246
return config
242247

243248
@staticmethod
244-
def _prepare_resource_config(instance_count, instance_type):
245-
return {'InstanceCount': instance_count, 'InstanceType': instance_type}
249+
def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
250+
config = {'InstanceCount': instance_count, 'InstanceType': instance_type}
251+
252+
if volume_kms_key is not None:
253+
config['VolumeKmsKeyId'] = volume_kms_key
254+
255+
return config

tests/integ/kms_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
KEY_ALIAS = "SageMakerKmsKey"
16+
KEY_POLICY = '''
17+
{{
18+
"Version": "2012-10-17",
19+
"Id": "sagemaker-kms-integ-test-policy",
20+
"Statement": [
21+
{{
22+
"Sid": "Enable IAM User Permissions",
23+
"Effect": "Allow",
24+
"Principal": {{
25+
"AWS": "arn:aws:iam::{account_id}:root"
26+
}},
27+
"Action": "kms:*",
28+
"Resource": "*"
29+
}},
30+
{{
31+
"Sid": "Allow use of the key",
32+
"Effect": "Allow",
33+
"Principal": {{
34+
"AWS": "arn:aws:iam::{account_id}:role/SageMakerRole"
35+
}},
36+
"Action": [
37+
"kms:Encrypt",
38+
"kms:Decrypt",
39+
"kms:ReEncrypt*",
40+
"kms:GenerateDataKey*",
41+
"kms:DescribeKey"
42+
],
43+
"Resource": "*"
44+
}},
45+
{{
46+
"Sid": "Allow attachment of persistent resources",
47+
"Effect": "Allow",
48+
"Principal": {{
49+
"AWS": "arn:aws:iam::{account_id}:role/SageMakerRole"
50+
}},
51+
"Action": [
52+
"kms:CreateGrant",
53+
"kms:ListGrants",
54+
"kms:RevokeGrant"
55+
],
56+
"Resource": "*",
57+
"Condition": {{
58+
"Bool": {{
59+
"kms:GrantIsForAWSResource": "true"
60+
}}
61+
}}
62+
}}
63+
]
64+
}}
65+
'''
66+
67+
68+
def _get_kms_key_arn(kms_client, alias):
69+
try:
70+
response = kms_client.describe_key(KeyId='alias/' + alias)
71+
return response['KeyMetadata']['Arn']
72+
except kms_client.exceptions.NotFoundException:
73+
return None
74+
75+
76+
def _create_kms_key(kms_client, account_id):
77+
response = kms_client.create_key(
78+
Policy=KEY_POLICY.format(account_id=account_id),
79+
Description='KMS key for SageMaker Python SDK integ tests',
80+
)
81+
key_arn = response['KeyMetadata']['Arn']
82+
response = kms_client.create_alias(AliasName='alias/' + KEY_ALIAS, TargetKeyId=key_arn)
83+
return key_arn
84+
85+
86+
def get_or_create_kms_key(kms_client, account_id):
87+
kms_key_arn = _get_kms_key_arn(kms_client, KEY_ALIAS)
88+
if kms_key_arn is not None:
89+
return kms_key_arn
90+
else:
91+
return _create_kms_key(kms_client, account_id)

tests/integ/test_transformer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.mxnet import MXNet
2424
from sagemaker.transformer import Transformer
2525
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
26+
from tests.integ.kms_utils import get_or_create_kms_key
2627
from tests.integ.timeout import timeout
2728

2829

@@ -47,8 +48,16 @@ def test_transform_mxnet(sagemaker_session):
4748
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
4849
key_prefix=transform_input_key_prefix)
4950

50-
transformer = _create_transformer_and_transform_job(mx, transform_input)
51+
sts_client = sagemaker_session.boto_session.client('sts')
52+
account_id = sts_client.get_caller_identity()['Account']
53+
kms_client = sagemaker_session.boto_session.client('kms')
54+
kms_key_arn = get_or_create_kms_key(kms_client, account_id)
55+
56+
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
5157
transformer.wait()
58+
job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
59+
TransformJobName=transformer.latest_transform_job.name)
60+
assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId']
5261

5362

5463
@pytest.mark.continuous_testing
@@ -90,7 +99,7 @@ def test_attach_transform_kmeans(sagemaker_session):
9099
attached_transformer.wait()
91100

92101

93-
def _create_transformer_and_transform_job(estimator, transform_input):
94-
transformer = estimator.transformer(1, 'ml.m4.xlarge')
102+
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
103+
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
95104
transformer.transform(transform_input, content_type='text/csv')
96105
return transformer

tests/unit/test_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
494494
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
495495
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
496496
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
497-
env=env, role=new_role, model_server_workers=1)
497+
volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1)
498498

499499
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF)
500500
assert transformer.strategy == strategy
@@ -507,6 +507,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
507507
assert transformer.env == env
508508
assert transformer.base_transform_job_name == base_name
509509
assert transformer.tags == TAGS
510+
assert transformer.volume_kms_key == kms_key
510511

511512

512513
def test_ensure_latest_training_job(sagemaker_session):

tests/unit/test_transformer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
INSTANCE_COUNT = 1
2525
INSTANCE_TYPE = 'ml.m4.xlarge'
26+
KMS_KEY_ID = 'kms-key-id'
2627

2728
S3_DATA_TYPE = 'S3Prefix'
2829
S3_BUCKET = 'bucket'
@@ -48,7 +49,8 @@ def sagemaker_session():
4849
@pytest.fixture()
4950
def transformer(sagemaker_session):
5051
return Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE,
51-
output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session)
52+
output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session,
53+
volume_kms_key=KMS_KEY_ID)
5254

5355

5456
@patch('sagemaker.transformer._TransformJob.start_new')
@@ -178,7 +180,8 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
178180
'ModelName': MODEL_NAME,
179181
'TransformResources': {
180182
'InstanceCount': INSTANCE_COUNT,
181-
'InstanceType': INSTANCE_TYPE
183+
'InstanceType': INSTANCE_TYPE,
184+
'VolumeKmsKeyId': KMS_KEY_ID
182185
},
183186
'BatchStrategy': None,
184187
'TransformOutput': {
@@ -197,6 +200,7 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
197200
assert init_params['model_name'] == MODEL_NAME
198201
assert init_params['instance_count'] == INSTANCE_COUNT
199202
assert init_params['instance_type'] == INSTANCE_TYPE
203+
assert init_params['volume_kms_key'] == KMS_KEY_ID
200204

201205

202206
# _TransformJob tests
@@ -227,6 +231,7 @@ def test_load_config(transformer):
227231
'resource_config': {
228232
'InstanceCount': INSTANCE_COUNT,
229233
'InstanceType': INSTANCE_TYPE,
234+
'VolumeKmsKeyId': KMS_KEY_ID,
230235
},
231236
}
232237

@@ -292,8 +297,8 @@ def test_prepare_output_config_with_optional_params():
292297

293298

294299
def test_prepare_resource_config():
295-
config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE)
296-
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE}
300+
config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, KMS_KEY_ID)
301+
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID}
297302

298303

299304
def test_transform_job_wait(sagemaker_session):

0 commit comments

Comments
 (0)