Skip to content

Commit 63cc7fa

Browse files
authored
Merge pull request aws#352 from ChoiByungWook/tf_byoc
Tensorflow Bring Your Own Container
2 parents 94342d7 + f54cccc commit 63cc7fa

File tree

13 files changed

+1699
-0
lines changed

13 files changed

+1699
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
# For more information on creating a Dockerfile
15+
# https://docs.docker.com/compose/gettingstarted/#step-2-create-a-dockerfile
16+
FROM tensorflow/tensorflow:1.8.0-py3
17+
18+
RUN apt-get update && apt-get install -y --no-install-recommends nginx curl
19+
20+
# Download TensorFlow Serving
21+
# https://www.tensorflow.org/serving/setup#installing_the_modelserver
22+
RUN echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list
23+
RUN curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
24+
RUN apt-get update && apt-get install tensorflow-model-server
25+
26+
ENV PATH="/opt/ml/code:${PATH}"
27+
28+
# /opt/ml and all subdirectories are utilized by SageMaker, we use the /code subdirectory to store our user code.
29+
COPY /cifar10 /opt/ml/code
30+
WORKDIR /opt/ml/code
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#!/usr/bin/env bash
2+
3+
# This script shows how to build the Docker image and push it to ECR to be ready for use
4+
# by SageMaker.
5+
6+
# The argument to this script is the image name. This will be used as the image on the local
7+
# machine and combined with the account and region to form the repository name for ECR.
8+
image=$1
9+
10+
if [ "$image" == "" ]
11+
then
12+
echo "Usage: $0 <image-name>"
13+
exit 1
14+
fi
15+
16+
chmod +x cifar10/train
17+
chmod +x cifar10/serve
18+
19+
# Get the account number associated with the current IAM credentials
20+
account=$(aws sts get-caller-identity --query Account --output text)
21+
22+
if [ $? -ne 0 ]
23+
then
24+
exit 255
25+
fi
26+
27+
28+
# Get the region defined in the current configuration (default to us-west-2 if none defined)
29+
region=$(aws configure get region)
30+
region=${region:-us-west-2}
31+
32+
33+
fullname="${account}.dkr.ecr.${region}.amazonaws.com/${image}:latest"
34+
35+
# If the repository doesn't exist in ECR, create it.
36+
37+
aws ecr describe-repositories --repository-names "${image}" > /dev/null 2>&1
38+
39+
if [ $? -ne 0 ]
40+
then
41+
aws ecr create-repository --repository-name "${image}" > /dev/null
42+
fi
43+
44+
# Get the login command from ECR and execute it directly
45+
$(aws ecr get-login --region ${region} --no-include-email)
46+
47+
# Build the docker image locally with the image name and then push it to ECR
48+
# with the full name.
49+
50+
docker build -t ${image} .
51+
docker tag ${image} ${fullname}
52+
53+
docker push ${fullname}

advanced_functionality/tensorflow_bring_your_own/container/cifar10/__init__.py

Whitespace-only changes.
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
from __future__ import division
15+
from __future__ import print_function
16+
17+
import argparse
18+
import functools
19+
import os
20+
21+
import tensorflow as tf
22+
23+
import resnet_model
24+
25+
INPUT_TENSOR_NAME = "inputs"
26+
SIGNATURE_NAME = "serving_default"
27+
28+
HEIGHT = 32
29+
WIDTH = 32
30+
DEPTH = 3
31+
NUM_CLASSES = 10
32+
NUM_DATA_BATCHES = 5
33+
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
34+
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
35+
RESNET_SIZE = 32
36+
BATCH_SIZE = 1
37+
38+
# Scale the learning rate linearly with the batch size. When the batch size is
39+
# 128, the learning rate should be 0.05.
40+
_INITIAL_LEARNING_RATE = 0.05 * BATCH_SIZE / 128
41+
_MOMENTUM = 0.9
42+
43+
# We use a weight decay of 0.0002, which performs better than the 0.0001 that
44+
# was originally suggested.
45+
_WEIGHT_DECAY = 2e-4
46+
47+
_BATCHES_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE
48+
49+
50+
def model_fn(features, labels, mode):
51+
"""
52+
Model function for CIFAR-10.
53+
For more information: https://www.tensorflow.org/guide/custom_estimators#write_a_model_function
54+
"""
55+
inputs = features[INPUT_TENSOR_NAME]
56+
tf.summary.image('images', inputs, max_outputs=6)
57+
58+
network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES)
59+
60+
inputs = tf.reshape(inputs, [-1, HEIGHT, WIDTH, DEPTH])
61+
62+
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
63+
64+
predictions = {
65+
'classes': tf.argmax(logits, axis=1),
66+
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
67+
}
68+
69+
if mode == tf.estimator.ModeKeys.PREDICT:
70+
export_outputs = {
71+
SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions)
72+
}
73+
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs)
74+
75+
# Calculate loss, which includes softmax cross entropy and L2 regularization.
76+
cross_entropy = tf.losses.softmax_cross_entropy(
77+
logits=logits, onehot_labels=tf.one_hot(labels, 10))
78+
79+
# Create a tensor named cross_entropy for logging purposes.
80+
tf.identity(cross_entropy, name='cross_entropy')
81+
tf.summary.scalar('cross_entropy', cross_entropy)
82+
83+
# Add weight decay to the loss.
84+
loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
85+
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
86+
87+
if mode == tf.estimator.ModeKeys.TRAIN:
88+
global_step = tf.train.get_or_create_global_step()
89+
90+
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
91+
boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]]
92+
values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]]
93+
learning_rate = tf.train.piecewise_constant(
94+
tf.cast(global_step, tf.int32), boundaries, values)
95+
96+
# Create a tensor named learning_rate for logging purposes
97+
tf.identity(learning_rate, name='learning_rate')
98+
tf.summary.scalar('learning_rate', learning_rate)
99+
100+
optimizer = tf.train.MomentumOptimizer(
101+
learning_rate=learning_rate,
102+
momentum=_MOMENTUM)
103+
104+
# Batch norm requires update ops to be added as a dependency to the train_op
105+
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
106+
with tf.control_dependencies(update_ops):
107+
train_op = optimizer.minimize(loss, global_step)
108+
else:
109+
train_op = None
110+
111+
return tf.estimator.EstimatorSpec(
112+
mode=mode,
113+
predictions=predictions,
114+
loss=loss,
115+
train_op=train_op)
116+
117+
118+
def serving_input_fn():
119+
"""
120+
Serving input function for CIFAR-10. Specifies the input format the caller of predict() will have to provide.
121+
For more information: https://www.tensorflow.org/guide/saved_model#build_and_load_a_savedmodel
122+
"""
123+
inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
124+
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
125+
126+
127+
def make_batch(data_dir, batch_size=2):
128+
dataset = tf.data.TFRecordDataset(data_dir).repeat()
129+
130+
dataset = dataset.map(parser, num_parallel_calls=batch_size)
131+
132+
min_queue_examples = int(45000 * 0.4)
133+
# Ensure that the capacity is sufficiently large to provide good random
134+
# shuffling.
135+
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
136+
137+
dataset = dataset.batch(batch_size)
138+
iterator = dataset.make_one_shot_iterator()
139+
140+
image, label = iterator.get_next()
141+
142+
return image, label
143+
144+
145+
def parser(serialized_example):
146+
features = tf.parse_single_example(
147+
serialized_example,
148+
features={
149+
'image': tf.FixedLenFeature([], tf.string),
150+
'label': tf.FixedLenFeature([], tf.int64),
151+
})
152+
image = tf.decode_raw(features['image'], tf.uint8)
153+
image.set_shape([DEPTH * HEIGHT * WIDTH])
154+
155+
image = tf.cast(tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]), tf.float32)
156+
label = tf.cast(features['label'], tf.int32)
157+
158+
return image, label
159+
160+
161+
def train_input_fn(data_dir):
162+
with tf.device('/cpu:0'):
163+
train_data = os.path.join(data_dir, 'train.tfrecords')
164+
image_batch, label_batch = make_batch(train_data, BATCH_SIZE)
165+
return {INPUT_TENSOR_NAME: image_batch}, label_batch
166+
167+
168+
def eval_input_fn(data_dir):
169+
with tf.device('/cpu:0'):
170+
eval_data = os.path.join(data_dir, 'eval.tfrecords')
171+
image_batch, label_batch = make_batch(eval_data, BATCH_SIZE)
172+
173+
return {INPUT_TENSOR_NAME: image_batch}, label_batch
174+
175+
176+
def train(model_dir, data_dir, train_steps):
177+
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir)
178+
179+
temp_input_fn = functools.partial(train_input_fn, data_dir)
180+
181+
train_spec = tf.estimator.TrainSpec(temp_input_fn, max_steps=train_steps)
182+
183+
exporter = tf.estimator.LatestExporter('Servo', serving_input_receiver_fn=serving_input_fn)
184+
temp_eval_fn = functools.partial(eval_input_fn, data_dir)
185+
eval_spec = tf.estimator.EvalSpec(temp_eval_fn, steps=1, exporters=exporter)
186+
187+
tf.estimator.train_and_evaluate(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
188+
189+
190+
def main(model_dir, data_dir, train_steps):
191+
tf.logging.set_verbosity(tf.logging.INFO)
192+
train(model_dir, data_dir, train_steps)
193+
194+
195+
if __name__ == '__main__':
196+
args_parser = argparse.ArgumentParser()
197+
# For more information:
198+
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html
199+
args_parser.add_argument(
200+
'--data-dir',
201+
default='/opt/ml/input/data/training',
202+
type=str,
203+
help='The directory where the CIFAR-10 input data is stored. Default: /opt/ml/input/data/training. This '
204+
'directory corresponds to the SageMaker channel named \'training\', which was specified when creating '
205+
'our training job on SageMaker')
206+
207+
# For more information:
208+
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html
209+
args_parser.add_argument(
210+
'--model-dir',
211+
default='/opt/ml/model',
212+
type=str,
213+
help='The directory where the model will be stored. Default: /opt/ml/model. This directory should contain all '
214+
'final model artifacts as Amazon SageMaker copies all data within this directory as a single object in '
215+
'compressed tar format.')
216+
217+
args_parser.add_argument(
218+
'--train-steps',
219+
type=int,
220+
default=100,
221+
help='The number of steps to use for training.')
222+
args = args_parser.parse_args()
223+
main(**vars(args))
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
events {
2+
# determines how many requests can simultaneously be served
3+
# https://www.digitalocean.com/community/tutorials/how-to-optimize-nginx-configuration
4+
# for more information
5+
worker_connections 2048;
6+
}
7+
8+
http {
9+
server {
10+
# configures the server to listen to the port 8080
11+
# Amazon SageMaker sends inference requests to port 8080.
12+
# For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
13+
listen 8080 deferred;
14+
15+
# redirects requests from SageMaker to TF Serving
16+
location /invocations {
17+
proxy_pass http://localhost:8501/v1/models/cifar10_model:predict;
18+
}
19+
20+
# Used by SageMaker to confirm if server is alive.
21+
# https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
22+
location /ping {
23+
return 200 "OK";
24+
}
25+
}
26+
}

0 commit comments

Comments
 (0)