diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 22f03c347d..ffee9b463c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,7 @@ CHANGELOG 1.9.3 ===== +* enhancement: Add support for volume KMS key to Transformer * bug-fix: Local Mode: Create output/data directory expected by SageMaker Container. * bug-fix: Estimator accepts the vpc configs made capable by 1.9.1 diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 25a3f044fb..b33ea25edc 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -331,7 +331,7 @@ def delete_endpoint(self): def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, role=None): + max_payload=None, tags=None, role=None, volume_kms_key=None): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator. @@ -353,6 +353,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit the training job are used for the transform job. role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during transform jobs. If not specified, the role from the Estimator will be used. + volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML + compute instance (default: None). """ self._ensure_latest_training_job() @@ -363,7 +365,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit output_path=output_path, output_kms_key=output_kms_key, accept=accept, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, env=env, tags=tags, base_transform_job_name=self.base_job_name, - sagemaker_session=self.sagemaker_session) + volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session) @property def training_job_analytics(self): @@ -767,7 +769,7 @@ def _update_init_params(cls, hp, tf_arguments): def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None, - max_payload=None, tags=None, role=None, model_server_workers=None): + max_payload=None, tags=None, role=None, model_server_workers=None, volume_kms_key=None): """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the SageMaker Session and base job name used by the Estimator. @@ -791,6 +793,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit transform jobs. If not specified, the role from the Estimator will be used. model_server_workers (int): Optional. The number of worker processes used by the inference server. If None, server will use one worker per vCPU. + volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML + compute instance (default: None). """ self._ensure_latest_training_job() role = role or self.role @@ -810,7 +814,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit output_path=output_path, output_kms_key=output_kms_key, accept=accept, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, env=transform_env, tags=tags, base_transform_job_name=self.base_job_name, - sagemaker_session=self.sagemaker_session) + volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session) def _s3_uri_prefix(channel_name, s3_data): diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 2d707eaa54..a9c17afdd0 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -23,7 +23,7 @@ class Transformer(object): def __init__(self, model_name, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None, output_kms_key=None, accept=None, max_concurrent_transforms=None, max_payload=None, tags=None, - env=None, base_transform_job_name=None, sagemaker_session=None): + env=None, base_transform_job_name=None, sagemaker_session=None, volume_kms_key=None): """Initialize a ``Transformer``. Args: @@ -50,6 +50,8 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML + compute instance (default: None). """ self.model_name = model_name self.strategy = strategy @@ -62,6 +64,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass self.instance_count = instance_count self.instance_type = instance_type + self.volume_kms_key = volume_kms_key self.max_concurrent_transforms = max_concurrent_transforms self.max_payload = max_payload @@ -159,6 +162,7 @@ def _prepare_init_params_from_job_description(cls, job_details): init_params['model_name'] = job_details['ModelName'] init_params['instance_count'] = job_details['TransformResources']['InstanceCount'] init_params['instance_type'] = job_details['TransformResources']['InstanceType'] + init_params['volume_kms_key'] = job_details['TransformResources'].get('VolumeKmsKeyId') init_params['strategy'] = job_details.get('BatchStrategy') init_params['assemble_with'] = job_details['TransformOutput'].get('AssembleWith') init_params['output_path'] = job_details['TransformOutput']['S3OutputPath'] @@ -200,7 +204,8 @@ def _load_config(data, data_type, content_type, compression_type, split_type, tr output_config = _TransformJob._prepare_output_config(transformer.output_path, transformer.output_kms_key, transformer.assemble_with, transformer.accept) - resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type) + resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type, + transformer.volume_kms_key) return {'input_config': input_config, 'output_config': output_config, @@ -241,5 +246,10 @@ def _prepare_output_config(s3_path, kms_key_id, assemble_with, accept): return config @staticmethod - def _prepare_resource_config(instance_count, instance_type): - return {'InstanceCount': instance_count, 'InstanceType': instance_type} + def _prepare_resource_config(instance_count, instance_type, volume_kms_key): + config = {'InstanceCount': instance_count, 'InstanceType': instance_type} + + if volume_kms_key is not None: + config['VolumeKmsKeyId'] = volume_kms_key + + return config diff --git a/tests/integ/kms_utils.py b/tests/integ/kms_utils.py new file mode 100644 index 0000000000..047a585980 --- /dev/null +++ b/tests/integ/kms_utils.py @@ -0,0 +1,91 @@ +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +KEY_ALIAS = "SageMakerKmsKey" +KEY_POLICY = ''' +{{ + "Version": "2012-10-17", + "Id": "sagemaker-kms-integ-test-policy", + "Statement": [ + {{ + "Sid": "Enable IAM User Permissions", + "Effect": "Allow", + "Principal": {{ + "AWS": "arn:aws:iam::{account_id}:root" + }}, + "Action": "kms:*", + "Resource": "*" + }}, + {{ + "Sid": "Allow use of the key", + "Effect": "Allow", + "Principal": {{ + "AWS": "arn:aws:iam::{account_id}:role/SageMakerRole" + }}, + "Action": [ + "kms:Encrypt", + "kms:Decrypt", + "kms:ReEncrypt*", + "kms:GenerateDataKey*", + "kms:DescribeKey" + ], + "Resource": "*" + }}, + {{ + "Sid": "Allow attachment of persistent resources", + "Effect": "Allow", + "Principal": {{ + "AWS": "arn:aws:iam::{account_id}:role/SageMakerRole" + }}, + "Action": [ + "kms:CreateGrant", + "kms:ListGrants", + "kms:RevokeGrant" + ], + "Resource": "*", + "Condition": {{ + "Bool": {{ + "kms:GrantIsForAWSResource": "true" + }} + }} + }} + ] +}} +''' + + +def _get_kms_key_arn(kms_client, alias): + try: + response = kms_client.describe_key(KeyId='alias/' + alias) + return response['KeyMetadata']['Arn'] + except kms_client.exceptions.NotFoundException: + return None + + +def _create_kms_key(kms_client, account_id): + response = kms_client.create_key( + Policy=KEY_POLICY.format(account_id=account_id), + Description='KMS key for SageMaker Python SDK integ tests', + ) + key_arn = response['KeyMetadata']['Arn'] + response = kms_client.create_alias(AliasName='alias/' + KEY_ALIAS, TargetKeyId=key_arn) + return key_arn + + +def get_or_create_kms_key(kms_client, account_id): + kms_key_arn = _get_kms_key_arn(kms_client, KEY_ALIAS) + if kms_key_arn is not None: + return kms_key_arn + else: + return _create_kms_key(kms_client, account_id) diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index 1b5a5160fd..dc626b23d6 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -23,6 +23,7 @@ from sagemaker.mxnet import MXNet from sagemaker.transformer import Transformer from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES +from tests.integ.kms_utils import get_or_create_kms_key from tests.integ.timeout import timeout @@ -47,8 +48,16 @@ def test_transform_mxnet(sagemaker_session): transform_input = mx.sagemaker_session.upload_data(path=transform_input_path, key_prefix=transform_input_key_prefix) - transformer = _create_transformer_and_transform_job(mx, transform_input) + sts_client = sagemaker_session.boto_session.client('sts') + account_id = sts_client.get_caller_identity()['Account'] + kms_client = sagemaker_session.boto_session.client('kms') + kms_key_arn = get_or_create_kms_key(kms_client, account_id) + + transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn) transformer.wait() + job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job( + TransformJobName=transformer.latest_transform_job.name) + assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId'] @pytest.mark.continuous_testing @@ -90,7 +99,7 @@ def test_attach_transform_kmeans(sagemaker_session): attached_transformer.wait() -def _create_transformer_and_transform_job(estimator, transform_input): - transformer = estimator.transformer(1, 'ml.m4.xlarge') +def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None): + transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) transformer.transform(transform_input, content_type='text/csv') return transformer diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 4ecb8ee85b..062475e8f2 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -494,7 +494,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE, strategy=strategy, assemble_with=assemble_with, output_path=OUTPUT_PATH, output_kms_key=kms_key, accept=accept, tags=TAGS, max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload, - env=env, role=new_role, model_server_workers=1) + volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1) sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF) assert transformer.strategy == strategy @@ -507,6 +507,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa assert transformer.env == env assert transformer.base_transform_job_name == base_name assert transformer.tags == TAGS + assert transformer.volume_kms_key == kms_key def test_ensure_latest_training_job(sagemaker_session): diff --git a/tests/unit/test_transformer.py b/tests/unit/test_transformer.py index 006970fd39..200d5df52d 100644 --- a/tests/unit/test_transformer.py +++ b/tests/unit/test_transformer.py @@ -23,6 +23,7 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = 'ml.m4.xlarge' +KMS_KEY_ID = 'kms-key-id' S3_DATA_TYPE = 'S3Prefix' S3_BUCKET = 'bucket' @@ -48,7 +49,8 @@ def sagemaker_session(): @pytest.fixture() def transformer(sagemaker_session): return Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE, - output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session) + output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session, + volume_kms_key=KMS_KEY_ID) @patch('sagemaker.transformer._TransformJob.start_new') @@ -178,7 +180,8 @@ def test_prepare_init_params_from_job_description_all_keys(transformer): 'ModelName': MODEL_NAME, 'TransformResources': { 'InstanceCount': INSTANCE_COUNT, - 'InstanceType': INSTANCE_TYPE + 'InstanceType': INSTANCE_TYPE, + 'VolumeKmsKeyId': KMS_KEY_ID }, 'BatchStrategy': None, 'TransformOutput': { @@ -197,6 +200,7 @@ def test_prepare_init_params_from_job_description_all_keys(transformer): assert init_params['model_name'] == MODEL_NAME assert init_params['instance_count'] == INSTANCE_COUNT assert init_params['instance_type'] == INSTANCE_TYPE + assert init_params['volume_kms_key'] == KMS_KEY_ID # _TransformJob tests @@ -227,6 +231,7 @@ def test_load_config(transformer): 'resource_config': { 'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, + 'VolumeKmsKeyId': KMS_KEY_ID, }, } @@ -292,8 +297,8 @@ def test_prepare_output_config_with_optional_params(): def test_prepare_resource_config(): - config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE) - assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE} + config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, KMS_KEY_ID) + assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID} def test_transform_job_wait(sagemaker_session):