diff --git a/tests/data/cifar_10/source/keras_cnn_cifar_10.py b/tests/data/cifar_10/source/keras_cnn_cifar_10.py index c73a1a5f90..b5f2ab7528 100644 --- a/tests/data/cifar_10/source/keras_cnn_cifar_10.py +++ b/tests/data/cifar_10/source/keras_cnn_cifar_10.py @@ -29,6 +29,7 @@ NUM_DATA_BATCHES = 5 NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES BATCH_SIZE = 128 +INPUT_TENSOR_NAME = PREDICT_INPUTS def keras_model_fn(hyperparameters): @@ -77,122 +78,34 @@ def keras_model_fn(hyperparameters): return _model -def serving_input_fn(params): - # Notice that the input placeholder has the same input shape as the Keras model input - tensor = tf.placeholder(tf.float32, shape=[None, HEIGHT, WIDTH, DEPTH]) - - # The inputs key PREDICT_INPUTS matches the Keras InputLayer name - inputs = {PREDICT_INPUTS: tensor} +def serving_input_fn(hyperpameters): + inputs = {PREDICT_INPUTS: tf.placeholder(tf.float32, [None, 32, 32, 3])} return tf.estimator.export.ServingInputReceiver(inputs, inputs) -def train_input_fn(training_dir, params): - return _input(tf.estimator.ModeKeys.TRAIN, - batch_size=BATCH_SIZE, data_dir=training_dir) - - -def eval_input_fn(training_dir, params): - return _input(tf.estimator.ModeKeys.EVAL, - batch_size=BATCH_SIZE, data_dir=training_dir) - - -def _input(mode, batch_size, data_dir): - """Uses the tf.data input pipeline for CIFAR-10 dataset. - Args: - mode: Standard names for model modes (tf.estimators.ModeKeys). - batch_size: The number of samples per batch of input requested. - """ - dataset = _record_dataset(_filenames(mode, data_dir)) - - # For training repeat forever. - if mode == tf.estimator.ModeKeys.TRAIN: - dataset = dataset.repeat() - - dataset = dataset.map(_dataset_parser) - dataset.prefetch(2 * batch_size) - - # For training, preprocess the image and shuffle. - if mode == tf.estimator.ModeKeys.TRAIN: - dataset = dataset.map(_train_preprocess_fn) - dataset.prefetch(2 * batch_size) - - # Ensure that the capacity is sufficiently large to provide good random - # shuffling. - buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4) + 3 * batch_size - dataset = dataset.shuffle(buffer_size=buffer_size) - - # Subtract off the mean and divide by the variance of the pixels. - dataset = dataset.map( - lambda image, label: (tf.image.per_image_standardization(image), label)) - dataset.prefetch(2 * batch_size) - - # Batch results by up to batch_size, and then fetch the tuple from the - # iterator. - iterator = dataset.batch(batch_size).make_one_shot_iterator() - images, labels = iterator.get_next() - - return {PREDICT_INPUTS: images}, labels - - -def _train_preprocess_fn(image, label): - """Preprocess a single training image of layout [height, width, depth].""" - # Resize the image to add four extra pixels on each side. - image = tf.image.resize_image_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8) - - # Randomly crop a [HEIGHT, WIDTH] section of the image. - image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH]) - - # Randomly flip the image horizontally. - image = tf.image.random_flip_left_right(image) - - return image, label - - -def _dataset_parser(value): - """Parse a CIFAR-10 record from value.""" - # Every record consists of a label followed by the image, with a fixed number - # of bytes for each. - label_bytes = 1 - image_bytes = HEIGHT * WIDTH * DEPTH - record_bytes = label_bytes + image_bytes - - # Convert from a string to a vector of uint8 that is record_bytes long. - raw_record = tf.decode_raw(value, tf.uint8) - - # The first byte represents the label, which we convert from uint8 to int32. - label = tf.cast(raw_record[0], tf.int32) - - # The remaining bytes after the label represent the image, which we reshape - # from [depth * height * width] to [depth, height, width]. - depth_major = tf.reshape(raw_record[label_bytes:record_bytes], - [DEPTH, HEIGHT, WIDTH]) - - # Convert from [depth, height, width] to [height, width, depth], and cast as - # float32. - image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) - - return image, tf.one_hot(label, NUM_CLASSES) +def train_input_fn(training_dir, hyperparameters): + return _generate_synthetic_data(tf.estimator.ModeKeys.TRAIN, batch_size=BATCH_SIZE) -def _record_dataset(filenames): - """Returns an input pipeline Dataset from `filenames`.""" - record_bytes = HEIGHT * WIDTH * DEPTH + 1 - return tf.data.FixedLengthRecordDataset(filenames, record_bytes) +def eval_input_fn(training_dir, hyperparameters): + return _generate_synthetic_data(tf.estimator.ModeKeys.EVAL, batch_size=BATCH_SIZE) -def _filenames(mode, data_dir): - """Returns a list of filenames based on 'mode'.""" - data_dir = os.path.join(data_dir, 'cifar-10-batches-bin') +def _generate_synthetic_data(mode, batch_size): + input_shape = [batch_size, HEIGHT, WIDTH, DEPTH] + images = tf.truncated_normal( + input_shape, + dtype=tf.float32, + stddev=1e-1, + name='synthetic_images') + labels = tf.random_uniform( + [batch_size, NUM_CLASSES], + minval=0, + maxval=NUM_CLASSES - 1, + dtype=tf.float32, + name='synthetic_labels') - assert os.path.exists(data_dir), ('Run cifar10_download_and_extract.py first ' - 'to download and extract the CIFAR-10 data.') + images = tf.contrib.framework.local_variable(images, name='images') + labels = tf.contrib.framework.local_variable(labels, name='labels') - if mode == tf.estimator.ModeKeys.TRAIN: - return [ - os.path.join(data_dir, 'data_batch_%d.bin' % i) - for i in range(1, NUM_DATA_BATCHES + 1) - ] - elif mode == tf.estimator.ModeKeys.EVAL: - return [os.path.join(data_dir, 'test_batch.bin')] - else: - raise ValueError('Invalid mode: %s' % mode) \ No newline at end of file + return {INPUT_TENSOR_NAME: images}, labels