1
+ # Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ from __future__ import absolute_import
14
+ from __future__ import division
15
+ from __future__ import print_function
16
+
17
+ import os
18
+
19
+ import tensorflow as tf
20
+ from tensorflow .python .keras .layers import InputLayer , Conv2D , Activation , MaxPooling2D , Dropout , Flatten , Dense
21
+ from tensorflow .python .keras .models import Sequential
22
+ from tensorflow .python .keras .optimizers import RMSprop
23
+ from tensorflow .python .saved_model .signature_constants import PREDICT_INPUTS
24
+
25
+ HEIGHT = 32
26
+ WIDTH = 32
27
+ DEPTH = 3
28
+ NUM_CLASSES = 10
29
+ NUM_DATA_BATCHES = 5
30
+ NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
31
+ BATCH_SIZE = 128
32
+
33
+
34
+ def keras_model_fn (hyperparameters ):
35
+ """keras_model_fn receives hyperparameters from the training job and returns a compiled keras model.
36
+ The model will be transformed into a TensorFlow Estimator before training and it will be saved in a
37
+ TensorFlow Serving SavedModel at the end of training.
38
+ Args:
39
+ hyperparameters: The hyperparameters passed to the SageMaker TrainingJob that runs your TensorFlow
40
+ training script.
41
+ Returns: A compiled Keras model
42
+ """
43
+ model = Sequential ()
44
+
45
+ # TensorFlow Serving default prediction input tensor name is PREDICT_INPUTS.
46
+ # We must conform to this naming scheme.
47
+ model .add (InputLayer (input_shape = (HEIGHT , WIDTH , DEPTH ), name = PREDICT_INPUTS ))
48
+ model .add (Conv2D (32 , (3 , 3 ), padding = 'same' ))
49
+ model .add (Activation ('relu' ))
50
+ model .add (Conv2D (32 , (3 , 3 )))
51
+ model .add (Activation ('relu' ))
52
+ model .add (MaxPooling2D (pool_size = (2 , 2 )))
53
+ model .add (Dropout (0.25 ))
54
+
55
+ model .add (Conv2D (64 , (3 , 3 ), padding = 'same' ))
56
+ model .add (Activation ('relu' ))
57
+ model .add (Conv2D (64 , (3 , 3 )))
58
+ model .add (Activation ('relu' ))
59
+ model .add (MaxPooling2D (pool_size = (2 , 2 )))
60
+ model .add (Dropout (0.25 ))
61
+
62
+ model .add (Flatten ())
63
+ model .add (Dense (512 ))
64
+ model .add (Activation ('relu' ))
65
+ model .add (Dropout (0.5 ))
66
+ model .add (Dense (NUM_CLASSES ))
67
+ model .add (Activation ('softmax' ))
68
+
69
+ _model = tf .keras .Model (inputs = model .input , outputs = model .output )
70
+
71
+ opt = RMSprop (lr = hyperparameters ['learning_rate' ], decay = hyperparameters ['decay' ])
72
+
73
+ _model .compile (loss = 'categorical_crossentropy' ,
74
+ optimizer = opt ,
75
+ metrics = ['accuracy' ])
76
+
77
+ return _model
78
+
79
+
80
+ def serving_input_fn (params ):
81
+ # Notice that the input placeholder has the same input shape as the Keras model input
82
+ tensor = tf .placeholder (tf .float32 , shape = [None , HEIGHT , WIDTH , DEPTH ])
83
+
84
+ # The inputs key PREDICT_INPUTS matches the Keras InputLayer name
85
+ inputs = {PREDICT_INPUTS : tensor }
86
+ return tf .estimator .export .ServingInputReceiver (inputs , inputs )
87
+
88
+
89
+ def train_input_fn (training_dir , params ):
90
+ return _input (tf .estimator .ModeKeys .TRAIN ,
91
+ batch_size = BATCH_SIZE , data_dir = training_dir )
92
+
93
+
94
+ def eval_input_fn (training_dir , params ):
95
+ return _input (tf .estimator .ModeKeys .EVAL ,
96
+ batch_size = BATCH_SIZE , data_dir = training_dir )
97
+
98
+
99
+ def _input (mode , batch_size , data_dir ):
100
+ """Uses the tf.data input pipeline for CIFAR-10 dataset.
101
+ Args:
102
+ mode: Standard names for model modes (tf.estimators.ModeKeys).
103
+ batch_size: The number of samples per batch of input requested.
104
+ """
105
+ dataset = _record_dataset (_filenames (mode , data_dir ))
106
+
107
+ # For training repeat forever.
108
+ if mode == tf .estimator .ModeKeys .TRAIN :
109
+ dataset = dataset .repeat ()
110
+
111
+ dataset = dataset .map (_dataset_parser )
112
+ dataset .prefetch (2 * batch_size )
113
+
114
+ # For training, preprocess the image and shuffle.
115
+ if mode == tf .estimator .ModeKeys .TRAIN :
116
+ dataset = dataset .map (_train_preprocess_fn )
117
+ dataset .prefetch (2 * batch_size )
118
+
119
+ # Ensure that the capacity is sufficiently large to provide good random
120
+ # shuffling.
121
+ buffer_size = int (NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4 ) + 3 * batch_size
122
+ dataset = dataset .shuffle (buffer_size = buffer_size )
123
+
124
+ # Subtract off the mean and divide by the variance of the pixels.
125
+ dataset = dataset .map (
126
+ lambda image , label : (tf .image .per_image_standardization (image ), label ))
127
+ dataset .prefetch (2 * batch_size )
128
+
129
+ # Batch results by up to batch_size, and then fetch the tuple from the
130
+ # iterator.
131
+ iterator = dataset .batch (batch_size ).make_one_shot_iterator ()
132
+ images , labels = iterator .get_next ()
133
+
134
+ return {PREDICT_INPUTS : images }, labels
135
+
136
+
137
+ def _train_preprocess_fn (image , label ):
138
+ """Preprocess a single training image of layout [height, width, depth]."""
139
+ # Resize the image to add four extra pixels on each side.
140
+ image = tf .image .resize_image_with_crop_or_pad (image , HEIGHT + 8 , WIDTH + 8 )
141
+
142
+ # Randomly crop a [HEIGHT, WIDTH] section of the image.
143
+ image = tf .random_crop (image , [HEIGHT , WIDTH , DEPTH ])
144
+
145
+ # Randomly flip the image horizontally.
146
+ image = tf .image .random_flip_left_right (image )
147
+
148
+ return image , label
149
+
150
+
151
+ def _dataset_parser (value ):
152
+ """Parse a CIFAR-10 record from value."""
153
+ # Every record consists of a label followed by the image, with a fixed number
154
+ # of bytes for each.
155
+ label_bytes = 1
156
+ image_bytes = HEIGHT * WIDTH * DEPTH
157
+ record_bytes = label_bytes + image_bytes
158
+
159
+ # Convert from a string to a vector of uint8 that is record_bytes long.
160
+ raw_record = tf .decode_raw (value , tf .uint8 )
161
+
162
+ # The first byte represents the label, which we convert from uint8 to int32.
163
+ label = tf .cast (raw_record [0 ], tf .int32 )
164
+
165
+ # The remaining bytes after the label represent the image, which we reshape
166
+ # from [depth * height * width] to [depth, height, width].
167
+ depth_major = tf .reshape (raw_record [label_bytes :record_bytes ],
168
+ [DEPTH , HEIGHT , WIDTH ])
169
+
170
+ # Convert from [depth, height, width] to [height, width, depth], and cast as
171
+ # float32.
172
+ image = tf .cast (tf .transpose (depth_major , [1 , 2 , 0 ]), tf .float32 )
173
+
174
+ return image , tf .one_hot (label , NUM_CLASSES )
175
+
176
+
177
+ def _record_dataset (filenames ):
178
+ """Returns an input pipeline Dataset from `filenames`."""
179
+ record_bytes = HEIGHT * WIDTH * DEPTH + 1
180
+ return tf .data .FixedLengthRecordDataset (filenames , record_bytes )
181
+
182
+
183
+ def _filenames (mode , data_dir ):
184
+ """Returns a list of filenames based on 'mode'."""
185
+ data_dir = os .path .join (data_dir , 'cifar-10-batches-bin' )
186
+
187
+ assert os .path .exists (data_dir ), ('Run cifar10_download_and_extract.py first '
188
+ 'to download and extract the CIFAR-10 data.' )
189
+
190
+ if mode == tf .estimator .ModeKeys .TRAIN :
191
+ return [
192
+ os .path .join (data_dir , 'data_batch_%d.bin' % i )
193
+ for i in range (1 , NUM_DATA_BATCHES + 1 )
194
+ ]
195
+ elif mode == tf .estimator .ModeKeys .EVAL :
196
+ return [os .path .join (data_dir , 'test_batch.bin' )]
197
+ else :
198
+ raise ValueError ('Invalid mode: %s' % mode )
0 commit comments