Skip to content

Add volume KMS key to transformer #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 27, 2018
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CHANGELOG
=====

* feature: add support for TensorFlow 1.9
* enhancement: Add support for volume KMS key to Transformer
Copy link
Contributor

@nadiaya nadiaya Aug 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move it to the next section (1.9.3dev I believe)? This one already was released

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, also updated transformer integ tests to include the volume KMS key


1.9.1
=====
Expand Down
13 changes: 9 additions & 4 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Transformer(object):

def __init__(self, model_name, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
output_kms_key=None, accept=None, max_concurrent_transforms=None, max_payload=None, tags=None,
env=None, base_transform_job_name=None, sagemaker_session=None):
env=None, base_transform_job_name=None, sagemaker_session=None, volume_kms_key=None):
"""Initialize a ``Transformer``.

Args:
Expand All @@ -50,6 +50,8 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
compute instance (default: None).
"""
self.model_name = model_name
self.strategy = strategy
Expand All @@ -62,6 +64,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass

self.instance_count = instance_count
self.instance_type = instance_type
self.volume_kms_key = volume_kms_key

self.max_concurrent_transforms = max_concurrent_transforms
self.max_payload = max_payload
Expand Down Expand Up @@ -159,6 +162,7 @@ def _prepare_init_params_from_job_description(cls, job_details):
init_params['model_name'] = job_details['ModelName']
init_params['instance_count'] = job_details['TransformResources']['InstanceCount']
init_params['instance_type'] = job_details['TransformResources']['InstanceType']
init_params['volume_kms_key'] = job_details['TransformResources'].get('VolumeKmsKeyId')
init_params['strategy'] = job_details.get('BatchStrategy')
init_params['assemble_with'] = job_details['TransformOutput'].get('AssembleWith')
init_params['output_path'] = job_details['TransformOutput']['S3OutputPath']
Expand Down Expand Up @@ -200,7 +204,8 @@ def _load_config(data, data_type, content_type, compression_type, split_type, tr
output_config = _TransformJob._prepare_output_config(transformer.output_path, transformer.output_kms_key,
transformer.assemble_with, transformer.accept)

resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type)
resource_config = _TransformJob._prepare_resource_config(transformer.instance_count, transformer.instance_type,
transformer.volume_kms_key)

return {'input_config': input_config,
'output_config': output_config,
Expand Down Expand Up @@ -241,5 +246,5 @@ def _prepare_output_config(s3_path, kms_key_id, assemble_with, accept):
return config

@staticmethod
def _prepare_resource_config(instance_count, instance_type):
return {'InstanceCount': instance_count, 'InstanceType': instance_type}
def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
return {'InstanceCount': instance_count, 'InstanceType': instance_type, 'VolumeKmsKeyId': volume_kms_key}
13 changes: 9 additions & 4 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

INSTANCE_COUNT = 1
INSTANCE_TYPE = 'ml.m4.xlarge'
KMS_KEY_ID = 'kms-key-id'

S3_DATA_TYPE = 'S3Prefix'
S3_BUCKET = 'bucket'
Expand All @@ -48,7 +49,8 @@ def sagemaker_session():
@pytest.fixture()
def transformer(sagemaker_session):
return Transformer(MODEL_NAME, INSTANCE_COUNT, INSTANCE_TYPE,
output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session)
output_path=OUTPUT_PATH, sagemaker_session=sagemaker_session,
volume_kms_key=KMS_KEY_ID)


@patch('sagemaker.transformer._TransformJob.start_new')
Expand Down Expand Up @@ -178,7 +180,8 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
'ModelName': MODEL_NAME,
'TransformResources': {
'InstanceCount': INSTANCE_COUNT,
'InstanceType': INSTANCE_TYPE
'InstanceType': INSTANCE_TYPE,
'VolumeKmsKeyId': KMS_KEY_ID
},
'BatchStrategy': None,
'TransformOutput': {
Expand All @@ -197,6 +200,7 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
assert init_params['model_name'] == MODEL_NAME
assert init_params['instance_count'] == INSTANCE_COUNT
assert init_params['instance_type'] == INSTANCE_TYPE
assert init_params['volume_kms_key'] == KMS_KEY_ID


# _TransformJob tests
Expand Down Expand Up @@ -227,6 +231,7 @@ def test_load_config(transformer):
'resource_config': {
'InstanceCount': INSTANCE_COUNT,
'InstanceType': INSTANCE_TYPE,
'VolumeKmsKeyId': KMS_KEY_ID,
},
}

Expand Down Expand Up @@ -292,8 +297,8 @@ def test_prepare_output_config_with_optional_params():


def test_prepare_resource_config():
config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE)
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE}
config = _TransformJob._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, KMS_KEY_ID)
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID}


def test_transform_job_wait(sagemaker_session):
Expand Down