Skip to content

Commit 04cb783

Browse files
author
Andrew Packer
committed
Fix changelog and add integ test
This also adds volume KMS key to Estimator.transformer()
1 parent d477a33 commit 04cb783

File tree

5 files changed

+112
-10
lines changed

5 files changed

+112
-10
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ CHANGELOG
66
========
77

88
* bug-fix: Local Mode: Create output/data directory expected by SageMaker Container.
9+
* enhancement: Add support for volume KMS key to Transformer
910

1011
1.9.2
1112
=====
1213

1314
* feature: add support for TensorFlow 1.9
14-
* enhancement: Add support for volume KMS key to Transformer
1515

1616
1.9.1
1717
=====

src/sagemaker/estimator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def delete_endpoint(self):
328328

329329
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
330330
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
331-
max_payload=None, tags=None, role=None):
331+
max_payload=None, tags=None, role=None, volume_kms_key=None):
332332
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
333333
SageMaker Session and base job name used by the Estimator.
334334
@@ -350,6 +350,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
350350
the training job are used for the transform job.
351351
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
352352
transform jobs. If not specified, the role from the Estimator will be used.
353+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
354+
compute instance (default: None).
353355
"""
354356
self._ensure_latest_training_job()
355357

@@ -360,7 +362,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
360362
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
361363
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
362364
env=env, tags=tags, base_transform_job_name=self.base_job_name,
363-
sagemaker_session=self.sagemaker_session)
365+
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)
364366

365367
@property
366368
def training_job_analytics(self):
@@ -756,7 +758,7 @@ def _update_init_params(cls, hp, tf_arguments):
756758

757759
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
758760
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
759-
max_payload=None, tags=None, role=None, model_server_workers=None):
761+
max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None):
760762
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
761763
SageMaker Session and base job name used by the Estimator.
762764
@@ -780,6 +782,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
780782
transform jobs. If not specified, the role from the Estimator will be used.
781783
model_server_workers (int): Optional. The number of worker processes used by the inference server.
782784
If None, server will use one worker per vCPU.
785+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
786+
compute instance (default: None).
783787
"""
784788
self._ensure_latest_training_job()
785789
role = role or self.role
@@ -799,7 +803,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
799803
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
800804
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
801805
env=transform_env, tags=tags, base_transform_job_name=self.base_job_name,
802-
sagemaker_session=self.sagemaker_session)
806+
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)
803807

804808

805809
def _s3_uri_prefix(channel_name, s3_data):

tests/integ/kms_utils.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
KEY_ALIAS = "SageMakerKmsKey"
16+
KEY_POLICY = '''
17+
{{
18+
"Version": "2012-10-17",
19+
"Id": "sagemaker-kms-integ-test-policy",
20+
"Statement": [
21+
{{
22+
"Sid": "Enable IAM User Permissions",
23+
"Effect": "Allow",
24+
"Principal": {{
25+
"AWS": "arn:aws:iam::{account_id}:root"
26+
}},
27+
"Action": "kms:*",
28+
"Resource": "*"
29+
}},
30+
{{
31+
"Sid": "Allow use of the key",
32+
"Effect": "Allow",
33+
"Principal": {{
34+
"AWS": "arn:aws:iam::{account_id}:role/SageMakerRole"
35+
}},
36+
"Action": [
37+
"kms:Encrypt",
38+
"kms:Decrypt",
39+
"kms:ReEncrypt*",
40+
"kms:GenerateDataKey*",
41+
"kms:DescribeKey"
42+
],
43+
"Resource": "*"
44+
}},
45+
{{
46+
"Sid": "Allow attachment of persistent resources",
47+
"Effect": "Allow",
48+
"Principal": {{
49+
"AWS": "arn:aws:iam::{account_id}:role/SageMakerRole"
50+
}},
51+
"Action": [
52+
"kms:CreateGrant",
53+
"kms:ListGrants",
54+
"kms:RevokeGrant"
55+
],
56+
"Resource": "*",
57+
"Condition": {{
58+
"Bool": {{
59+
"kms:GrantIsForAWSResource": "true"
60+
}}
61+
}}
62+
}}
63+
]
64+
}}
65+
'''
66+
67+
68+
def _get_kms_key_arn(kms_client, alias):
69+
try:
70+
response = kms_client.describe_key(KeyId='alias/' + alias)
71+
return response['KeyMetadata']['Arn']
72+
except kms_client.exceptions.NotFoundException:
73+
return None
74+
75+
def _create_kms_key(kms_client, account_id):
76+
response = kms_client.create_key(
77+
Policy=KEY_POLICY.format(account_id=account_id),
78+
Description='KMS key for SageMaker Python SDK integ tests',
79+
)
80+
key_arn = response['KeyMetadata']['Arn']
81+
response = kms_client.create_alias(AliasName='alias/' + KEY_ALIAS, TargetKeyId=key_arn)
82+
return key_arn
83+
84+
def get_or_create_kms_key(kms_client, account_id):
85+
kms_key_arn = _get_kms_key_arn(kms_client, KEY_ALIAS)
86+
if kms_key_arn is not None:
87+
return kms_key_arn
88+
else:
89+
return _create_kms_key(kms_client, account_id)

tests/integ/test_transformer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sagemaker.mxnet import MXNet
2424
from sagemaker.transformer import Transformer
2525
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
26+
from tests.integ.kms_utils import get_or_create_kms_key
2627
from tests.integ.timeout import timeout
2728

2829

@@ -47,9 +48,16 @@ def test_transform_mxnet(sagemaker_session):
4748
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
4849
key_prefix=transform_input_key_prefix)
4950

50-
transformer = _create_transformer_and_transform_job(mx, transform_input)
51-
transformer.wait()
51+
sts_client = sagemaker_session.boto_session.client('sts')
52+
account_id = sts_client.get_caller_identity()['Account']
53+
kms_client = sagemaker_session.boto_session.client('kms')
54+
kms_key_arn = get_or_create_kms_key(kms_client, account_id)
5255

56+
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
57+
transformer.wait()
58+
job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
59+
TransformJobName=transformer.latest_transform_job.name)
60+
assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId']
5361

5462
@pytest.mark.continuous_testing
5563
def test_attach_transform_kmeans(sagemaker_session):
@@ -90,7 +98,7 @@ def test_attach_transform_kmeans(sagemaker_session):
9098
attached_transformer.wait()
9199

92100

93-
def _create_transformer_and_transform_job(estimator, transform_input):
94-
transformer = estimator.transformer(1, 'ml.m4.xlarge')
101+
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
102+
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
95103
transformer.transform(transform_input, content_type='text/csv')
96104
return transformer

tests/unit/test_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
488488
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
489489
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
490490
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
491-
env=env, role=new_role, model_server_workers=1)
491+
volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1)
492492

493493
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF)
494494
assert transformer.strategy == strategy
@@ -501,6 +501,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
501501
assert transformer.env == env
502502
assert transformer.base_transform_job_name == base_name
503503
assert transformer.tags == TAGS
504+
assert transformer.volume_kms_key == kms_key
504505

505506

506507
def test_ensure_latest_training_job(sagemaker_session):

0 commit comments

Comments
 (0)