29
29
NUM_DATA_BATCHES = 5
30
30
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
31
31
BATCH_SIZE = 128
32
+ INPUT_TENSOR_NAME = PREDICT_INPUTS
32
33
33
34
34
35
def keras_model_fn (hyperparameters ):
@@ -77,122 +78,34 @@ def keras_model_fn(hyperparameters):
77
78
return _model
78
79
79
80
80
- def serving_input_fn (params ):
81
- # Notice that the input placeholder has the same input shape as the Keras model input
82
- tensor = tf .placeholder (tf .float32 , shape = [None , HEIGHT , WIDTH , DEPTH ])
83
-
84
- # The inputs key PREDICT_INPUTS matches the Keras InputLayer name
85
- inputs = {PREDICT_INPUTS : tensor }
81
+ def serving_input_fn (hyperpameters ):
82
+ inputs = {PREDICT_INPUTS : tf .placeholder (tf .float32 , [None , 32 , 32 , 3 ])}
86
83
return tf .estimator .export .ServingInputReceiver (inputs , inputs )
87
84
88
85
89
- def train_input_fn (training_dir , params ):
90
- return _input (tf .estimator .ModeKeys .TRAIN ,
91
- batch_size = BATCH_SIZE , data_dir = training_dir )
92
-
93
-
94
- def eval_input_fn (training_dir , params ):
95
- return _input (tf .estimator .ModeKeys .EVAL ,
96
- batch_size = BATCH_SIZE , data_dir = training_dir )
97
-
98
-
99
- def _input (mode , batch_size , data_dir ):
100
- """Uses the tf.data input pipeline for CIFAR-10 dataset.
101
- Args:
102
- mode: Standard names for model modes (tf.estimators.ModeKeys).
103
- batch_size: The number of samples per batch of input requested.
104
- """
105
- dataset = _record_dataset (_filenames (mode , data_dir ))
106
-
107
- # For training repeat forever.
108
- if mode == tf .estimator .ModeKeys .TRAIN :
109
- dataset = dataset .repeat ()
110
-
111
- dataset = dataset .map (_dataset_parser )
112
- dataset .prefetch (2 * batch_size )
113
-
114
- # For training, preprocess the image and shuffle.
115
- if mode == tf .estimator .ModeKeys .TRAIN :
116
- dataset = dataset .map (_train_preprocess_fn )
117
- dataset .prefetch (2 * batch_size )
118
-
119
- # Ensure that the capacity is sufficiently large to provide good random
120
- # shuffling.
121
- buffer_size = int (NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4 ) + 3 * batch_size
122
- dataset = dataset .shuffle (buffer_size = buffer_size )
123
-
124
- # Subtract off the mean and divide by the variance of the pixels.
125
- dataset = dataset .map (
126
- lambda image , label : (tf .image .per_image_standardization (image ), label ))
127
- dataset .prefetch (2 * batch_size )
128
-
129
- # Batch results by up to batch_size, and then fetch the tuple from the
130
- # iterator.
131
- iterator = dataset .batch (batch_size ).make_one_shot_iterator ()
132
- images , labels = iterator .get_next ()
133
-
134
- return {PREDICT_INPUTS : images }, labels
135
-
136
-
137
- def _train_preprocess_fn (image , label ):
138
- """Preprocess a single training image of layout [height, width, depth]."""
139
- # Resize the image to add four extra pixels on each side.
140
- image = tf .image .resize_image_with_crop_or_pad (image , HEIGHT + 8 , WIDTH + 8 )
141
-
142
- # Randomly crop a [HEIGHT, WIDTH] section of the image.
143
- image = tf .random_crop (image , [HEIGHT , WIDTH , DEPTH ])
144
-
145
- # Randomly flip the image horizontally.
146
- image = tf .image .random_flip_left_right (image )
147
-
148
- return image , label
149
-
150
-
151
- def _dataset_parser (value ):
152
- """Parse a CIFAR-10 record from value."""
153
- # Every record consists of a label followed by the image, with a fixed number
154
- # of bytes for each.
155
- label_bytes = 1
156
- image_bytes = HEIGHT * WIDTH * DEPTH
157
- record_bytes = label_bytes + image_bytes
158
-
159
- # Convert from a string to a vector of uint8 that is record_bytes long.
160
- raw_record = tf .decode_raw (value , tf .uint8 )
161
-
162
- # The first byte represents the label, which we convert from uint8 to int32.
163
- label = tf .cast (raw_record [0 ], tf .int32 )
164
-
165
- # The remaining bytes after the label represent the image, which we reshape
166
- # from [depth * height * width] to [depth, height, width].
167
- depth_major = tf .reshape (raw_record [label_bytes :record_bytes ],
168
- [DEPTH , HEIGHT , WIDTH ])
169
-
170
- # Convert from [depth, height, width] to [height, width, depth], and cast as
171
- # float32.
172
- image = tf .cast (tf .transpose (depth_major , [1 , 2 , 0 ]), tf .float32 )
173
-
174
- return image , tf .one_hot (label , NUM_CLASSES )
86
+ def train_input_fn (training_dir , hyperparameters ):
87
+ return _generate_synthetic_data (tf .estimator .ModeKeys .TRAIN , batch_size = BATCH_SIZE )
175
88
176
89
177
- def _record_dataset (filenames ):
178
- """Returns an input pipeline Dataset from `filenames`."""
179
- record_bytes = HEIGHT * WIDTH * DEPTH + 1
180
- return tf .data .FixedLengthRecordDataset (filenames , record_bytes )
90
+ def eval_input_fn (training_dir , hyperparameters ):
91
+ return _generate_synthetic_data (tf .estimator .ModeKeys .EVAL , batch_size = BATCH_SIZE )
181
92
182
93
183
- def _filenames (mode , data_dir ):
184
- """Returns a list of filenames based on 'mode'."""
185
- data_dir = os .path .join (data_dir , 'cifar-10-batches-bin' )
94
+ def _generate_synthetic_data (mode , batch_size ):
95
+ input_shape = [batch_size , HEIGHT , WIDTH , DEPTH ]
96
+ images = tf .truncated_normal (
97
+ input_shape ,
98
+ dtype = tf .float32 ,
99
+ stddev = 1e-1 ,
100
+ name = 'synthetic_images' )
101
+ labels = tf .random_uniform (
102
+ [batch_size , NUM_CLASSES ],
103
+ minval = 0 ,
104
+ maxval = NUM_CLASSES - 1 ,
105
+ dtype = tf .float32 ,
106
+ name = 'synthetic_labels' )
186
107
187
- assert os . path . exists ( data_dir ), ( 'Run cifar10_download_and_extract.py first '
188
- 'to download and extract the CIFAR-10 data. ' )
108
+ images = tf . contrib . framework . local_variable ( images , name = 'images' )
109
+ labels = tf . contrib . framework . local_variable ( labels , name = 'labels ' )
189
110
190
- if mode == tf .estimator .ModeKeys .TRAIN :
191
- return [
192
- os .path .join (data_dir , 'data_batch_%d.bin' % i )
193
- for i in range (1 , NUM_DATA_BATCHES + 1 )
194
- ]
195
- elif mode == tf .estimator .ModeKeys .EVAL :
196
- return [os .path .join (data_dir , 'test_batch.bin' )]
197
- else :
198
- raise ValueError ('Invalid mode: %s' % mode )
111
+ return {INPUT_TENSOR_NAME : images }, labels
0 commit comments