Skip to content

Commit 52e4d76

Browse files
authored
Add train_volume_kms_key parameter (#389)
* Add train_volume_kms_key parameter * Fix flake8
1 parent a00d5ad commit 52e4d76

File tree

6 files changed

+75
-46
lines changed

6 files changed

+75
-46
lines changed

src/sagemaker/amazon/knn.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,22 @@ def __init__(self, role, train_instance_count, train_instance_type, k, sample_si
4343
dimension_reduction_type=None, dimension_reduction_target=None, index_type=None,
4444
index_metric=None, faiss_index_ivf_nlists=None, faiss_index_pq_m=None, **kwargs):
4545
"""k-nearest neighbors (KNN) is :class:`Estimator` used for classification and regression.
46-
4746
This Estimator may be fit via calls to
4847
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires Amazon
4948
:class:`~sagemaker.amazon.record_pb2.Record` protobuf serialized data to be stored in S3.
5049
There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that
5150
can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed
5251
to the `fit` call.
53-
5452
To learn more about the Amazon protobuf Record class and how to prepare bulk data in this format, please
5553
consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html
56-
5754
After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker
5855
Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint,
5956
deploy returns a :class:`~sagemaker.amazon.knn.KNNPredictor` object that can be used
6057
for inference calls using the trained model hosted in the SageMaker Endpoint.
61-
6258
KNN Estimators can be configured by setting hyperparameters. The available hyperparameters for
6359
KNN are documented below.
64-
6560
For further information on the AWS KNN algorithm,
6661
please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/knn.html
67-
6862
Args:
6963
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and
7064
APIs that create Amazon SageMaker endpoints use this role to access
@@ -76,17 +70,17 @@ def __init__(self, role, train_instance_count, train_instance_type, k, sample_si
7670
predictor_type (str): Required. Type of inference to use on the data's labels,
7771
allowed values are 'classifier' and 'regressor'.
7872
dimension_reduction_type (str): Optional. Type of dimension reduction technique to use.
79-
Valid values: sign”, “fjlt
73+
Valid values: "sign", "fjlt"
8074
dimension_reduction_target (int): Optional. Target dimension to reduce to. Required when
8175
dimension_reduction_type is specified.
8276
index_type (str): Optional. Type of index to use. Valid values are
83-
faiss.Flat”, “faiss.IVFFlat”, “faiss.IVFPQ.
77+
"faiss.Flat", "faiss.IVFFlat", "faiss.IVFPQ".
8478
index_metric(str): Optional. Distance metric to measure between points when finding nearest neighbors.
8579
Valid values are "COSINE", "INNER_PRODUCT", "L2"
8680
faiss_index_ivf_nlists(str): Optional. Number of centroids to construct in the index if
87-
index_type is faiss.IVFFlat or faiss.IVFPQ.
81+
index_type is "faiss.IVFFlat" or "faiss.IVFPQ".
8882
faiss_index_pq_m(int): Optional. Number of vector sub-components to construct in the index,
89-
if index_type is faiss.IVFPQ.
83+
if index_type is "faiss.IVFPQ".
9084
**kwargs: base class keyword argument values.
9185
"""
9286

src/sagemaker/estimator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4646
"""
4747

4848
def __init__(self, role, train_instance_count, train_instance_type,
49-
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
49+
train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, input_mode='File',
5050
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None,
5151
subnets=None, security_group_ids=None):
5252
"""Initialize an ``EstimatorBase`` instance.
@@ -61,6 +61,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
6161
train_volume_size (int): Size in GB of the EBS volume to use for storing input data
6262
during training (default: 30). Must be large enough to store training data if File Mode is used
6363
(which is the default).
64+
train_volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached to the
65+
training instance (default: None).
6466
train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
6567
After this amount of time Amazon SageMaker terminates the job regardless of its current status.
6668
input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes:
@@ -87,6 +89,7 @@ def __init__(self, role, train_instance_count, train_instance_type,
8789
self.train_instance_count = train_instance_count
8890
self.train_instance_type = train_instance_type
8991
self.train_volume_size = train_volume_size
92+
self.train_volume_kms_key = train_volume_kms_key
9093
self.train_max_run = train_max_run
9194
self.input_mode = input_mode
9295
self.tags = tags
@@ -427,9 +430,9 @@ class Estimator(EstimatorBase):
427430
"""
428431

429432
def __init__(self, image_name, role, train_instance_count, train_instance_type,
430-
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
431-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None,
432-
hyperparameters=None, tags=None, subnets=None, security_group_ids=None):
433+
train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60,
434+
input_mode='File', output_path=None, output_kms_key=None, base_job_name=None,
435+
sagemaker_session=None, hyperparameters=None, tags=None, subnets=None, security_group_ids=None):
433436
"""Initialize an ``Estimator`` instance.
434437
435438
Args:
@@ -443,6 +446,8 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
443446
train_volume_size (int): Size in GB of the EBS volume to use for storing input data
444447
during training (default: 30). Must be large enough to store training data if File Mode is used
445448
(which is the default).
449+
train_volume_kms_key (str): Optional. KMS key ID for encrypting EBS volume attached to the
450+
training instance (default: None).
446451
train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
447452
After this amount of time Amazon SageMaker terminates the job regardless of its current status.
448453
input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes:
@@ -462,11 +467,16 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
462467
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
463468
using the default AWS configuration chain.
464469
hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with.
470+
tags (list[dict]): List of tags for labeling a training job. For more, see
471+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
472+
subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config.
473+
security_group_ids (list[str]): List of security group ids. If not specified training job will be created
474+
without VPC config.
465475
"""
466476
self.image_name = image_name
467477
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
468478
super(Estimator, self).__init__(role, train_instance_count, train_instance_type,
469-
train_volume_size, train_max_run, input_mode,
479+
train_volume_size, train_volume_kms_key, train_max_run, input_mode,
470480
output_path, output_kms_key, base_job_name, sagemaker_session,
471481
tags, subnets, security_group_ids)
472482

src/sagemaker/job.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def _load_config(inputs, estimator):
5757
output_config = _Job._prepare_output_config(estimator.output_path, estimator.output_kms_key)
5858
resource_config = _Job._prepare_resource_config(estimator.train_instance_count,
5959
estimator.train_instance_type,
60-
estimator.train_volume_size)
60+
estimator.train_volume_size,
61+
estimator.train_volume_kms_key)
6162
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
6263
vpc_config = _Job._prepare_vpc_config(estimator.subnets, estimator.security_group_ids)
6364

@@ -140,10 +141,14 @@ def _prepare_output_config(s3_path, kms_key_id):
140141
return config
141142

142143
@staticmethod
143-
def _prepare_resource_config(instance_count, instance_type, volume_size):
144-
return {'InstanceCount': instance_count,
145-
'InstanceType': instance_type,
146-
'VolumeSizeInGB': volume_size}
144+
def _prepare_resource_config(instance_count, instance_type, volume_size, train_volume_kms_key):
145+
resource_config = {'InstanceCount': instance_count,
146+
'InstanceType': instance_type,
147+
'VolumeSizeInGB': volume_size}
148+
if train_volume_kms_key is not None:
149+
resource_config['VolumeKmsKeyId'] = train_volume_kms_key
150+
151+
return resource_config
147152

148153
@staticmethod
149154
def _prepare_vpc_config(subnets, security_group_ids):

src/sagemaker/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
259259
'InputDataConfig': input_config,
260260
'OutputDataConfig': output_config,
261261
'TrainingJobName': job_name,
262-
"StoppingCondition": stop_condition,
263-
"ResourceConfig": resource_config,
264-
"RoleArn": role,
262+
'StoppingCondition': stop_condition,
263+
'ResourceConfig': resource_config,
264+
'RoleArn': role,
265265
}
266266

267267
if hyperparameters and len(hyperparameters) > 0:

tests/unit/test_estimator.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,6 @@
4444
TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
4545
OUTPUT_PATH = 's3://bucket/prefix'
4646

47-
COMMON_TRAIN_ARGS = {
48-
'volume_size': 30,
49-
'hyperparameters': {
50-
'sagemaker_program': 'dummy_script.py',
51-
'sagemaker_enable_cloudwatch_metrics': False,
52-
'sagemaker_container_log_level': logging.INFO,
53-
},
54-
'input_mode': 'File',
55-
'instance_type': 'c4.4xlarge',
56-
'inputs': 's3://mybucket/train',
57-
'instance_count': 1,
58-
'role': 'DummyRole',
59-
'kms_key_id': None,
60-
'max_run': 24,
61-
'wait': True,
62-
}
63-
6447
DESCRIBE_TRAINING_JOB_RESULT = {
6548
'ModelArtifacts': {
6649
'S3ModelArtifacts': MODEL_DATA
@@ -119,6 +102,29 @@ def sagemaker_session():
119102
return sms
120103

121104

105+
def test_framework_all_init_args(sagemaker_session):
106+
f = DummyFramework('my_script.py', role='DummyRole', train_instance_count=3, train_instance_type='ml.m4.xlarge',
107+
sagemaker_session=sagemaker_session, train_volume_size=123, train_volume_kms_key='volumekms',
108+
train_max_run=456, input_mode='inputmode', output_path='outputpath', output_kms_key='outputkms',
109+
base_job_name='basejobname', tags=[{'foo': 'bar'}], subnets=['123', '456'],
110+
security_group_ids=['789', '012'])
111+
_TrainingJob.start_new(f, 's3://mydata')
112+
sagemaker_session.train.assert_called_once()
113+
_, args = sagemaker_session.train.call_args
114+
assert args == {'input_mode': 'inputmode', 'tags': [{'foo': 'bar'}], 'hyperparameters': {}, 'image': 'fakeimage',
115+
'input_config': [{'ChannelName': 'training',
116+
'DataSource': {
117+
'S3DataSource': {'S3DataType': 'S3Prefix',
118+
'S3DataDistributionType': 'FullyReplicated',
119+
'S3Uri': 's3://mydata'}}}],
120+
'output_config': {'KmsKeyId': 'outputkms', 'S3OutputPath': 'outputpath'},
121+
'vpc_config': {'Subnets': ['123', '456'], 'SecurityGroupIds': ['789', '012']},
122+
'stop_condition': {'MaxRuntimeInSeconds': 456},
123+
'role': sagemaker_session.expand_role(), 'job_name': None,
124+
'resource_config': {'VolumeSizeInGB': 123, 'InstanceCount': 3, 'VolumeKmsKeyId': 'volumekms',
125+
'InstanceType': 'ml.m4.xlarge'}}
126+
127+
122128
def test_sagemaker_s3_uri_invalid(sagemaker_session):
123129
with pytest.raises(ValueError) as error:
124130
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,

tests/unit/test_job.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@
3030
ROLE = 'DummyRole'
3131
IMAGE_NAME = 'fakeimage'
3232
JOB_NAME = 'fakejob'
33+
VOLUME_KMS_KEY = 'volkmskey'
3334

3435

3536
@pytest.fixture()
3637
def estimator(sagemaker_session):
37-
return Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, MAX_RUNTIME,
38-
output_path=S3_OUTPUT_PATH, sagemaker_session=sagemaker_session)
38+
return Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, train_volume_size=VOLUME_SIZE,
39+
train_max_run=MAX_RUNTIME, output_path=S3_OUTPUT_PATH, sagemaker_session=sagemaker_session)
3940

4041

4142
@pytest.fixture()
@@ -282,11 +283,24 @@ def test_prepare_output_config_kms_key_none():
282283

283284

284285
def test_prepare_resource_config():
285-
resource_config = _Job._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE)
286+
resource_config = _Job._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, None)
286287

287-
assert resource_config['InstanceCount'] == INSTANCE_COUNT
288-
assert resource_config['InstanceType'] == INSTANCE_TYPE
289-
assert resource_config['VolumeSizeInGB'] == VOLUME_SIZE
288+
assert resource_config == {
289+
'InstanceCount': INSTANCE_COUNT,
290+
'InstanceType': INSTANCE_TYPE,
291+
'VolumeSizeInGB': VOLUME_SIZE
292+
}
293+
294+
295+
def test_prepare_resource_config_with_volume_kms():
296+
resource_config = _Job._prepare_resource_config(INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, VOLUME_KMS_KEY)
297+
298+
assert resource_config == {
299+
'InstanceCount': INSTANCE_COUNT,
300+
'InstanceType': INSTANCE_TYPE,
301+
'VolumeSizeInGB': VOLUME_SIZE,
302+
'VolumeKmsKeyId': VOLUME_KMS_KEY
303+
}
290304

291305

292306
def test_prepare_stop_condition():

0 commit comments

Comments
 (0)