Skip to content

Add volume KMS key to transformer #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 27, 2018
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ CHANGELOG
1.9.3
=====

* enhancement: Add support for volume KMS key to Transformer
* bug-fix: Local Mode: Create output/data directory expected by SageMaker Container.
* bug-fix: Estimator accepts the vpc configs made capable by 1.9.1

Expand Down
12 changes: 8 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def delete_endpoint(self):

def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
max_payload=None, tags=None, role=None):
max_payload=None, tags=None, role=None, volume_kms_key=None):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
SageMaker Session and base job name used by the Estimator.

Expand All @@ -353,6 +353,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
the training job are used for the transform job.
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
transform jobs. If not specified, the role from the Estimator will be used.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
"""
self._ensure_latest_training_job()

Expand All @@ -363,7 +365,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=env, tags=tags, base_transform_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session)
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)

@property
def training_job_analytics(self):
Expand Down Expand Up @@ -767,7 +769,7 @@ def _update_init_params(cls, hp, tf_arguments):

def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
max_payload=None, tags=None, role=None, model_server_workers=None):
max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None):
"""Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
SageMaker Session and base job name used by the Estimator.

Expand All @@ -791,6 +793,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
transform jobs. If not specified, the role from the Estimator will be used.
model_server_workers (int): Optional. The number of worker processes used by the inference server.
If None, server will use one worker per vCPU.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
"""
self._ensure_latest_training_job()
role = role or self.role
Expand All @@ -810,7 +814,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=transform_env, tags=tags, base_transform_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session)
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)


def _s3_uri_prefix(channel_name, s3_data):
Expand Down
18 changes: 14 additions & 4 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Transformer(object):

def __init__(self, model_name, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, max_concurrent_transforms=None, max_payload=None, tags=None,
env=None, base_transform_job_name=None, sagemaker_session=None):
env=None, base_transform_job_name=None, sagemaker_session=None, volume_kms_key=None):
"""Initialize a ``Transformer``.

Args:
Expand All @@ -50,6 +50,8 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
"""
self.model_name = model_name
self.strategy = strategy
Expand All @@ -62,6 +64,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass

self.instance_count = instance_count
self.instance_type = instance_type
self.volume_kms_key = volume_kms_key

self.max_concurrent_transforms = max_concurrent_transforms
self.max_payload = max_payload
Expand Down Expand Up @@ -159,6 +162,7 @@ def _prepare_init_params_from_job_description(cls, job_details):
init_params['model_name'] = job_details['ModelName']
init_params['instance_count'] = job_details['TransformResources']['InstanceCount']
init_params['instance_type'] = job_details['TransformResources']['InstanceType']
init_params['volume_kms_key'] = job_details['TransformResources'].get('VolumeKmsKeyId')
init_params['strategy'] = job_details.get('BatchStrategy')
init_params['assemble_with'] = job_details['TransformOutput'].get('AssembleWith')
init_params['output_path'] = job_details['TransformOutput']['S3OutputPath']
Expand Down Expand Up @@ -200,7 +204,8 @@ def _load_config(data, data_type, content_type, compression_type, split_type, tr
output_config = _TransformJob._prepare_output_config(transformer.output_path, transformer.output_kms_key,
transformer.assemble_with, transformer.accept)

resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type)
resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type,
transformer.volume_kms_key)

return {'input_config': input_config,
'output_config': output_config,
Expand Down Expand Up @@ -241,5 +246,10 @@ def _prepare_output_config(s3_path, kms_key_id, assemble_with, accept):
return config

@staticmethod
def _prepare_resource_config(instance_count, instance_type):
return {'InstanceCount': instance_count, 'InstanceType': instance_type}
def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
config = {'InstanceCount': instance_count, 'InstanceType': instance_type}

if volume_kms_key is not None:
config['VolumeKmsKeyId'] = volume_kms_key

return config
91 changes: 91 additions & 0 deletions tests/integ/kms_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

KEY_ALIAS = "SageMakerKmsKey"
KEY_POLICY = '''
{{
"Version": "2012-10-17",
"Id": "sagemaker-kms-integ-test-policy",
"Statement": [
{{
"Sid": "Enable IAM User Permissions",
"Effect": "Allow",
"Principal": {{
"AWS": "arn:aws:iam::{account_id}:root"
}},
"Action": "kms:*",
"Resource": "*"
}},
{{
"Sid": "Allow use of the key",
"Effect": "Allow",
"Principal": {{
"AWS": "arn:aws:iam::{account_id}:role/SageMakerRole"
}},
"Action": [
"kms:Encrypt",
"kms:Decrypt",
"kms:ReEncrypt*",
"kms:GenerateDataKey*",
"kms:DescribeKey"
],
"Resource": "*"
}},
{{
"Sid": "Allow attachment of persistent resources",
"Effect": "Allow",
"Principal": {{
"AWS": "arn:aws:iam::{account_id}:role/SageMakerRole"
}},
"Action": [
"kms:CreateGrant",
"kms:ListGrants",
"kms:RevokeGrant"
],
"Resource": "*",
"Condition": {{
"Bool": {{
"kms:GrantIsForAWSResource": "true"
}}
}}
}}
]
}}
'''


def _get_kms_key_arn(kms_client, alias):
try:
response = kms_client.describe_key(KeyId='alias/' + alias)
return response['KeyMetadata']['Arn']
except kms_client.exceptions.NotFoundException:
return None


def _create_kms_key(kms_client, account_id):
response = kms_client.create_key(
Policy=KEY_POLICY.format(account_id=account_id),
Description='KMS key for SageMaker Python SDK integ tests',
)
key_arn = response['KeyMetadata']['Arn']
response = kms_client.create_alias(AliasName='alias/' + KEY_ALIAS, TargetKeyId=key_arn)
return key_arn


def get_or_create_kms_key(kms_client, account_id):
kms_key_arn = _get_kms_key_arn(kms_client, KEY_ALIAS)
if kms_key_arn is not None:
return kms_key_arn
else:
return _create_kms_key(kms_client, account_id)
15 changes: 12 additions & 3 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sagemaker.mxnet import MXNet
from sagemaker.transformer import Transformer
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.kms_utils import get_or_create_kms_key
from tests.integ.timeout import timeout


Expand All @@ -47,8 +48,16 @@ def test_transform_mxnet(sagemaker_session):
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
key_prefix=transform_input_key_prefix)

transformer = _create_transformer_and_transform_job(mx, transform_input)
sts_client = sagemaker_session.boto_session.client('sts')
account_id = sts_client.get_caller_identity()['Account']
kms_client = sagemaker_session.boto_session.client('kms')
kms_key_arn = get_or_create_kms_key(kms_client, account_id)

transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
transformer.wait()
job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=transformer.latest_transform_job.name)
assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId']


@pytest.mark.continuous_testing
Expand Down Expand Up @@ -90,7 +99,7 @@ def test_attach_transform_kmeans(sagemaker_session):
attached_transformer.wait()


def _create_transformer_and_transform_job(estimator, transform_input):
transformer = estimator.transformer(1, 'ml.m4.xlarge')
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
transformer.transform(transform_input, content_type='text/csv')
return transformer
3 changes: 2 additions & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with,
output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS,
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
env=env, role=new_role, model_server_workers=1)
volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1)

sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF)
assert transformer.strategy == strategy
Expand All @@ -507,6 +507,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
assert transformer.env == env
assert transformer.base_transform_job_name == base_name
assert transformer.tags == TAGS
assert transformer.volume_kms_key == kms_key


def test_ensure_latest_training_job(sagemaker_session):
Expand Down
13 changes: 9 additions & 4 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

INSTANCE_COUNT = 1
INSTANCE_TYPE = 'ml.m4.xlarge'
KMS_KEY_ID = 'kms-key-id'

S3_DATA_TYPE = 'S3Prefix'
S3_BUCKET = 'bucket'
Expand All @@ -48,7 +49,8 @@ def sagemaker_session():
@pytest.fixture()
def transformer(sagemaker_session):
return Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE,
output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session)
output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session,
volume_kms_key=KMS_KEY_ID)


@patch('sagemaker.transformer._TransformJob.start_new')
Expand Down Expand Up @@ -178,7 +180,8 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
'ModelName': MODEL_NAME,
'TransformResources': {
'InstanceCount': INSTANCE_COUNT,
'InstanceType': INSTANCE_TYPE
'InstanceType': INSTANCE_TYPE,
'VolumeKmsKeyId': KMS_KEY_ID
},
'BatchStrategy': None,
'TransformOutput': {
Expand All @@ -197,6 +200,7 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
assert init_params['model_name'] == MODEL_NAME
assert init_params['instance_count'] == INSTANCE_COUNT
assert init_params['instance_type'] == INSTANCE_TYPE
assert init_params['volume_kms_key'] == KMS_KEY_ID


# _TransformJob tests
Expand Down Expand Up @@ -227,6 +231,7 @@ def test_load_config(transformer):
'resource_config': {
'InstanceCount': INSTANCE_COUNT,
'InstanceType': INSTANCE_TYPE,
'VolumeKmsKeyId': KMS_KEY_ID,
},
}

Expand Down Expand Up @@ -292,8 +297,8 @@ def test_prepare_output_config_with_optional_params():


def test_prepare_resource_config():
config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE)
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE}
config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, KMS_KEY_ID)
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID}


def test_transform_job_wait(sagemaker_session):
Expand Down