From d4ff02acb920d0c9df9046eb652a4753a2f88cd1 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Tue, 14 Aug 2018 16:49:31 -0700 Subject: [PATCH 1/5] Add VPC config to estimator for training job creation CreateTraningJob api supports vpc. This change adds vpc config as an optional argument to Estimator. --- CHANGELOG.rst | 1 + src/sagemaker/estimator.py | 12 +++++++++--- src/sagemaker/job.py | 11 ++++++++++- src/sagemaker/session.py | 5 ++++- tests/integ/test_tf.py | 11 ++++++++++- tests/unit/test_chainer.py | 1 + tests/unit/test_estimator.py | 1 + tests/unit/test_mxnet.py | 1 + tests/unit/test_pytorch.py | 3 ++- tests/unit/test_session.py | 6 ++++-- tests/unit/test_tf_estimator.py | 1 + 11 files changed, 44 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0d666397b1..792ccdcf4b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,7 @@ CHANGELOG * bug-fix: Estimators: Fix serialization of single records * bug-fix: deprecate enable_cloudwatch_metrics from Framework Estimators. +* enhancement: Enable VPC config in training job creation 1.9.0 ===== diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 691f1f35d2..48c9771130 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -47,7 +47,8 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): def __init__(self, role, train_instance_count, train_instance_type, train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File', - output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None): + output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None, + subnets=None, security_group_ids=None): """Initialize an ``EstimatorBase`` instance. Args: @@ -100,6 +101,10 @@ def __init__(self, role, train_instance_count, train_instance_type, self.output_kms_key = output_kms_key self.latest_training_job = None + # VPC configurations + self.subnets = subnets + self.security_group_ids = security_group_ids + @abstractmethod def train_image(self): """Return the Docker image to use for training. @@ -399,8 +404,9 @@ def start_new(cls, estimator, inputs): estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode, input_config=config['input_config'], role=config['role'], job_name=estimator._current_job_name, output_config=config['output_config'], - resource_config=config['resource_config'], hyperparameters=hyperparameters, - stop_condition=config['stop_condition'], tags=estimator.tags) + resource_config=config['resource_config'], vpc_config=config['vpc_config'], + hyperparameters=hyperparameters, stop_condition=config['stop_condition'], + tags=estimator.tags) return cls(estimator.sagemaker_session, estimator._current_job_name) diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 92e2314c4c..773350aa69 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -59,12 +59,14 @@ def _load_config(inputs, estimator): estimator.train_instance_type, estimator.train_volume_size) stop_condition = _Job._prepare_stop_condition(estimator.train_max_run) + vpc_config = _Job._prepare_vpc_config(estimator.subnets, estimator.security_group_ids) return {'input_config': input_config, 'role': role, 'output_config': output_config, 'resource_config': resource_config, - 'stop_condition': stop_condition} + 'stop_condition': stop_condition, + 'vpc_config': vpc_config} @staticmethod def _format_inputs_to_input_config(inputs): @@ -143,6 +145,13 @@ def _prepare_resource_config(instance_count, instance_type, volume_size): 'InstanceType': instance_type, 'VolumeSizeInGB': volume_size} + @staticmethod + def _prepare_vpc_config(subnets, security_group_ids): + if subnets is None or security_group_ids is None: + return None + return {'Subnets': subnets, + 'SecurityGroupIds': security_group_ids} + @staticmethod def _prepare_stop_condition(max_run): return {'MaxRuntimeInSeconds': max_run} diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ec62e09ac0..bf3f7a346e 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -202,7 +202,7 @@ def default_bucket(self): return self._default_bucket def train(self, image, input_mode, input_config, role, job_name, output_config, - resource_config, hyperparameters, stop_condition, tags): + resource_config, vpc_config, hyperparameters, stop_condition, tags): """Create an Amazon SageMaker training job. Args: @@ -259,6 +259,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config, if tags is not None: train_request['Tags'] = tags + if vpc_config is not None: + train_request['VpcConfig'] = vpc_config + LOGGER.info('Creating training-job with name: {}'.format(job_name)) LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4))) self.sagemaker_client.create_training_job(**train_request) diff --git a/tests/integ/test_tf.py b/tests/integ/test_tf.py index dc9e886d62..992f6842a4 100644 --- a/tests/integ/test_tf.py +++ b/tests/integ/test_tf.py @@ -22,6 +22,8 @@ from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data') +VPC_SUBNETS = ['subnet-06b8537735fac3757'] +VPC_SECURITY_GROUP_IDS = ['sg-0a1008de6e1f384c3'] @pytest.mark.continuous_testing @@ -98,10 +100,17 @@ def test_failed_tf_training(sagemaker_session, tf_full_version): hyperparameters={'input_tensor_name': 'inputs'}, train_instance_count=1, train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + subnets=VPC_SUBNETS, + security_group_ids=VPC_SECURITY_GROUP_IDS) inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure') with pytest.raises(ValueError) as e: estimator.fit(inputs) assert 'This failure is expected' in str(e.value) + + job_desc = estimator.sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=estimator.latest_training_job.name) + assert VPC_SUBNETS == job_desc['VpcConfig']['Subnets'] + assert VPC_SECURITY_GROUP_IDS == job_desc['VpcConfig']['SecurityGroupIds'] diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index e6ad5b4a4f..5a1537a38c 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -121,6 +121,7 @@ def _create_train_job(version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, + 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 574dbed629..1d3a251851 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -658,6 +658,7 @@ def test_unsupported_type_in_dict(): }, 'stop_condition': {'MaxRuntimeInSeconds': 86400}, 'tags': None, + 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} } HYPERPARAMS = {'x': 1, 'y': 'hello'} diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 1787fc00d3..640f840a02 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -95,6 +95,7 @@ def _create_train_job(version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, + 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} } diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 2cdafe2aef..43ffae5fbb 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -111,7 +111,8 @@ def _create_train_job(version): 'stop_condition': { 'MaxRuntimeInSeconds': 24 * 60 * 60 }, - 'tags': None + 'tags': None, + 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} } diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 5118289a54..1920dddec1 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -176,6 +176,7 @@ def test_s3_input_all_arguments(): MAX_TIME = 3 * 60 * 60 JOB_NAME = 'jobname' TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] +VPC_CONFIG = {'Subnets': 'subnet', 'SecurityGroupIds': 'sgi-blahblah'} DEFAULT_EXPECTED_TRAIN_JOB_ARGS = { 'OutputDataConfig': { @@ -259,7 +260,7 @@ def test_train_pack_to_request(sagemaker_session): sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, - hyperparameters=None, stop_condition=stop_cond, tags=None) + hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=None) assert sagemaker_session.sagemaker_client.method_calls[0] == ( 'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS) @@ -322,12 +323,13 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, - hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS) + hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS, vpc_config=VPC_CONFIG) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] assert actual_train_args['HyperParameters'] == hyperparameters assert actual_train_args['Tags'] == TAGS + assert actual_train_args['VpcConfig'] == VPC_CONFIG def test_transform_pack_to_request(sagemaker_session): diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index 4861843ca0..dd7e7d86bd 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -103,6 +103,7 @@ def _create_train_job(tf_version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, + 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} } From 149ac7a03e2e95c1b02bffff8410378452979922 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Wed, 15 Aug 2018 13:46:48 -0700 Subject: [PATCH 2/5] Fix unit tests --- tests/unit/test_chainer.py | 2 +- tests/unit/test_estimator.py | 2 +- tests/unit/test_mxnet.py | 2 +- tests/unit/test_pytorch.py | 2 +- tests/unit/test_tf_estimator.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 5a1537a38c..f7e8e71d8e 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -121,7 +121,7 @@ def _create_train_job(version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, - 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} + 'vpc_config': None } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 1d3a251851..d173f013fb 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -658,7 +658,7 @@ def test_unsupported_type_in_dict(): }, 'stop_condition': {'MaxRuntimeInSeconds': 86400}, 'tags': None, - 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} + 'vpc_config': None } HYPERPARAMS = {'x': 1, 'y': 'hello'} diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 640f840a02..b6b03c844f 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -95,7 +95,7 @@ def _create_train_job(version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, - 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} + 'vpc_config': None } diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 43ffae5fbb..35611cd2fe 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -112,7 +112,7 @@ def _create_train_job(version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, - 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} + 'vpc_config': None } diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index dd7e7d86bd..681d45a0b1 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -103,7 +103,7 @@ def _create_train_job(tf_version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, - 'vpc_config': {'SecurityGroupIds': None, 'Subnets': None} + 'vpc_config': None } From ff097db7c5a4261fe4b214691a73f62c94a1b5be Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Wed, 15 Aug 2018 14:26:17 -0700 Subject: [PATCH 3/5] Fix docstring --- src/sagemaker/estimator.py | 3 +++ src/sagemaker/session.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 48c9771130..9febff5240 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -79,6 +79,9 @@ def __init__(self, role, train_instance_count, train_instance_type, using the default AWS configuration chain. tags (list[dict]): List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config. + security_group_ids (list[str]): List of security group ids. If not specified training job will be created + without VPC config. """ self.role = role self.train_instance_count = train_instance_count diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index bf3f7a346e..ea1c1c50f6 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -228,6 +228,13 @@ def train(self, image, input_mode, input_config, role, job_name, output_config, * instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. + vpc_config (dict): Contains values for VpcConfig: + + * subnets (list[str]): List of subnet ids. + The key in vpc_config is 'Subnets'. + * security_group_ids (list[str]): List of security group ids. + The key in vpc_config is 'SecurityGroupIds'. + hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. From f68b6febacdbe367ba7e94caff7f384221af254f Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Tue, 21 Aug 2018 13:01:29 -0700 Subject: [PATCH 4/5] Create VPC if does not exist --- tests/integ/test_tf.py | 14 ++++---- tests/integ/vpc_utils.py | 69 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 6 deletions(-) create mode 100644 tests/integ/vpc_utils.py diff --git a/tests/integ/test_tf.py b/tests/integ/test_tf.py index 992f6842a4..329436ccd8 100644 --- a/tests/integ/test_tf.py +++ b/tests/integ/test_tf.py @@ -20,10 +20,10 @@ from sagemaker.tensorflow import TensorFlow from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout +from tests.integ.vpc_utils import get_or_create_subnet_and_security_group DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data') -VPC_SUBNETS = ['subnet-06b8537735fac3757'] -VPC_SECURITY_GROUP_IDS = ['sg-0a1008de6e1f384c3'] +VPC_NAME = 'training-job-test' @pytest.mark.continuous_testing @@ -92,6 +92,8 @@ def test_tf_async(sagemaker_session): def test_failed_tf_training(sagemaker_session, tf_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): script_path = os.path.join(DATA_DIR, 'iris', 'failure_script.py') + ec2_client = sagemaker_session.boto_session.client('ec2') + subnet, security_group_id = get_or_create_subnet_and_security_group(ec2_client, VPC_NAME) estimator = TensorFlow(entry_point=script_path, role='SageMakerRole', framework_version=tf_full_version, @@ -101,8 +103,8 @@ def test_failed_tf_training(sagemaker_session, tf_full_version): train_instance_count=1, train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session, - subnets=VPC_SUBNETS, - security_group_ids=VPC_SECURITY_GROUP_IDS) + subnets=[subnet], + security_group_ids=[security_group_id]) inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure') @@ -112,5 +114,5 @@ def test_failed_tf_training(sagemaker_session, tf_full_version): job_desc = estimator.sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=estimator.latest_training_job.name) - assert VPC_SUBNETS == job_desc['VpcConfig']['Subnets'] - assert VPC_SECURITY_GROUP_IDS == job_desc['VpcConfig']['SecurityGroupIds'] + assert [subnet] == job_desc['VpcConfig']['Subnets'] + assert [security_group_id] == job_desc['VpcConfig']['SecurityGroupIds'] diff --git a/tests/integ/vpc_utils.py b/tests/integ/vpc_utils.py new file mode 100644 index 0000000000..fb951ad00a --- /dev/null +++ b/tests/integ/vpc_utils.py @@ -0,0 +1,69 @@ +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +def _get_subnet_id_by_name(ec2_client, name): + desc = ec2_client.describe_subnets(Filters=[ + {'Name': 'tag-value', 'Values': [name]} + ]) + if len(desc['Subnets']) == 0: + return None + else: + return desc['Subnets'][0]['SubnetId'] + + +def _get_security_id_by_name(ec2_client, name): + desc = ec2_client.describe_security_groups(Filters=[ + {'Name': 'tag-value', 'Values': [name]} + ]) + if len(desc['SecurityGroups']) == 0: + return None + else: + return desc['SecurityGroups'][0]['GroupId'] + + +def _vpc_exists(ec2_client, name): + desc = ec2_client.describe_vpcs(Filters=[ + {'Name': 'tag-value', 'Values': [name]} + ]) + return len(desc['Vpcs']) > 0 + + +def _get_route_table_id(ec2_client, vpc_id): + desc = ec2_client.describe_route_tables(Filters=[ + {'Name': 'vpc-id', 'Values': [vpc_id]} + ]) + return desc['RouteTables'][0]['RouteTableId'] + + +def create_vpc_with_name(ec2_client, name): + vpc_id = ec2_client.create_vpc(CidrBlock='10.0.0.0/16')['Vpc']['VpcId'] + + subnet_id = ec2_client.create_subnet(CidrBlock='10.0.0.0/24', VpcId=vpc_id)['Subnet']['SubnetId'] + + s3_service = [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if s.endswith('s3')][0] + ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service, + RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)]) + + security_group_id = ec2_client.create_security_group(GroupName='TrainingJobTestGroup', Description='Testing', + VpcId=vpc_id)['GroupId'] + + ec2_client.create_tags(Resources=[vpc_id, subnet_id, security_group_id], Tags=[{'Key': 'Name', 'Value': name}]) + + return subnet_id, security_group_id + + +def get_or_create_subnet_and_security_group(ec2_client, name): + if _vpc_exists(ec2_client, name): + return _get_subnet_id_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, name) + else: + return create_vpc_with_name(ec2_client, name) From ba9cac3785c15dd17c814f33afd5cd6967f268e4 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Tue, 21 Aug 2018 13:25:13 -0700 Subject: [PATCH 5/5] Add absolute_import --- tests/integ/vpc_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integ/vpc_utils.py b/tests/integ/vpc_utils.py index fb951ad00a..480c82f6a2 100644 --- a/tests/integ/vpc_utils.py +++ b/tests/integ/vpc_utils.py @@ -10,6 +10,8 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import absolute_import + def _get_subnet_id_by_name(ec2_client, name): desc = ec2_client.describe_subnets(Filters=[