Skip to content

Commit a110341

Browse files
yangawsPiali Das
authored and
Piali Das
committed
Use synthetic data in keras integ test (aws#367)
* Use synthetic data in keras integ test * Add new line at end of file * Use PREDICT_INPUTS as prediction input tensor name
1 parent 0d04909 commit a110341

File tree

1 file changed

+23
-110
lines changed

1 file changed

+23
-110
lines changed

tests/data/cifar_10/source/keras_cnn_cifar_10.py

Lines changed: 23 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
NUM_DATA_BATCHES = 5
3030
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES
3131
BATCH_SIZE = 128
32+
INPUT_TENSOR_NAME = PREDICT_INPUTS
3233

3334

3435
def keras_model_fn(hyperparameters):
@@ -77,122 +78,34 @@ def keras_model_fn(hyperparameters):
7778
return _model
7879

7980

80-
def serving_input_fn(params):
81-
# Notice that the input placeholder has the same input shape as the Keras model input
82-
tensor = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH])
83-
84-
# The inputs key PREDICT_INPUTS matches the Keras InputLayer name
85-
inputs = {PREDICT_INPUTS: tensor}
81+
def serving_input_fn(hyperpameters):
82+
inputs = {PREDICT_INPUTS: tf.placeholder(tf.float32, [None, 32, 32, 3])}
8683
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
8784

8885

89-
def train_input_fn(training_dir, params):
90-
return _input(tf.estimator.ModeKeys.TRAIN,
91-
batch_size=BATCH_SIZE, data_dir=training_dir)
92-
93-
94-
def eval_input_fn(training_dir, params):
95-
return _input(tf.estimator.ModeKeys.EVAL,
96-
batch_size=BATCH_SIZE, data_dir=training_dir)
97-
98-
99-
def _input(mode, batch_size, data_dir):
100-
"""Uses the tf.data input pipeline for CIFAR-10 dataset.
101-
Args:
102-
mode: Standard names for model modes (tf.estimators.ModeKeys).
103-
batch_size: The number of samples per batch of input requested.
104-
"""
105-
dataset = _record_dataset(_filenames(mode, data_dir))
106-
107-
# For training repeat forever.
108-
if mode == tf.estimator.ModeKeys.TRAIN:
109-
dataset = dataset.repeat()
110-
111-
dataset = dataset.map(_dataset_parser)
112-
dataset.prefetch(2 * batch_size)
113-
114-
# For training, preprocess the image and shuffle.
115-
if mode == tf.estimator.ModeKeys.TRAIN:
116-
dataset = dataset.map(_train_preprocess_fn)
117-
dataset.prefetch(2 * batch_size)
118-
119-
# Ensure that the capacity is sufficiently large to provide good random
120-
# shuffling.
121-
buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4) + 3 * batch_size
122-
dataset = dataset.shuffle(buffer_size=buffer_size)
123-
124-
# Subtract off the mean and divide by the variance of the pixels.
125-
dataset = dataset.map(
126-
lambda image, label: (tf.image.per_image_standardization(image), label))
127-
dataset.prefetch(2 * batch_size)
128-
129-
# Batch results by up to batch_size, and then fetch the tuple from the
130-
# iterator.
131-
iterator = dataset.batch(batch_size).make_one_shot_iterator()
132-
images, labels = iterator.get_next()
133-
134-
return {PREDICT_INPUTS: images}, labels
135-
136-
137-
def _train_preprocess_fn(image, label):
138-
"""Preprocess a single training image of layout [height, width, depth]."""
139-
# Resize the image to add four extra pixels on each side.
140-
image = tf.image.resize_image_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8)
141-
142-
# Randomly crop a [HEIGHT, WIDTH] section of the image.
143-
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
144-
145-
# Randomly flip the image horizontally.
146-
image = tf.image.random_flip_left_right(image)
147-
148-
return image, label
149-
150-
151-
def _dataset_parser(value):
152-
"""Parse a CIFAR-10 record from value."""
153-
# Every record consists of a label followed by the image, with a fixed number
154-
# of bytes for each.
155-
label_bytes = 1
156-
image_bytes = HEIGHT * WIDTH * DEPTH
157-
record_bytes = label_bytes + image_bytes
158-
159-
# Convert from a string to a vector of uint8 that is record_bytes long.
160-
raw_record = tf.decode_raw(value, tf.uint8)
161-
162-
# The first byte represents the label, which we convert from uint8 to int32.
163-
label = tf.cast(raw_record[0], tf.int32)
164-
165-
# The remaining bytes after the label represent the image, which we reshape
166-
# from [depth * height * width] to [depth, height, width].
167-
depth_major = tf.reshape(raw_record[label_bytes:record_bytes],
168-
[DEPTH, HEIGHT, WIDTH])
169-
170-
# Convert from [depth, height, width] to [height, width, depth], and cast as
171-
# float32.
172-
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
173-
174-
return image, tf.one_hot(label, NUM_CLASSES)
86+
def train_input_fn(training_dir, hyperparameters):
87+
return _generate_synthetic_data(tf.estimator.ModeKeys.TRAIN, batch_size=BATCH_SIZE)
17588

17689

177-
def _record_dataset(filenames):
178-
"""Returns an input pipeline Dataset from `filenames`."""
179-
record_bytes = HEIGHT * WIDTH * DEPTH + 1
180-
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
90+
def eval_input_fn(training_dir, hyperparameters):
91+
return _generate_synthetic_data(tf.estimator.ModeKeys.EVAL, batch_size=BATCH_SIZE)
18192

18293

183-
def _filenames(mode, data_dir):
184-
"""Returns a list of filenames based on 'mode'."""
185-
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
94+
def _generate_synthetic_data(mode, batch_size):
95+
input_shape = [batch_size, HEIGHT, WIDTH, DEPTH]
96+
images = tf.truncated_normal(
97+
input_shape,
98+
dtype=tf.float32,
99+
stddev=1e-1,
100+
name='synthetic_images')
101+
labels = tf.random_uniform(
102+
[batch_size, NUM_CLASSES],
103+
minval=0,
104+
maxval=NUM_CLASSES - 1,
105+
dtype=tf.float32,
106+
name='synthetic_labels')
186107

187-
assert os.path.exists(data_dir), ('Run cifar10_download_and_extract.py first '
188-
'to download and extract the CIFAR-10 data.')
108+
images = tf.contrib.framework.local_variable(images, name='images')
109+
labels = tf.contrib.framework.local_variable(labels, name='labels')
189110

190-
if mode == tf.estimator.ModeKeys.TRAIN:
191-
return [
192-
os.path.join(data_dir, 'data_batch_%d.bin' % i)
193-
for i in range(1, NUM_DATA_BATCHES + 1)
194-
]
195-
elif mode == tf.estimator.ModeKeys.EVAL:
196-
return [os.path.join(data_dir, 'test_batch.bin')]
197-
else:
198-
raise ValueError('Invalid mode: %s' % mode)
111+
return {INPUT_TENSOR_NAME: images}, labels

0 commit comments

Comments
 (0)