|
| 1 | +# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +from __future__ import absolute_import |
| 14 | +from __future__ import division |
| 15 | +from __future__ import print_function |
| 16 | + |
| 17 | +import argparse |
| 18 | +import functools |
| 19 | +import os |
| 20 | + |
| 21 | +import tensorflow as tf |
| 22 | + |
| 23 | +import resnet_model |
| 24 | + |
| 25 | +INPUT_TENSOR_NAME = "inputs" |
| 26 | +SIGNATURE_NAME = "serving_default" |
| 27 | + |
| 28 | +HEIGHT = 32 |
| 29 | +WIDTH = 32 |
| 30 | +DEPTH = 3 |
| 31 | +NUM_CLASSES = 10 |
| 32 | +NUM_DATA_BATCHES = 5 |
| 33 | +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 10000 * NUM_DATA_BATCHES |
| 34 | +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 |
| 35 | +RESNET_SIZE = 32 |
| 36 | +BATCH_SIZE = 1 |
| 37 | + |
| 38 | +# Scale the learning rate linearly with the batch size. When the batch size is |
| 39 | +# 128, the learning rate should be 0.05. |
| 40 | +_INITIAL_LEARNING_RATE = 0.05 * BATCH_SIZE / 128 |
| 41 | +_MOMENTUM = 0.9 |
| 42 | + |
| 43 | +# We use a weight decay of 0.0002, which performs better than the 0.0001 that |
| 44 | +# was originally suggested. |
| 45 | +_WEIGHT_DECAY = 2e-4 |
| 46 | + |
| 47 | +_BATCHES_PER_EPOCH = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE |
| 48 | + |
| 49 | + |
| 50 | +def model_fn(features, labels, mode): |
| 51 | + """ |
| 52 | + Model function for CIFAR-10. |
| 53 | + For more information: https://www.tensorflow.org/guide/custom_estimators#write_a_model_function |
| 54 | + """ |
| 55 | + inputs = features[INPUT_TENSOR_NAME] |
| 56 | + tf.summary.image('images', inputs, max_outputs=6) |
| 57 | + |
| 58 | + network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES) |
| 59 | + |
| 60 | + inputs = tf.reshape(inputs, [-1, HEIGHT, WIDTH, DEPTH]) |
| 61 | + |
| 62 | + logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) |
| 63 | + |
| 64 | + predictions = { |
| 65 | + 'classes': tf.argmax(logits, axis=1), |
| 66 | + 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') |
| 67 | + } |
| 68 | + |
| 69 | + if mode == tf.estimator.ModeKeys.PREDICT: |
| 70 | + export_outputs = { |
| 71 | + SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions) |
| 72 | + } |
| 73 | + return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) |
| 74 | + |
| 75 | + # Calculate loss, which includes softmax cross entropy and L2 regularization. |
| 76 | + cross_entropy = tf.losses.softmax_cross_entropy( |
| 77 | + logits=logits, onehot_labels=tf.one_hot(labels, 10)) |
| 78 | + |
| 79 | + # Create a tensor named cross_entropy for logging purposes. |
| 80 | + tf.identity(cross_entropy, name='cross_entropy') |
| 81 | + tf.summary.scalar('cross_entropy', cross_entropy) |
| 82 | + |
| 83 | + # Add weight decay to the loss. |
| 84 | + loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( |
| 85 | + [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) |
| 86 | + |
| 87 | + if mode == tf.estimator.ModeKeys.TRAIN: |
| 88 | + global_step = tf.train.get_or_create_global_step() |
| 89 | + |
| 90 | + # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. |
| 91 | + boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]] |
| 92 | + values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]] |
| 93 | + learning_rate = tf.train.piecewise_constant( |
| 94 | + tf.cast(global_step, tf.int32), boundaries, values) |
| 95 | + |
| 96 | + # Create a tensor named learning_rate for logging purposes |
| 97 | + tf.identity(learning_rate, name='learning_rate') |
| 98 | + tf.summary.scalar('learning_rate', learning_rate) |
| 99 | + |
| 100 | + optimizer = tf.train.MomentumOptimizer( |
| 101 | + learning_rate=learning_rate, |
| 102 | + momentum=_MOMENTUM) |
| 103 | + |
| 104 | + # Batch norm requires update ops to be added as a dependency to the train_op |
| 105 | + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) |
| 106 | + with tf.control_dependencies(update_ops): |
| 107 | + train_op = optimizer.minimize(loss, global_step) |
| 108 | + else: |
| 109 | + train_op = None |
| 110 | + |
| 111 | + return tf.estimator.EstimatorSpec( |
| 112 | + mode=mode, |
| 113 | + predictions=predictions, |
| 114 | + loss=loss, |
| 115 | + train_op=train_op) |
| 116 | + |
| 117 | + |
| 118 | +def serving_input_fn(): |
| 119 | + """ |
| 120 | + Serving input function for CIFAR-10. Specifies the input format the caller of predict() will have to provide. |
| 121 | + For more information: https://www.tensorflow.org/guide/saved_model#build_and_load_a_savedmodel |
| 122 | + """ |
| 123 | + inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])} |
| 124 | + return tf.estimator.export.ServingInputReceiver(inputs, inputs) |
| 125 | + |
| 126 | + |
| 127 | +def make_batch(data_dir, batch_size=2): |
| 128 | + dataset = tf.data.TFRecordDataset(data_dir).repeat() |
| 129 | + |
| 130 | + dataset = dataset.map(parser, num_parallel_calls=batch_size) |
| 131 | + |
| 132 | + min_queue_examples = int(45000 * 0.4) |
| 133 | + # Ensure that the capacity is sufficiently large to provide good random |
| 134 | + # shuffling. |
| 135 | + dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size) |
| 136 | + |
| 137 | + dataset = dataset.batch(batch_size) |
| 138 | + iterator = dataset.make_one_shot_iterator() |
| 139 | + |
| 140 | + image, label = iterator.get_next() |
| 141 | + |
| 142 | + return image, label |
| 143 | + |
| 144 | + |
| 145 | +def parser(serialized_example): |
| 146 | + features = tf.parse_single_example( |
| 147 | + serialized_example, |
| 148 | + features={ |
| 149 | + 'image': tf.FixedLenFeature([], tf.string), |
| 150 | + 'label': tf.FixedLenFeature([], tf.int64), |
| 151 | + }) |
| 152 | + image = tf.decode_raw(features['image'], tf.uint8) |
| 153 | + image.set_shape([DEPTH * HEIGHT * WIDTH]) |
| 154 | + |
| 155 | + image = tf.cast(tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]), tf.float32) |
| 156 | + label = tf.cast(features['label'], tf.int32) |
| 157 | + |
| 158 | + return image, label |
| 159 | + |
| 160 | + |
| 161 | +def train_input_fn(data_dir): |
| 162 | + with tf.device('/cpu:0'): |
| 163 | + train_data = os.path.join(data_dir, 'train.tfrecords') |
| 164 | + image_batch, label_batch = make_batch(train_data, BATCH_SIZE) |
| 165 | + return {INPUT_TENSOR_NAME: image_batch}, label_batch |
| 166 | + |
| 167 | + |
| 168 | +def eval_input_fn(data_dir): |
| 169 | + with tf.device('/cpu:0'): |
| 170 | + eval_data = os.path.join(data_dir, 'eval.tfrecords') |
| 171 | + image_batch, label_batch = make_batch(eval_data, BATCH_SIZE) |
| 172 | + |
| 173 | + return {INPUT_TENSOR_NAME: image_batch}, label_batch |
| 174 | + |
| 175 | + |
| 176 | +def train(model_dir, data_dir, train_steps): |
| 177 | + estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir) |
| 178 | + |
| 179 | + temp_input_fn = functools.partial(train_input_fn, data_dir) |
| 180 | + |
| 181 | + train_spec = tf.estimator.TrainSpec(temp_input_fn, max_steps=train_steps) |
| 182 | + |
| 183 | + exporter = tf.estimator.LatestExporter('Servo', serving_input_receiver_fn=serving_input_fn) |
| 184 | + temp_eval_fn = functools.partial(eval_input_fn, data_dir) |
| 185 | + eval_spec = tf.estimator.EvalSpec(temp_eval_fn, steps=1, exporters=exporter) |
| 186 | + |
| 187 | + tf.estimator.train_and_evaluate(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec) |
| 188 | + |
| 189 | + |
| 190 | +def main(model_dir, data_dir, train_steps): |
| 191 | + tf.logging.set_verbosity(tf.logging.INFO) |
| 192 | + train(model_dir, data_dir, train_steps) |
| 193 | + |
| 194 | + |
| 195 | +if __name__ == '__main__': |
| 196 | + args_parser = argparse.ArgumentParser() |
| 197 | + # For more information: |
| 198 | + # https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo.html |
| 199 | + args_parser.add_argument( |
| 200 | + '--data-dir', |
| 201 | + default='/opt/ml/input/data/training', |
| 202 | + type=str, |
| 203 | + help='The directory where the CIFAR-10 input data is stored. Default: /opt/ml/input/data/training. This ' |
| 204 | + 'directory corresponds to the SageMaker channel named \'training\', which was specified when creating ' |
| 205 | + 'our training job on SageMaker') |
| 206 | + |
| 207 | + # For more information: |
| 208 | + # https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html |
| 209 | + args_parser.add_argument( |
| 210 | + '--model-dir', |
| 211 | + default='/opt/ml/model', |
| 212 | + type=str, |
| 213 | + help='The directory where the model will be stored. Default: /opt/ml/model. This directory should contain all ' |
| 214 | + 'final model artifacts as Amazon SageMaker copies all data within this directory as a single object in ' |
| 215 | + 'compressed tar format.') |
| 216 | + |
| 217 | + args_parser.add_argument( |
| 218 | + '--train-steps', |
| 219 | + type=int, |
| 220 | + default=100, |
| 221 | + help='The number of steps to use for training.') |
| 222 | + args = args_parser.parse_args() |
| 223 | + main(**vars(args)) |
0 commit comments