Skip to content

Add VPC config to estimator for training job creation #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CHANGELOG

* bug-fix: Estimators: Fix serialization of single records
* bug-fix: deprecate enable_cloudwatch_metrics from Framework Estimators.
* enhancement: Enable VPC config in training job creation

1.9.0
=====
Expand Down
15 changes: 12 additions & 3 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):

def __init__(self, role, train_instance_count, train_instance_type,
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None):
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None,
subnets=None, security_group_ids=None):
"""Initialize an ``EstimatorBase`` instance.

Args:
Expand Down Expand Up @@ -78,6 +79,9 @@ def __init__(self, role, train_instance_count, train_instance_type,
using the default AWS configuration chain.
tags (list[dict]): List of tags for labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config.
security_group_ids (list[str]): List of security group ids. If not specified training job will be created
without VPC config.
"""
self.role = role
self.train_instance_count = train_instance_count
Expand All @@ -100,6 +104,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
self.output_kms_key = output_kms_key
self.latest_training_job = None

# VPC configurations
self.subnets = subnets
self.security_group_ids = security_group_ids

@abstractmethod
def train_image(self):
"""Return the Docker image to use for training.
Expand Down Expand Up @@ -399,8 +407,9 @@ def start_new(cls, estimator, inputs):
estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode,
input_config=config['input_config'], role=config['role'],
job_name=estimator._current_job_name, output_config=config['output_config'],
resource_config=config['resource_config'], hyperparameters=hyperparameters,
stop_condition=config['stop_condition'], tags=estimator.tags)
resource_config=config['resource_config'], vpc_config=config['vpc_config'],
hyperparameters=hyperparameters, stop_condition=config['stop_condition'],
tags=estimator.tags)

return cls(estimator.sagemaker_session, estimator._current_job_name)

Expand Down
11 changes: 10 additions & 1 deletion src/sagemaker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ def _load_config(inputs, estimator):
estimator.train_instance_type,
estimator.train_volume_size)
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
vpc_config = _Job._prepare_vpc_config(estimator.subnets, estimator.security_group_ids)

return {'input_config': input_config,
'role': role,
'output_config': output_config,
'resource_config': resource_config,
'stop_condition': stop_condition}
'stop_condition': stop_condition,
'vpc_config': vpc_config}

@staticmethod
def _format_inputs_to_input_config(inputs):
Expand Down Expand Up @@ -143,6 +145,13 @@ def _prepare_resource_config(instance_count, instance_type, volume_size):
'InstanceType': instance_type,
'VolumeSizeInGB': volume_size}

@staticmethod
def _prepare_vpc_config(subnets, security_group_ids):
if subnets is None or security_group_ids is None:
return None
return {'Subnets': subnets,
'SecurityGroupIds': security_group_ids}

@staticmethod
def _prepare_stop_condition(max_run):
return {'MaxRuntimeInSeconds': max_run}
Expand Down
12 changes: 11 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def default_bucket(self):
return self._default_bucket

def train(self, image, input_mode, input_config, role, job_name, output_config,
resource_config, hyperparameters, stop_condition, tags):
resource_config, vpc_config, hyperparameters, stop_condition, tags):
"""Create an Amazon SageMaker training job.

Args:
Expand All @@ -228,6 +228,13 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
* instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
The key in resource_config is 'InstanceType'.

vpc_config (dict): Contains values for VpcConfig:

* subnets (list[str]): List of subnet ids.
The key in vpc_config is 'Subnets'.
* security_group_ids (list[str]): List of security group ids.
The key in vpc_config is 'SecurityGroupIds'.

hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible as
a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for
keys and values, but ``str()`` will be called to convert them before training.
Expand Down Expand Up @@ -259,6 +266,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
if tags is not None:
train_request['Tags'] = tags

if vpc_config is not None:
train_request['VpcConfig'] = vpc_config

LOGGER.info('Creating training-job with name: {}'.format(job_name))
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
self.sagemaker_client.create_training_job(**train_request)
Expand Down
13 changes: 12 additions & 1 deletion tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from sagemaker.tensorflow import TensorFlow
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout
from tests.integ.vpc_utils import get_or_create_subnet_and_security_group

DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data')
VPC_NAME = 'training-job-test'


@pytest.mark.continuous_testing
Expand Down Expand Up @@ -90,6 +92,8 @@ def test_tf_async(sagemaker_session):
def test_failed_tf_training(sagemaker_session, tf_full_version):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
script_path = os.path.join(DATA_DIR, 'iris', 'failure_script.py')
ec2_client = sagemaker_session.boto_session.client('ec2')
subnet, security_group_id = get_or_create_subnet_and_security_group(ec2_client, VPC_NAME)
estimator = TensorFlow(entry_point=script_path,
role='SageMakerRole',
framework_version=tf_full_version,
Expand All @@ -98,10 +102,17 @@ def test_failed_tf_training(sagemaker_session, tf_full_version):
hyperparameters={'input_tensor_name': 'inputs'},
train_instance_count=1,
train_instance_type='ml.c4.xlarge',
sagemaker_session=sagemaker_session)
sagemaker_session=sagemaker_session,
subnets=[subnet],
security_group_ids=[security_group_id])

inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure')

with pytest.raises(ValueError) as e:
estimator.fit(inputs)
assert 'This failure is expected' in str(e.value)

job_desc = estimator.sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=estimator.latest_training_job.name)
assert [subnet] == job_desc['VpcConfig']['Subnets']
assert [security_group_id] == job_desc['VpcConfig']['SecurityGroupIds']
71 changes: 71 additions & 0 deletions tests/integ/vpc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import


def _get_subnet_id_by_name(ec2_client, name):
desc = ec2_client.describe_subnets(Filters=[
{'Name': 'tag-value', 'Values': [name]}
])
if len(desc['Subnets']) == 0:
return None
else:
return desc['Subnets'][0]['SubnetId']


def _get_security_id_by_name(ec2_client, name):
desc = ec2_client.describe_security_groups(Filters=[
{'Name': 'tag-value', 'Values': [name]}
])
if len(desc['SecurityGroups']) == 0:
return None
else:
return desc['SecurityGroups'][0]['GroupId']


def _vpc_exists(ec2_client, name):
desc = ec2_client.describe_vpcs(Filters=[
{'Name': 'tag-value', 'Values': [name]}
])
return len(desc['Vpcs']) > 0


def _get_route_table_id(ec2_client, vpc_id):
desc = ec2_client.describe_route_tables(Filters=[
{'Name': 'vpc-id', 'Values': [vpc_id]}
])
return desc['RouteTables'][0]['RouteTableId']


def create_vpc_with_name(ec2_client, name):
vpc_id = ec2_client.create_vpc(CidrBlock='10.0.0.0/16')['Vpc']['VpcId']

subnet_id = ec2_client.create_subnet(CidrBlock='10.0.0.0/24', VpcId=vpc_id)['Subnet']['SubnetId']

s3_service = [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if s.endswith('s3')][0]
ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service,
RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)])

security_group_id = ec2_client.create_security_group(GroupName='TrainingJobTestGroup', Description='Testing',
VpcId=vpc_id)['GroupId']

ec2_client.create_tags(Resources=[vpc_id, subnet_id, security_group_id], Tags=[{'Key': 'Name', 'Value': name}])

return subnet_id, security_group_id


def get_or_create_subnet_and_security_group(ec2_client, name):
if _vpc_exists(ec2_client, name):
return _get_subnet_id_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, name)
else:
return create_vpc_with_name(ec2_client, name)
1 change: 1 addition & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _create_train_job(version):
'MaxRuntimeInSeconds': 24 * 60 * 60
},
'tags': None,
'vpc_config': None
}


Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def test_unsupported_type_in_dict():
},
'stop_condition': {'MaxRuntimeInSeconds': 86400},
'tags': None,
'vpc_config': None
}

HYPERPARAMS = {'x': 1, 'y': 'hello'}
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _create_train_job(version):
'MaxRuntimeInSeconds': 24 * 60 * 60
},
'tags': None,
'vpc_config': None
}


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _create_train_job(version):
'stop_condition': {
'MaxRuntimeInSeconds': 24 * 60 * 60
},
'tags': None
'tags': None,
'vpc_config': None
}


Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def test_s3_input_all_arguments():
MAX_TIME = 3 * 60 * 60
JOB_NAME = 'jobname'
TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
VPC_CONFIG = {'Subnets': 'subnet', 'SecurityGroupIds': 'sgi-blahblah'}

DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
'OutputDataConfig': {
Expand Down Expand Up @@ -259,7 +260,7 @@ def test_train_pack_to_request(sagemaker_session):

sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
hyperparameters=None, stop_condition=stop_cond, tags=None)
hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=None)

assert sagemaker_session.sagemaker_client.method_calls[0] == (
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
Expand Down Expand Up @@ -322,12 +323,13 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):

sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS)
hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS, vpc_config=VPC_CONFIG)

_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]

assert actual_train_args['HyperParameters'] == hyperparameters
assert actual_train_args['Tags'] == TAGS
assert actual_train_args['VpcConfig'] == VPC_CONFIG


def test_transform_pack_to_request(sagemaker_session):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _create_train_job(tf_version):
'MaxRuntimeInSeconds': 24 * 60 * 60
},
'tags': None,
'vpc_config': None
}


Expand Down