Skip to content

Commit ae4df9a

Browse files
author
Jonathan Esterhazy
committed
add mxnet_gluon_mnist example
1 parent 30a869e commit ae4df9a

File tree

4 files changed

+451
-0
lines changed

4 files changed

+451
-0
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import os
2+
3+
os.environ['AWS_ACCESS_KEY_ID'] = 'type your aws access key id here'
4+
os.environ['AWS_SECRET_ACCESS_KEY'] = 'type your aws secret access key here'
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
<script type="text/Javascript">
2+
var pixels = [];
3+
for (var i = 0; i < 28*28; i++) pixels[i] = 0;
4+
var click = 0;
5+
6+
var canvas = document.querySelector("canvas");
7+
canvas.addEventListener("mousemove", function(e){
8+
if (e.buttons == 1) {
9+
click = 1;
10+
canvas.getContext("2d").fillStyle = "rgb(0,0,0)";
11+
canvas.getContext("2d").fillRect(e.offsetX, e.offsetY, 8, 8);
12+
x = Math.floor(e.offsetY * 0.2);
13+
y = Math.floor(e.offsetX * 0.2) + 1;
14+
for (var dy = 0; dy < 2; dy++){
15+
for (var dx = 0; dx < 2; dx++){
16+
if ((x + dx < 28) && (y + dy < 28)){
17+
pixels[(y+dy)+(x+dx)*28] = 1;
18+
}
19+
}
20+
}
21+
} else {
22+
if (click == 1) set_value();
23+
click = 0;
24+
}
25+
});
26+
function clear_value(){
27+
canvas.getContext("2d").fillStyle = "rgb(255,255,255)";
28+
canvas.getContext("2d").fillRect(0, 0, 140, 140);
29+
for (var i = 0; i < 28*28; i++) pixels[i] = 0;
30+
}
31+
32+
function set_value(){
33+
var result = "[["
34+
for (var i = 0; i < 28; i++) {
35+
result += "["
36+
for (var j = 0; j < 28; j++) {
37+
result += pixels [i * 28 + j]
38+
if (j < 27) {
39+
result += ", "
40+
}
41+
}
42+
result += "]"
43+
if (i < 27) {
44+
result += ", "
45+
}
46+
}
47+
result += "]]"
48+
var kernel = IPython.notebook.kernel;
49+
kernel.execute("data = " + result)
50+
}
51+
</script>
52+
<table>
53+
<td style="border-style: none;">
54+
<div style="border: solid 2px #666; width: 143px; height: 144px;">
55+
<canvas width="140" height="140"></canvas>
56+
</div></td>
57+
<td style="border-style: none;">
58+
<button onclick="clear_value()">Clear</button>
59+
</td>
60+
</table>
61+
62+
63+
64+
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import print_function
2+
3+
import logging
4+
import mxnet as mx
5+
from mxnet import gluon, autograd
6+
from mxnet.gluon import nn
7+
import numpy as np
8+
import json
9+
10+
11+
logging.basicConfig(level=logging.DEBUG)
12+
13+
# ------------------------------------------------------------ #
14+
# Training methods #
15+
# ------------------------------------------------------------ #
16+
17+
18+
def train(channel_input_dirs, hyperparameters, **kwargs):
19+
# IM passes num_cpus, num_gpus and other args we can use to tailor training to
20+
# the current container environment, but here we just use simple cpu context.
21+
ctx = mx.cpu()
22+
23+
# retrieve the hyperparameters we set in notebook (with some defaults)
24+
batch_size = hyperparameters.get('batch_size', 100)
25+
epochs = hyperparameters.get('epochs', 10)
26+
learning_rate = hyperparameters.get('learning_rate', 0.1)
27+
momentum = hyperparameters.get('momentum', 0.9)
28+
log_interval = hyperparameters.get('log_interval', 100)
29+
30+
# load training and validation data
31+
# we use the gluon.data.vision.MNIST class because of its built in mnist pre-processing logic,
32+
# but point it at the location where IM placed the data files, so it doesn't download them again.
33+
training_dir = channel_input_dirs['training']
34+
train_data = get_train_data(training_dir + '/train', batch_size)
35+
val_data = get_val_data(training_dir + '/test', batch_size)
36+
37+
# define the network
38+
net = define_network()
39+
40+
# Collect all parameters from net and its children, then initialize them.
41+
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
42+
# Trainer is for updating parameters with gradient.
43+
trainer = gluon.Trainer(net.collect_params(), 'sgd',
44+
{'learning_rate': learning_rate, 'momentum': momentum})
45+
metric = mx.metric.Accuracy()
46+
loss = gluon.loss.SoftmaxCrossEntropyLoss()
47+
48+
for epoch in range(epochs):
49+
# reset data iterator and metric at begining of epoch.
50+
metric.reset()
51+
for i, (data, label) in enumerate(train_data):
52+
# Copy data to ctx if necessary
53+
data = data.as_in_context(ctx)
54+
label = label.as_in_context(ctx)
55+
# Start recording computation graph with record() section.
56+
# Recorded graphs can then be differentiated with backward.
57+
with autograd.record():
58+
output = net(data)
59+
L = loss(output, label)
60+
L.backward()
61+
# take a gradient step with batch_size equal to data.shape[0]
62+
trainer.step(data.shape[0])
63+
# update metric at last.
64+
metric.update([label], [output])
65+
66+
if i % log_interval == 0 and i > 0:
67+
name, acc = metric.get()
68+
print('[Epoch %d Batch %d] Training: %s=%f' % (epoch, i, name, acc))
69+
70+
name, acc = metric.get()
71+
print('[Epoch %d] Training: %s=%f' % (epoch, name, acc))
72+
73+
name, val_acc = test(ctx, net, val_data)
74+
print('[Epoch %d] Validation: %s=%f' % (epoch, name, val_acc))
75+
76+
return net
77+
78+
79+
def save(net, model_dir):
80+
# save the model
81+
y = net(mx.sym.var('data'))
82+
y.save('%s/model.json' % model_dir)
83+
net.collect_params().save('%s/model.params' % model_dir)
84+
85+
86+
def define_network():
87+
net = nn.Sequential()
88+
with net.name_scope():
89+
net.add(nn.Dense(128, activation='relu'))
90+
net.add(nn.Dense(64, activation='relu'))
91+
net.add(nn.Dense(10))
92+
return net
93+
94+
95+
def input_transformer(data, label):
96+
data = data.reshape((-1,)).astype(np.float32) / 255
97+
return data, label
98+
99+
100+
def get_train_data(data_dir, batch_size):
101+
return gluon.data.DataLoader(
102+
gluon.data.vision.MNIST(data_dir, train=True, transform=input_transformer),
103+
batch_size=batch_size, shuffle=True, last_batch='discard')
104+
105+
106+
def get_val_data(data_dir, batch_size):
107+
return gluon.data.DataLoader(
108+
gluon.data.vision.MNIST(data_dir, train=False, transform=input_transformer),
109+
batch_size=batch_size, shuffle=False)
110+
111+
112+
def test(ctx, net, val_data):
113+
metric = mx.metric.Accuracy()
114+
for data, label in val_data:
115+
data = data.as_in_context(ctx)
116+
label = label.as_in_context(ctx)
117+
output = net(data)
118+
metric.update([label], [output])
119+
return metric.get()
120+
121+
122+
# ------------------------------------------------------------ #
123+
# Hosting methods #
124+
# ------------------------------------------------------------ #
125+
126+
def model_fn(model_dir):
127+
"""
128+
Load the gluon model. Called once when hosting service starts.
129+
130+
:param: model_dir The directory where model files are stored.
131+
:return: a model (in this case a Gluon network)
132+
"""
133+
symbol = mx.sym.load('%s/model.json' % model_dir)
134+
outputs = mx.symbol.softmax(data=symbol, name='softmax_label')
135+
inputs = mx.sym.var('data')
136+
param_dict = gluon.ParameterDict('model_')
137+
net = gluon.SymbolBlock(outputs, inputs, param_dict)
138+
net.load_params('%s/model.params' % model_dir, ctx=mx.cpu())
139+
return net
140+
141+
142+
def transform_fn(net, data, input_content_type, output_content_type):
143+
"""
144+
Transform a request using the Gluon model. Called once per request.
145+
146+
:param net: The Gluon model.
147+
:param data: The request payload.
148+
:param input_content_type: The request content type.
149+
:param output_content_type: The (desired) response content type.
150+
:return: response payload and content type.
151+
"""
152+
# we can use content types to vary input/output handling, but
153+
# here we just assume json for both
154+
parsed = json.loads(data)
155+
nda = mx.nd.array(parsed)
156+
output = net(nda)
157+
prediction = mx.nd.argmax(output, axis=1)
158+
response_body = json.dumps(prediction.asnumpy().tolist()[0])
159+
return response_body, output_content_type

0 commit comments

Comments
 (0)