Skip to content

Commit 9bd5fa3

Browse files
authored
Merge pull request aws#214 from winstonaws/fix_tf_cifar_predict
Fix predictions for tensorflow cifar example
2 parents a240855 + 7724496 commit 9bd5fa3

File tree

2 files changed

+42
-24
lines changed

2 files changed

+42
-24
lines changed

sagemaker-python-sdk/tensorflow_resnet_cifar10_with_tensorboard/source_dir/resnet_cifar_10.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,39 +106,39 @@ def model_fn(features, labels, mode, params):
106106

107107

108108
def serving_input_fn(params):
109-
feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=(32, 32, 3))}
110-
return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()
109+
inputs = {INPUT_TENSOR_NAME: tf.placeholder(tf.float32, [None, 32, 32, 3])}
110+
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
111111

112112

113113
def train_input_fn(training_dir, params):
114-
return input_fn(tf.estimator.ModeKeys.TRAIN,
114+
return _input_from_files(tf.estimator.ModeKeys.TRAIN,
115115
batch_size=BATCH_SIZE, data_dir=training_dir)
116116

117117

118118
def eval_input_fn(training_dir, params):
119-
return input_fn(tf.estimator.ModeKeys.EVAL,
119+
return _input_from_files(tf.estimator.ModeKeys.EVAL,
120120
batch_size=BATCH_SIZE, data_dir=training_dir)
121121

122122

123-
def input_fn(mode, batch_size, data_dir):
124-
"""Input_fn using the contrib.data input pipeline for CIFAR-10 dataset.
123+
def _input_from_files(mode, batch_size, data_dir):
124+
"""Uses the contrib.data input pipeline for CIFAR-10 dataset.
125125
126126
Args:
127127
mode: Standard names for model modes (tf.estimators.ModeKeys).
128128
batch_size: The number of samples per batch of input requested.
129129
"""
130-
dataset = record_dataset(filenames(mode, data_dir))
130+
dataset = _record_dataset(_filenames(mode, data_dir))
131131

132132
# For training repeat forever.
133133
if mode == tf.estimator.ModeKeys.TRAIN:
134134
dataset = dataset.repeat()
135135

136-
dataset = dataset.map(dataset_parser, num_threads=1,
136+
dataset = dataset.map(_dataset_parser, num_threads=1,
137137
output_buffer_size=2 * batch_size)
138138

139139
# For training, preprocess the image and shuffle.
140140
if mode == tf.estimator.ModeKeys.TRAIN:
141-
dataset = dataset.map(train_preprocess_fn, num_threads=1,
141+
dataset = dataset.map(_train_preprocess_fn, num_threads=1,
142142
output_buffer_size=2 * batch_size)
143143

144144
# Ensure that the capacity is sufficiently large to provide good random
@@ -160,7 +160,7 @@ def input_fn(mode, batch_size, data_dir):
160160
return {INPUT_TENSOR_NAME: images}, labels
161161

162162

163-
def train_preprocess_fn(image, label):
163+
def _train_preprocess_fn(image, label):
164164
"""Preprocess a single training image of layout [height, width, depth]."""
165165
# Resize the image to add four extra pixels on each side.
166166
image = tf.image.resize_image_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8)
@@ -174,7 +174,7 @@ def train_preprocess_fn(image, label):
174174
return image, label
175175

176176

177-
def dataset_parser(value):
177+
def _dataset_parser(value):
178178
"""Parse a CIFAR-10 record from value."""
179179
# Every record consists of a label followed by the image, with a fixed number
180180
# of bytes for each.
@@ -200,13 +200,13 @@ def dataset_parser(value):
200200
return image, tf.one_hot(label, NUM_CLASSES)
201201

202202

203-
def record_dataset(filenames):
203+
def _record_dataset(filenames):
204204
"""Returns an input pipeline Dataset from `filenames`."""
205205
record_bytes = HEIGHT * WIDTH * DEPTH + 1
206206
return tf.contrib.data.FixedLengthRecordDataset(filenames, record_bytes)
207207

208208

209-
def filenames(mode, data_dir):
209+
def _filenames(mode, data_dir):
210210
"""Returns a list of filenames based on 'mode'."""
211211
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
212212

sagemaker-python-sdk/tensorflow_resnet_cifar10_with_tensorboard/tensorflow_resnet_cifar10_with_tensorboard.ipynb

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@
138138
"\n",
139139
"It takes a few minutes to provision containers and start the training job.**TensorBoard** will start to display metrics shortly after that.\n",
140140
"\n",
141-
"You can access **TensorBoard** locally at [http://localhost:6006](http://localhost:6006) or using your SageMaker notebook instance [proxy/6006/](/proxy/6006/)(TensorBoard will not work if forget to put the slash, '/', in end of the url). If TensorBoard started on a different port, adjust these URLs to match.",
142-
"This example uses the optional hyperparameter **```min_eval_frequency```** to generate training evaluations more often, allowing to visualize **TensorBoard** scalar data faster. You can find the available optional hyperparameters [here](https://github.com/aws/sagemaker-python-sdk#optional-hyperparameters)**."
141+
"You can access **TensorBoard** locally at [http://localhost:6006](http://localhost:6006) or using your SageMaker notebook instance [proxy/6006/](/proxy/6006/)(TensorBoard will not work if forget to put the slash, '/', in end of the url). If TensorBoard started on a different port, adjust these URLs to match.This example uses the optional hyperparameter **```min_eval_frequency```** to generate training evaluations more often, allowing to visualize **TensorBoard** scalar data faster. You can find the available optional hyperparameters [here](https://github.com/aws/sagemaker-python-sdk#optional-hyperparameters)**."
143142
]
144143
},
145144
{
@@ -156,14 +155,33 @@
156155
{
157156
"cell_type": "code",
158157
"execution_count": null,
159-
"metadata": {
160-
"collapsed": true
161-
},
158+
"metadata": {},
162159
"outputs": [],
163160
"source": [
164161
"predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')"
165162
]
166163
},
164+
{
165+
"cell_type": "markdown",
166+
"metadata": {},
167+
"source": [
168+
"# Make a prediction with fake data to verify the endpoint is up\n",
169+
"\n",
170+
"Prediction is not the focus of this notebook, so to verify the endpoint's functionality, we'll simply generate random data in the correct shape and make a prediction."
171+
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": null,
176+
"metadata": {},
177+
"outputs": [],
178+
"source": [
179+
"import numpy as np\n",
180+
"\n",
181+
"random_image_data = np.random.rand(32, 32, 3)\n",
182+
"predictor.predict(random_image_data)"
183+
]
184+
},
167185
{
168186
"cell_type": "markdown",
169187
"metadata": {},
@@ -185,24 +203,24 @@
185203
}
186204
],
187205
"metadata": {
188-
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.",
189206
"kernelspec": {
190-
"display_name": "Environment (conda_tensorflow_p27)",
207+
"display_name": "conda_tensorflow_p27",
191208
"language": "python",
192209
"name": "conda_tensorflow_p27"
193210
},
194211
"language_info": {
195212
"codemirror_mode": {
196213
"name": "ipython",
197-
"version": 3
214+
"version": 2
198215
},
199216
"file_extension": ".py",
200217
"mimetype": "text/x-python",
201218
"name": "python",
202219
"nbconvert_exporter": "python",
203-
"pygments_lexer": "ipython3",
204-
"version": "2.7.13"
205-
}
220+
"pygments_lexer": "ipython2",
221+
"version": "2.7.11"
222+
},
223+
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
206224
},
207225
"nbformat": 4,
208226
"nbformat_minor": 2

0 commit comments

Comments
 (0)