diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0d666397b1..792ccdcf4b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,7 @@ CHANGELOG * bug-fix: Estimators: Fix serialization of single records * bug-fix: deprecate enable_cloudwatch_metrics from Framework Estimators. +* enhancement: Enable VPC config in training job creation 1.9.0 ===== diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 691f1f35d2..9febff5240 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -47,7 +47,8 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): def __init__(self, role, train_instance_count, train_instance_type, train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File', - output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None): + output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None, + subnets=None, security_group_ids=None): """Initialize an ``EstimatorBase`` instance. Args: @@ -78,6 +79,9 @@ def __init__(self, role, train_instance_count, train_instance_type, using the default AWS configuration chain. tags (list[dict]): List of tags for labeling a training job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. + subnets (list[str]): List of subnet ids. If not specified training job will be created without VPC config. + security_group_ids (list[str]): List of security group ids. If not specified training job will be created + without VPC config. """ self.role = role self.train_instance_count = train_instance_count @@ -100,6 +104,10 @@ def __init__(self, role, train_instance_count, train_instance_type, self.output_kms_key = output_kms_key self.latest_training_job = None + # VPC configurations + self.subnets = subnets + self.security_group_ids = security_group_ids + @abstractmethod def train_image(self): """Return the Docker image to use for training. @@ -399,8 +407,9 @@ def start_new(cls, estimator, inputs): estimator.sagemaker_session.train(image=estimator.train_image(), input_mode=estimator.input_mode, input_config=config['input_config'], role=config['role'], job_name=estimator._current_job_name, output_config=config['output_config'], - resource_config=config['resource_config'], hyperparameters=hyperparameters, - stop_condition=config['stop_condition'], tags=estimator.tags) + resource_config=config['resource_config'], vpc_config=config['vpc_config'], + hyperparameters=hyperparameters, stop_condition=config['stop_condition'], + tags=estimator.tags) return cls(estimator.sagemaker_session, estimator._current_job_name) diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 92e2314c4c..773350aa69 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -59,12 +59,14 @@ def _load_config(inputs, estimator): estimator.train_instance_type, estimator.train_volume_size) stop_condition = _Job._prepare_stop_condition(estimator.train_max_run) + vpc_config = _Job._prepare_vpc_config(estimator.subnets, estimator.security_group_ids) return {'input_config': input_config, 'role': role, 'output_config': output_config, 'resource_config': resource_config, - 'stop_condition': stop_condition} + 'stop_condition': stop_condition, + 'vpc_config': vpc_config} @staticmethod def _format_inputs_to_input_config(inputs): @@ -143,6 +145,13 @@ def _prepare_resource_config(instance_count, instance_type, volume_size): 'InstanceType': instance_type, 'VolumeSizeInGB': volume_size} + @staticmethod + def _prepare_vpc_config(subnets, security_group_ids): + if subnets is None or security_group_ids is None: + return None + return {'Subnets': subnets, + 'SecurityGroupIds': security_group_ids} + @staticmethod def _prepare_stop_condition(max_run): return {'MaxRuntimeInSeconds': max_run} diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ec62e09ac0..ea1c1c50f6 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -202,7 +202,7 @@ def default_bucket(self): return self._default_bucket def train(self, image, input_mode, input_config, role, job_name, output_config, - resource_config, hyperparameters, stop_condition, tags): + resource_config, vpc_config, hyperparameters, stop_condition, tags): """Create an Amazon SageMaker training job. Args: @@ -228,6 +228,13 @@ def train(self, image, input_mode, input_config, role, job_name, output_config, * instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'. + vpc_config (dict): Contains values for VpcConfig: + + * subnets (list[str]): List of subnet ids. + The key in vpc_config is 'Subnets'. + * security_group_ids (list[str]): List of security group ids. + The key in vpc_config is 'SecurityGroupIds'. + hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for keys and values, but ``str()`` will be called to convert them before training. @@ -259,6 +266,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config, if tags is not None: train_request['Tags'] = tags + if vpc_config is not None: + train_request['VpcConfig'] = vpc_config + LOGGER.info('Creating training-job with name: {}'.format(job_name)) LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4))) self.sagemaker_client.create_training_job(**train_request) diff --git a/tests/integ/test_tf.py b/tests/integ/test_tf.py index dc9e886d62..329436ccd8 100644 --- a/tests/integ/test_tf.py +++ b/tests/integ/test_tf.py @@ -20,8 +20,10 @@ from sagemaker.tensorflow import TensorFlow from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout +from tests.integ.vpc_utils import get_or_create_subnet_and_security_group DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data') +VPC_NAME = 'training-job-test' @pytest.mark.continuous_testing @@ -90,6 +92,8 @@ def test_tf_async(sagemaker_session): def test_failed_tf_training(sagemaker_session, tf_full_version): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): script_path = os.path.join(DATA_DIR, 'iris', 'failure_script.py') + ec2_client = sagemaker_session.boto_session.client('ec2') + subnet, security_group_id = get_or_create_subnet_and_security_group(ec2_client, VPC_NAME) estimator = TensorFlow(entry_point=script_path, role='SageMakerRole', framework_version=tf_full_version, @@ -98,10 +102,17 @@ def test_failed_tf_training(sagemaker_session, tf_full_version): hyperparameters={'input_tensor_name': 'inputs'}, train_instance_count=1, train_instance_type='ml.c4.xlarge', - sagemaker_session=sagemaker_session) + sagemaker_session=sagemaker_session, + subnets=[subnet], + security_group_ids=[security_group_id]) inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure') with pytest.raises(ValueError) as e: estimator.fit(inputs) assert 'This failure is expected' in str(e.value) + + job_desc = estimator.sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=estimator.latest_training_job.name) + assert [subnet] == job_desc['VpcConfig']['Subnets'] + assert [security_group_id] == job_desc['VpcConfig']['SecurityGroupIds'] diff --git a/tests/integ/vpc_utils.py b/tests/integ/vpc_utils.py new file mode 100644 index 0000000000..480c82f6a2 --- /dev/null +++ b/tests/integ/vpc_utils.py @@ -0,0 +1,71 @@ +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +def _get_subnet_id_by_name(ec2_client, name): + desc = ec2_client.describe_subnets(Filters=[ + {'Name': 'tag-value', 'Values': [name]} + ]) + if len(desc['Subnets']) == 0: + return None + else: + return desc['Subnets'][0]['SubnetId'] + + +def _get_security_id_by_name(ec2_client, name): + desc = ec2_client.describe_security_groups(Filters=[ + {'Name': 'tag-value', 'Values': [name]} + ]) + if len(desc['SecurityGroups']) == 0: + return None + else: + return desc['SecurityGroups'][0]['GroupId'] + + +def _vpc_exists(ec2_client, name): + desc = ec2_client.describe_vpcs(Filters=[ + {'Name': 'tag-value', 'Values': [name]} + ]) + return len(desc['Vpcs']) > 0 + + +def _get_route_table_id(ec2_client, vpc_id): + desc = ec2_client.describe_route_tables(Filters=[ + {'Name': 'vpc-id', 'Values': [vpc_id]} + ]) + return desc['RouteTables'][0]['RouteTableId'] + + +def create_vpc_with_name(ec2_client, name): + vpc_id = ec2_client.create_vpc(CidrBlock='10.0.0.0/16')['Vpc']['VpcId'] + + subnet_id = ec2_client.create_subnet(CidrBlock='10.0.0.0/24', VpcId=vpc_id)['Subnet']['SubnetId'] + + s3_service = [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if s.endswith('s3')][0] + ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service, + RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)]) + + security_group_id = ec2_client.create_security_group(GroupName='TrainingJobTestGroup', Description='Testing', + VpcId=vpc_id)['GroupId'] + + ec2_client.create_tags(Resources=[vpc_id, subnet_id, security_group_id], Tags=[{'Key': 'Name', 'Value': name}]) + + return subnet_id, security_group_id + + +def get_or_create_subnet_and_security_group(ec2_client, name): + if _vpc_exists(ec2_client, name): + return _get_subnet_id_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, name) + else: + return create_vpc_with_name(ec2_client, name) diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index e6ad5b4a4f..f7e8e71d8e 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -121,6 +121,7 @@ def _create_train_job(version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, + 'vpc_config': None } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 574dbed629..d173f013fb 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -658,6 +658,7 @@ def test_unsupported_type_in_dict(): }, 'stop_condition': {'MaxRuntimeInSeconds': 86400}, 'tags': None, + 'vpc_config': None } HYPERPARAMS = {'x': 1, 'y': 'hello'} diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 1787fc00d3..b6b03c844f 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -95,6 +95,7 @@ def _create_train_job(version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, + 'vpc_config': None } diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 2cdafe2aef..35611cd2fe 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -111,7 +111,8 @@ def _create_train_job(version): 'stop_condition': { 'MaxRuntimeInSeconds': 24 * 60 * 60 }, - 'tags': None + 'tags': None, + 'vpc_config': None } diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 5118289a54..1920dddec1 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -176,6 +176,7 @@ def test_s3_input_all_arguments(): MAX_TIME = 3 * 60 * 60 JOB_NAME = 'jobname' TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}] +VPC_CONFIG = {'Subnets': 'subnet', 'SecurityGroupIds': 'sgi-blahblah'} DEFAULT_EXPECTED_TRAIN_JOB_ARGS = { 'OutputDataConfig': { @@ -259,7 +260,7 @@ def test_train_pack_to_request(sagemaker_session): sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, - hyperparameters=None, stop_condition=stop_cond, tags=None) + hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=None) assert sagemaker_session.sagemaker_client.method_calls[0] == ( 'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS) @@ -322,12 +323,13 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE, job_name=JOB_NAME, output_config=out_config, resource_config=resource_config, - hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS) + hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS, vpc_config=VPC_CONFIG) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] assert actual_train_args['HyperParameters'] == hyperparameters assert actual_train_args['Tags'] == TAGS + assert actual_train_args['VpcConfig'] == VPC_CONFIG def test_transform_pack_to_request(sagemaker_session): diff --git a/tests/unit/test_tf_estimator.py b/tests/unit/test_tf_estimator.py index 4861843ca0..681d45a0b1 100644 --- a/tests/unit/test_tf_estimator.py +++ b/tests/unit/test_tf_estimator.py @@ -103,6 +103,7 @@ def _create_train_job(tf_version): 'MaxRuntimeInSeconds': 24 * 60 * 60 }, 'tags': None, + 'vpc_config': None }