Skip to content

Commit 5e61265

Browse files
authored
Add VPC config to estimator for training job creation (#353)
* Add VPC config to estimator for training job creation CreateTraningJob api supports vpc. This change adds vpc config as an optional argument to Estimator.
1 parent d7665de commit 5e61265

12 files changed

+127
-9
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CHANGELOG
77

88
* bug-fix: Estimators: Fix serialization of single records
99
* bug-fix: deprecate enable_cloudwatch_metrics from Framework Estimators.
10+
* enhancement: Enable VPC config in training job creation
1011

1112
1.9.0
1213
=====

src/sagemaker/estimator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4747

4848
def __init__(self, role, train_instance_count, train_instance_type,
4949
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
50-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None):
50+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None,
51+
subnets=None, security_group_ids=None):
5152
"""Initialize an ``EstimatorBase`` instance.
5253
5354
Args:
@@ -78,6 +79,9 @@ def __init__(self, role, train_instance_count, train_instance_type,
7879
using the default AWS configuration chain.
7980
tags (list[dict]): List of tags for labeling a training job. For more, see
8081
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
82+
subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config.
83+
security_group_ids (list[str]): List of security group ids. If not specified training job will be created
84+
without VPC config.
8185
"""
8286
self.role = role
8387
self.train_instance_count = train_instance_count
@@ -100,6 +104,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
100104
self.output_kms_key = output_kms_key
101105
self.latest_training_job = None
102106

107+
# VPC configurations
108+
self.subnets = subnets
109+
self.security_group_ids = security_group_ids
110+
103111
@abstractmethod
104112
def train_image(self):
105113
"""Return the Docker image to use for training.
@@ -399,8 +407,9 @@ def start_new(cls, estimator, inputs):
399407
estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
400408
input_config=config['input_config'], role=config['role'],
401409
job_name=estimator._current_job_name, output_config=config['output_config'],
402-
resource_config=config['resource_config'], hyperparameters=hyperparameters,
403-
stop_condition=config['stop_condition'], tags=estimator.tags)
410+
resource_config=config['resource_config'], vpc_config=config['vpc_config'],
411+
hyperparameters=hyperparameters, stop_condition=config['stop_condition'],
412+
tags=estimator.tags)
404413

405414
return cls(estimator.sagemaker_session, estimator._current_job_name)
406415

src/sagemaker/job.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,14 @@ def _load_config(inputs, estimator):
5959
estimator.train_instance_type,
6060
estimator.train_volume_size)
6161
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
62+
vpc_config = _Job._prepare_vpc_config(estimator.subnets, estimator.security_group_ids)
6263

6364
return {'input_config': input_config,
6465
'role': role,
6566
'output_config': output_config,
6667
'resource_config': resource_config,
67-
'stop_condition': stop_condition}
68+
'stop_condition': stop_condition,
69+
'vpc_config': vpc_config}
6870

6971
@staticmethod
7072
def _format_inputs_to_input_config(inputs):
@@ -143,6 +145,13 @@ def _prepare_resource_config(instance_count, instance_type, volume_size):
143145
'InstanceType': instance_type,
144146
'VolumeSizeInGB': volume_size}
145147

148+
@staticmethod
149+
def _prepare_vpc_config(subnets, security_group_ids):
150+
if subnets is None or security_group_ids is None:
151+
return None
152+
return {'Subnets': subnets,
153+
'SecurityGroupIds': security_group_ids}
154+
146155
@staticmethod
147156
def _prepare_stop_condition(max_run):
148157
return {'MaxRuntimeInSeconds': max_run}

src/sagemaker/session.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def default_bucket(self):
202202
return self._default_bucket
203203

204204
def train(self, image, input_mode, input_config, role, job_name, output_config,
205-
resource_config, hyperparameters, stop_condition, tags):
205+
resource_config, vpc_config, hyperparameters, stop_condition, tags):
206206
"""Create an Amazon SageMaker training job.
207207
208208
Args:
@@ -228,6 +228,13 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
228228
* instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
229229
The key in resource_config is 'InstanceType'.
230230
231+
vpc_config (dict): Contains values for VpcConfig:
232+
233+
* subnets (list[str]): List of subnet ids.
234+
The key in vpc_config is 'Subnets'.
235+
* security_group_ids (list[str]): List of security group ids.
236+
The key in vpc_config is 'SecurityGroupIds'.
237+
231238
hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible as
232239
a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for
233240
keys and values, but ``str()`` will be called to convert them before training.
@@ -259,6 +266,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
259266
if tags is not None:
260267
train_request['Tags'] = tags
261268

269+
if vpc_config is not None:
270+
train_request['VpcConfig'] = vpc_config
271+
262272
LOGGER.info('Creating training-job with name: {}'.format(job_name))
263273
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
264274
self.sagemaker_client.create_training_job(**train_request)

tests/integ/test_tf.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from sagemaker.tensorflow import TensorFlow
2121
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2222
from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout
23+
from tests.integ.vpc_utils import get_or_create_subnet_and_security_group
2324

2425
DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data')
26+
VPC_NAME = 'training-job-test'
2527

2628

2729
@pytest.mark.continuous_testing
@@ -90,6 +92,8 @@ def test_tf_async(sagemaker_session):
9092
def test_failed_tf_training(sagemaker_session, tf_full_version):
9193
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
9294
script_path = os.path.join(DATA_DIR, 'iris', 'failure_script.py')
95+
ec2_client = sagemaker_session.boto_session.client('ec2')
96+
subnet, security_group_id = get_or_create_subnet_and_security_group(ec2_client, VPC_NAME)
9397
estimator = TensorFlow(entry_point=script_path,
9498
role='SageMakerRole',
9599
framework_version=tf_full_version,
@@ -98,10 +102,17 @@ def test_failed_tf_training(sagemaker_session, tf_full_version):
98102
hyperparameters={'input_tensor_name': 'inputs'},
99103
train_instance_count=1,
100104
train_instance_type='ml.c4.xlarge',
101-
sagemaker_session=sagemaker_session)
105+
sagemaker_session=sagemaker_session,
106+
subnets=[subnet],
107+
security_group_ids=[security_group_id])
102108

103109
inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure')
104110

105111
with pytest.raises(ValueError) as e:
106112
estimator.fit(inputs)
107113
assert 'This failure is expected' in str(e.value)
114+
115+
job_desc = estimator.sagemaker_session.sagemaker_client.describe_training_job(
116+
TrainingJobName=estimator.latest_training_job.name)
117+
assert [subnet] == job_desc['VpcConfig']['Subnets']
118+
assert [security_group_id] == job_desc['VpcConfig']['SecurityGroupIds']

tests/integ/vpc_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
def _get_subnet_id_by_name(ec2_client, name):
17+
desc = ec2_client.describe_subnets(Filters=[
18+
{'Name': 'tag-value', 'Values': [name]}
19+
])
20+
if len(desc['Subnets']) == 0:
21+
return None
22+
else:
23+
return desc['Subnets'][0]['SubnetId']
24+
25+
26+
def _get_security_id_by_name(ec2_client, name):
27+
desc = ec2_client.describe_security_groups(Filters=[
28+
{'Name': 'tag-value', 'Values': [name]}
29+
])
30+
if len(desc['SecurityGroups']) == 0:
31+
return None
32+
else:
33+
return desc['SecurityGroups'][0]['GroupId']
34+
35+
36+
def _vpc_exists(ec2_client, name):
37+
desc = ec2_client.describe_vpcs(Filters=[
38+
{'Name': 'tag-value', 'Values': [name]}
39+
])
40+
return len(desc['Vpcs']) > 0
41+
42+
43+
def _get_route_table_id(ec2_client, vpc_id):
44+
desc = ec2_client.describe_route_tables(Filters=[
45+
{'Name': 'vpc-id', 'Values': [vpc_id]}
46+
])
47+
return desc['RouteTables'][0]['RouteTableId']
48+
49+
50+
def create_vpc_with_name(ec2_client, name):
51+
vpc_id = ec2_client.create_vpc(CidrBlock='10.0.0.0/16')['Vpc']['VpcId']
52+
53+
subnet_id = ec2_client.create_subnet(CidrBlock='10.0.0.0/24', VpcId=vpc_id)['Subnet']['SubnetId']
54+
55+
s3_service = [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if s.endswith('s3')][0]
56+
ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service,
57+
RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)])
58+
59+
security_group_id = ec2_client.create_security_group(GroupName='TrainingJobTestGroup', Description='Testing',
60+
VpcId=vpc_id)['GroupId']
61+
62+
ec2_client.create_tags(Resources=[vpc_id, subnet_id, security_group_id], Tags=[{'Key': 'Name', 'Value': name}])
63+
64+
return subnet_id, security_group_id
65+
66+
67+
def get_or_create_subnet_and_security_group(ec2_client, name):
68+
if _vpc_exists(ec2_client, name):
69+
return _get_subnet_id_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, name)
70+
else:
71+
return create_vpc_with_name(ec2_client, name)

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _create_train_job(version):
121121
'MaxRuntimeInSeconds': 24 * 60 * 60
122122
},
123123
'tags': None,
124+
'vpc_config': None
124125
}
125126

126127

tests/unit/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def test_unsupported_type_in_dict():
658658
},
659659
'stop_condition': {'MaxRuntimeInSeconds': 86400},
660660
'tags': None,
661+
'vpc_config': None
661662
}
662663

663664
HYPERPARAMS = {'x': 1, 'y': 'hello'}

tests/unit/test_mxnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def _create_train_job(version):
9595
'MaxRuntimeInSeconds': 24 * 60 * 60
9696
},
9797
'tags': None,
98+
'vpc_config': None
9899
}
99100

100101

tests/unit/test_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def _create_train_job(version):
111111
'stop_condition': {
112112
'MaxRuntimeInSeconds': 24 * 60 * 60
113113
},
114-
'tags': None
114+
'tags': None,
115+
'vpc_config': None
115116
}
116117

117118

tests/unit/test_session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def test_s3_input_all_arguments():
176176
MAX_TIME = 3 * 60 * 60
177177
JOB_NAME = 'jobname'
178178
TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
179+
VPC_CONFIG = {'Subnets': 'subnet', 'SecurityGroupIds': 'sgi-blahblah'}
179180

180181
DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
181182
'OutputDataConfig': {
@@ -259,7 +260,7 @@ def test_train_pack_to_request(sagemaker_session):
259260

260261
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
261262
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
262-
hyperparameters=None, stop_condition=stop_cond, tags=None)
263+
hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=None)
263264

264265
assert sagemaker_session.sagemaker_client.method_calls[0] == (
265266
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
@@ -322,12 +323,13 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
322323

323324
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
324325
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
325-
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS)
326+
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS, vpc_config=VPC_CONFIG)
326327

327328
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
328329

329330
assert actual_train_args['HyperParameters'] == hyperparameters
330331
assert actual_train_args['Tags'] == TAGS
332+
assert actual_train_args['VpcConfig'] == VPC_CONFIG
331333

332334

333335
def test_transform_pack_to_request(sagemaker_session):

tests/unit/test_tf_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def _create_train_job(tf_version):
103103
'MaxRuntimeInSeconds': 24 * 60 * 60
104104
},
105105
'tags': None,
106+
'vpc_config': None
106107
}
107108

108109

0 commit comments

Comments
 (0)