Skip to content

Commit 10b8abc

Browse files
author
Dan Choi
committed
wip tf byoc example
1 parent 8211a60 commit 10b8abc

File tree

7 files changed

+721
-0
lines changed

7 files changed

+721
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# https://docs.docker.com/compose/gettingstarted/#step-2-create-a-dockerfile
2+
FROM tensorflow/tensorflow:1.8.0-py3
3+
4+
RUN apt-get update && apt-get install -y --no-install-recommends nginx curl
5+
6+
# Download tensorflow serving
7+
RUN echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list
8+
RUN curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -
9+
RUN apt-get update && apt-get install tensorflow-model-server
10+
11+
ENV PATH="/opt/ml/code:${PATH}"
12+
13+
COPY src /opt/ml/code
14+
WORKDIR /opt/ml/code

advanced_functionality/tensorflow_bring_your_own/src/__init__.py

Whitespace-only changes.
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
from __future__ import division
15+
from __future__ import print_function
16+
17+
import argparse
18+
import functools
19+
import logging
20+
21+
import os
22+
23+
import resnet_model
24+
import tensorflow as tf
25+
26+
INPUT_TENSOR_NAME = "inputs"
27+
SIGNATURE_NAME = "serving_default"
28+
29+
HEIGHT = 32
30+
WIDTH = 32
31+
DEPTH = 3
32+
NUM_CLASSES = 10
33+
NUM_DATA_BATCHES = 5
34+
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
35+
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
36+
RESNET_SIZE = 32
37+
BATCH_SIZE = 1
38+
39+
# Scale the learning rate linearly with the batch size. When the batch size is
40+
# 128, the learning rate should be 0.05.
41+
_INITIAL_LEARNING_RATE = 0.05 * BATCH_SIZE / 128
42+
_MOMENTUM = 0.9
43+
44+
# We use a weight decay of 0.0002, which performs better than the 0.0001 that
45+
# was originally suggested.
46+
_WEIGHT_DECAY = 2e-4
47+
48+
_BATCHES_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE
49+
50+
logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', level=logging.DEBUG)
51+
52+
53+
def model_fn(features, labels, mode):
54+
"""Model function for CIFAR-10."""
55+
inputs = features[INPUT_TENSOR_NAME]
56+
tf.summary.image('images', inputs, max_outputs=6)
57+
58+
network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES)
59+
60+
inputs = tf.reshape(inputs, [-1, HEIGHT, WIDTH, DEPTH])
61+
62+
logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN)
63+
64+
predictions = {
65+
'classes': tf.argmax(logits, axis=1),
66+
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
67+
}
68+
69+
if mode == tf.estimator.ModeKeys.PREDICT:
70+
export_outputs = {
71+
SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions)
72+
}
73+
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs)
74+
75+
# Calculate loss, which includes softmax cross entropy and L2 regularization.
76+
cross_entropy = tf.losses.softmax_cross_entropy(
77+
logits=logits, onehot_labels=tf.one_hot(labels, 10))
78+
79+
# Create a tensor named cross_entropy for logging purposes.
80+
tf.identity(cross_entropy, name='cross_entropy')
81+
tf.summary.scalar('cross_entropy', cross_entropy)
82+
83+
# Add weight decay to the loss.
84+
loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
85+
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
86+
87+
if mode == tf.estimator.ModeKeys.TRAIN:
88+
global_step = tf.train.get_or_create_global_step()
89+
90+
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
91+
boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]]
92+
values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]]
93+
learning_rate = tf.train.piecewise_constant(
94+
tf.cast(global_step, tf.int32), boundaries, values)
95+
96+
# Create a tensor named learning_rate for logging purposes
97+
tf.identity(learning_rate, name='learning_rate')
98+
tf.summary.scalar('learning_rate', learning_rate)
99+
100+
optimizer = tf.train.MomentumOptimizer(
101+
learning_rate=learning_rate,
102+
momentum=_MOMENTUM)
103+
104+
# Batch norm requires update ops to be added as a dependency to the train_op
105+
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
106+
with tf.control_dependencies(update_ops):
107+
train_op = optimizer.minimize(loss, global_step)
108+
else:
109+
train_op = None
110+
111+
return tf.estimator.EstimatorSpec(
112+
mode=mode,
113+
predictions=predictions,
114+
loss=loss,
115+
train_op=train_op)
116+
117+
118+
def serving_input_fn():
119+
inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
120+
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
121+
122+
123+
def make_batch(data_dir, batch_size=2):
124+
dataset = tf.data.TFRecordDataset(data_dir).repeat()
125+
126+
dataset = dataset.map(parser, num_parallel_calls=batch_size)
127+
128+
min_queue_examples = int(45000 * 0.4)
129+
# Ensure that the capacity is sufficiently large to provide good random
130+
# shuffling.
131+
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
132+
133+
dataset = dataset.batch(batch_size)
134+
iterator = dataset.make_one_shot_iterator()
135+
136+
image, label = iterator.get_next()
137+
138+
return image, label
139+
140+
141+
def parser(serialized_example):
142+
features = tf.parse_single_example(
143+
serialized_example,
144+
features={
145+
'image': tf.FixedLenFeature([], tf.string),
146+
'label': tf.FixedLenFeature([], tf.int64),
147+
})
148+
image = tf.decode_raw(features['image'], tf.uint8)
149+
image.set_shape([DEPTH * HEIGHT * WIDTH])
150+
151+
image = tf.cast(tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]), tf.float32)
152+
label = tf.cast(features['label'], tf.int32)
153+
154+
return image, label
155+
156+
157+
def train_input_fn(data_dir):
158+
with tf.device('/cpu:0'):
159+
train_data = os.path.join(data_dir, 'train.tfrecords')
160+
image_batch, label_batch = make_batch(train_data, BATCH_SIZE)
161+
return {INPUT_TENSOR_NAME: image_batch}, label_batch
162+
163+
164+
def eval_input_fn(data_dir):
165+
with tf.device('/cpu:0'):
166+
eval_data = os.path.join(data_dir, 'eval.tfrecords')
167+
image_batch, label_batch = make_batch(eval_data, BATCH_SIZE)
168+
169+
return {INPUT_TENSOR_NAME: image_batch}, label_batch
170+
171+
172+
def train(model_dir, data_dir, train_steps):
173+
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir)
174+
175+
temp_input_fn = functools.partial(train_input_fn, data_dir)
176+
177+
train_spec = tf.estimator.TrainSpec(temp_input_fn, max_steps=train_steps)
178+
179+
exporter = tf.estimator.LatestExporter('Servo', serving_input_receiver_fn=serving_input_fn)
180+
temp_eval_fn = functools.partial(eval_input_fn, data_dir)
181+
eval_spec = tf.estimator.EvalSpec(temp_eval_fn, steps=1, exporters=exporter)
182+
183+
tf.estimator.train_and_evaluate(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
184+
185+
186+
def main(model_dir, data_dir, train_steps):
187+
tf.logging.set_verbosity(tf.logging.INFO)
188+
train(model_dir, data_dir, train_steps)
189+
print('Training Done!')
190+
191+
192+
if __name__ == '__main__':
193+
args_parser = argparse.ArgumentParser()
194+
args_parser.add_argument(
195+
'--data-dir',
196+
default='/opt/ml/input/data/training',
197+
type=str,
198+
# required=True,
199+
help='The directory where the CIFAR-10 input data is stored.')
200+
args_parser.add_argument(
201+
'--model-dir',
202+
default='/opt/ml/model',
203+
type=str,
204+
# required=True,
205+
help='The directory where the model will be stored.')
206+
args_parser.add_argument(
207+
'--train-steps',
208+
type=int,
209+
default=100,
210+
# required=True,
211+
help='The number of steps to use for training.')
212+
args = args_parser.parse_args()
213+
main(**vars(args))
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
events {
2+
# determines how many requests can simultaneously be served
3+
# https://www.digitalocean.com/community/tutorials/how-to-optimize-nginx-configuration
4+
# for more information
5+
worker_connections 2048;
6+
}
7+
8+
http {
9+
server {
10+
# configures the server to listen to the port 8080
11+
listen 8080 deferred;
12+
13+
# redirects requests from SageMaker to TF Serving
14+
location /invocations {
15+
proxy_pass http://localhost:8501/v1/models/cifar10_model:predict;
16+
}
17+
18+
# Used my SageMaker to confirm if server is alive.
19+
location /ping {
20+
return 200 "OK";
21+
}
22+
}
23+
}

0 commit comments

Comments
 (0)