Skip to content

Commit e65c833

Browse files
committed
Add VPC config to estimator for training job creation
CreateTraningJob api supports vpc. This change adds vpc config as an optional argument to Estimator.
1 parent 146e171 commit e65c833

File tree

11 files changed

+44
-9
lines changed

11 files changed

+44
-9
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
========
77

88
* bug-fix: Estimators: Fix serialization of single records
9+
* enhancement: Enable VPC config in training job creation
910

1011
1.9.0
1112
=====

src/sagemaker/estimator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4646

4747
def __init__(self, role, train_instance_count, train_instance_type,
4848
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
49-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None):
49+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None,
50+
subnets=None, security_group_ids=None):
5051
"""Initialize an ``EstimatorBase`` instance.
5152
5253
Args:
@@ -99,6 +100,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
99100
self.output_kms_key = output_kms_key
100101
self.latest_training_job = None
101102

103+
# VPC configurations
104+
self.subnets = subnets
105+
self.security_group_ids = security_group_ids
106+
102107
@abstractmethod
103108
def train_image(self):
104109
"""Return the Docker image to use for training.
@@ -398,8 +403,9 @@ def start_new(cls, estimator, inputs):
398403
estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
399404
input_config=config['input_config'], role=config['role'],
400405
job_name=estimator._current_job_name, output_config=config['output_config'],
401-
resource_config=config['resource_config'], hyperparameters=hyperparameters,
402-
stop_condition=config['stop_condition'], tags=estimator.tags)
406+
resource_config=config['resource_config'], vpc_config=config['vpc_config'],
407+
hyperparameters=hyperparameters, stop_condition=config['stop_condition'],
408+
tags=estimator.tags)
403409

404410
return cls(estimator.sagemaker_session, estimator._current_job_name)
405411

src/sagemaker/job.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,14 @@ def _load_config(inputs, estimator):
5959
estimator.train_instance_type,
6060
estimator.train_volume_size)
6161
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
62+
vpc_config = _Job._prepare_vpc_config(estimator.subnets, estimator.security_group_ids)
6263

6364
return {'input_config': input_config,
6465
'role': role,
6566
'output_config': output_config,
6667
'resource_config': resource_config,
67-
'stop_condition': stop_condition}
68+
'stop_condition': stop_condition,
69+
'vpc_config': vpc_config}
6870

6971
@staticmethod
7072
def _format_inputs_to_input_config(inputs):
@@ -143,6 +145,13 @@ def _prepare_resource_config(instance_count, instance_type, volume_size):
143145
'InstanceType': instance_type,
144146
'VolumeSizeInGB': volume_size}
145147

148+
@staticmethod
149+
def _prepare_vpc_config(subnets, security_group_ids):
150+
if subnets is None or security_group_ids is None:
151+
return None
152+
return {'Subnets': subnets,
153+
'SecurityGroupIds': security_group_ids}
154+
146155
@staticmethod
147156
def _prepare_stop_condition(max_run):
148157
return {'MaxRuntimeInSeconds': max_run}

src/sagemaker/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def default_bucket(self):
202202
return self._default_bucket
203203

204204
def train(self, image, input_mode, input_config, role, job_name, output_config,
205-
resource_config, hyperparameters, stop_condition, tags):
205+
resource_config, vpc_config, hyperparameters, stop_condition, tags):
206206
"""Create an Amazon SageMaker training job.
207207
208208
Args:
@@ -259,6 +259,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
259259
if tags is not None:
260260
train_request['Tags'] = tags
261261

262+
if vpc_config is not None:
263+
train_request['VpcConfig'] = vpc_config
264+
262265
LOGGER.info('Creating training-job with name: {}'.format(job_name))
263266
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
264267
self.sagemaker_client.create_training_job(**train_request)

tests/integ/test_tf.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout
2323

2424
DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data')
25+
VPC_SUBNETS = ['subnet-06b8537735fac3757']
26+
VPC_SECURITY_GROUP_IDS = ['sg-0a1008de6e1f384c3']
2527

2628

2729
@pytest.mark.continuous_testing
@@ -98,10 +100,17 @@ def test_failed_tf_training(sagemaker_session, tf_full_version):
98100
hyperparameters={'input_tensor_name': 'inputs'},
99101
train_instance_count=1,
100102
train_instance_type='ml.c4.xlarge',
101-
sagemaker_session=sagemaker_session)
103+
sagemaker_session=sagemaker_session,
104+
subnets=VPC_SUBNETS,
105+
security_group_ids=VPC_SECURITY_GROUP_IDS)
102106

103107
inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure')
104108

105109
with pytest.raises(ValueError) as e:
106110
estimator.fit(inputs)
107111
assert 'This failure is expected' in str(e.value)
112+
113+
job_desc = estimator.sagemaker_session.sagemaker_client.describe_training_job(
114+
TrainingJobName=estimator.latest_training_job.name)
115+
assert VPC_SUBNETS == job_desc['VpcConfig']['Subnets']
116+
assert VPC_SECURITY_GROUP_IDS == job_desc['VpcConfig']['SecurityGroupIds']

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def _create_train_job(version):
122122
'MaxRuntimeInSeconds': 24 * 60 * 60
123123
},
124124
'tags': None,
125+
'vpc_config': {'SecurityGroupIds': None, 'Subnets': None}
125126
}
126127

127128

tests/unit/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def test_unsupported_type_in_dict():
658658
},
659659
'stop_condition': {'MaxRuntimeInSeconds': 86400},
660660
'tags': None,
661+
'vpc_config': {'SecurityGroupIds': None, 'Subnets': None}
661662
}
662663

663664
HYPERPARAMS = {'x': 1, 'y': 'hello'}

tests/unit/test_mxnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def _create_train_job(version):
9595
'MaxRuntimeInSeconds': 24 * 60 * 60
9696
},
9797
'tags': None,
98+
'vpc_config': {'SecurityGroupIds': None, 'Subnets': None}
9899
}
99100

100101

tests/unit/test_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def _create_train_job(version):
112112
'stop_condition': {
113113
'MaxRuntimeInSeconds': 24 * 60 * 60
114114
},
115-
'tags': None
115+
'tags': None,
116+
'vpc_config': {'SecurityGroupIds': None, 'Subnets': None}
116117
}
117118

118119

tests/unit/test_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def test_s3_input_all_arguments():
176176
MAX_TIME = 3 * 60 * 60
177177
JOB_NAME = 'jobname'
178178
TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
179+
VPC_CONFIG = {'Subnets': 'subnet', 'SecurityGroupIds': 'sgi-blahblah'}
179180

180181
DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
181182
'OutputDataConfig': {
@@ -259,7 +260,7 @@ def test_train_pack_to_request(sagemaker_session):
259260

260261
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
261262
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
262-
hyperparameters=None, stop_condition=stop_cond, tags=None)
263+
hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=None)
263264

264265
assert sagemaker_session.sagemaker_client.method_calls[0] == (
265266
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
@@ -322,12 +323,13 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
322323

323324
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
324325
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
325-
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS)
326+
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS, vpc_config=VPC_CONFIG)
326327

327328
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
328329

329330
assert actual_train_args['HyperParameters'] == hyperparameters
330331
assert actual_train_args['Tags'] == TAGS
332+
assert actual_train_args['VpcConfig'] == VPC_CONFIG
331333

332334

333335
def test_transform_pack_to_request(sagemaker_session):

tests/unit/test_tf_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def _create_train_job(tf_version):
103103
'MaxRuntimeInSeconds': 24 * 60 * 60
104104
},
105105
'tags': None,
106+
'vpc_config': {'SecurityGroupIds': None, 'Subnets': None}
106107
}
107108

108109

0 commit comments

Comments
 (0)