Skip to content

Commit f68b6fe

Browse files
committed
Create VPC if does not exist
1 parent ff097db commit f68b6fe

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

tests/integ/test_tf.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from sagemaker.tensorflow import TensorFlow
2121
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2222
from tests.integ.timeout import timeout_and_delete_endpoint_by_name, timeout
23+
from tests.integ.vpc_utils import get_or_create_subnet_and_security_group
2324

2425
DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data')
25-
VPC_SUBNETS = ['subnet-06b8537735fac3757']
26-
VPC_SECURITY_GROUP_IDS = ['sg-0a1008de6e1f384c3']
26+
VPC_NAME = 'training-job-test'
2727

2828

2929
@pytest.mark.continuous_testing
@@ -92,6 +92,8 @@ def test_tf_async(sagemaker_session):
9292
def test_failed_tf_training(sagemaker_session, tf_full_version):
9393
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
9494
script_path = os.path.join(DATA_DIR, 'iris', 'failure_script.py')
95+
ec2_client = sagemaker_session.boto_session.client('ec2')
96+
subnet, security_group_id = get_or_create_subnet_and_security_group(ec2_client, VPC_NAME)
9597
estimator = TensorFlow(entry_point=script_path,
9698
role='SageMakerRole',
9799
framework_version=tf_full_version,
@@ -101,8 +103,8 @@ def test_failed_tf_training(sagemaker_session, tf_full_version):
101103
train_instance_count=1,
102104
train_instance_type='ml.c4.xlarge',
103105
sagemaker_session=sagemaker_session,
104-
subnets=VPC_SUBNETS,
105-
security_group_ids=VPC_SECURITY_GROUP_IDS)
106+
subnets=[subnet],
107+
security_group_ids=[security_group_id])
106108

107109
inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure')
108110

@@ -112,5 +114,5 @@ def test_failed_tf_training(sagemaker_session, tf_full_version):
112114

113115
job_desc = estimator.sagemaker_session.sagemaker_client.describe_training_job(
114116
TrainingJobName=estimator.latest_training_job.name)
115-
assert VPC_SUBNETS == job_desc['VpcConfig']['Subnets']
116-
assert VPC_SECURITY_GROUP_IDS == job_desc['VpcConfig']['SecurityGroupIds']
117+
assert [subnet] == job_desc['VpcConfig']['Subnets']
118+
assert [security_group_id] == job_desc['VpcConfig']['SecurityGroupIds']

tests/integ/vpc_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
def _get_subnet_id_by_name(ec2_client, name):
15+
desc = ec2_client.describe_subnets(Filters=[
16+
{'Name': 'tag-value', 'Values': [name]}
17+
])
18+
if len(desc['Subnets']) == 0:
19+
return None
20+
else:
21+
return desc['Subnets'][0]['SubnetId']
22+
23+
24+
def _get_security_id_by_name(ec2_client, name):
25+
desc = ec2_client.describe_security_groups(Filters=[
26+
{'Name': 'tag-value', 'Values': [name]}
27+
])
28+
if len(desc['SecurityGroups']) == 0:
29+
return None
30+
else:
31+
return desc['SecurityGroups'][0]['GroupId']
32+
33+
34+
def _vpc_exists(ec2_client, name):
35+
desc = ec2_client.describe_vpcs(Filters=[
36+
{'Name': 'tag-value', 'Values': [name]}
37+
])
38+
return len(desc['Vpcs']) > 0
39+
40+
41+
def _get_route_table_id(ec2_client, vpc_id):
42+
desc = ec2_client.describe_route_tables(Filters=[
43+
{'Name': 'vpc-id', 'Values': [vpc_id]}
44+
])
45+
return desc['RouteTables'][0]['RouteTableId']
46+
47+
48+
def create_vpc_with_name(ec2_client, name):
49+
vpc_id = ec2_client.create_vpc(CidrBlock='10.0.0.0/16')['Vpc']['VpcId']
50+
51+
subnet_id = ec2_client.create_subnet(CidrBlock='10.0.0.0/24', VpcId=vpc_id)['Subnet']['SubnetId']
52+
53+
s3_service = [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if s.endswith('s3')][0]
54+
ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service,
55+
RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)])
56+
57+
security_group_id = ec2_client.create_security_group(GroupName='TrainingJobTestGroup', Description='Testing',
58+
VpcId=vpc_id)['GroupId']
59+
60+
ec2_client.create_tags(Resources=[vpc_id, subnet_id, security_group_id], Tags=[{'Key': 'Name', 'Value': name}])
61+
62+
return subnet_id, security_group_id
63+
64+
65+
def get_or_create_subnet_and_security_group(ec2_client, name):
66+
if _vpc_exists(ec2_client, name):
67+
return _get_subnet_id_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, name)
68+
else:
69+
return create_vpc_with_name(ec2_client, name)

0 commit comments

Comments
 (0)