Skip to content

Commit 8efffe3

Browse files
authored
Merge pull request aws#371 from awslabs/arpin_random_hpo
Added: MXNet Gluon CIFAR-10 Automatic Model Tuning vs random search
2 parents 822f4bc + fbbb9b6 commit 8efffe3

File tree

4 files changed

+2234
-0
lines changed

4 files changed

+2234
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from __future__ import print_function
2+
3+
import json
4+
import logging
5+
import os
6+
import time
7+
8+
import mxnet as mx
9+
from mxnet import autograd as ag
10+
from mxnet import gluon
11+
from mxnet.gluon.model_zoo import vision as models
12+
13+
14+
# ------------------------------------------------------------ #
15+
# Training methods #
16+
# ------------------------------------------------------------ #
17+
18+
def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir, hyperparameters, **kwargs):
19+
# retrieve the hyperparameters we set in notebook (with some defaults)
20+
batch_size = hyperparameters.get('batch_size', 128)
21+
epochs = hyperparameters.get('epochs', 100)
22+
learning_rate = hyperparameters.get('learning_rate', 0.1)
23+
momentum = hyperparameters.get('momentum', 0.9)
24+
log_interval = hyperparameters.get('log_interval', 1)
25+
wd = hyperparameters.get('wd', 0.0001)
26+
27+
if len(hosts) == 1:
28+
kvstore = 'device' if num_gpus > 0 else 'local'
29+
else:
30+
kvstore = 'dist_device_sync'
31+
32+
ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
33+
net = models.get_model('resnet34_v2', ctx=ctx, pretrained=False, classes=10)
34+
batch_size *= max(1, len(ctx))
35+
36+
# load training and validation data
37+
# we use the gluon.data.vision.CIFAR10 class because of its built in pre-processing logic,
38+
# but point it at the location where SageMaker placed the data files, so it doesn't download them again.
39+
40+
part_index = 0
41+
for i, host in enumerate(hosts):
42+
if host == current_host:
43+
part_index = i
44+
break
45+
46+
47+
data_dir = channel_input_dirs['training']
48+
train_data = get_train_data(num_cpus, data_dir, batch_size, (3, 32, 32),
49+
num_parts=len(hosts), part_index=part_index)
50+
test_data = get_test_data(num_cpus, data_dir, batch_size, (3, 32, 32))
51+
52+
# Collect all parameters from net and its children, then initialize them.
53+
net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx)
54+
# Trainer is for updating parameters with gradient.
55+
trainer = gluon.Trainer(net.collect_params(), 'sgd',
56+
optimizer_params={'learning_rate': learning_rate, 'momentum': momentum, 'wd': wd},
57+
kvstore=kvstore)
58+
metric = mx.metric.Accuracy()
59+
loss = gluon.loss.SoftmaxCrossEntropyLoss()
60+
61+
best_accuracy = 0.0
62+
for epoch in range(epochs):
63+
# reset data iterator and metric at begining of epoch.
64+
train_data.reset()
65+
tic = time.time()
66+
metric.reset()
67+
btic = time.time()
68+
69+
for i, batch in enumerate(train_data):
70+
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
71+
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
72+
outputs = []
73+
Ls = []
74+
with ag.record():
75+
for x, y in zip(data, label):
76+
z = net(x)
77+
L = loss(z, y)
78+
# store the loss and do backward after we have done forward
79+
# on all GPUs for better speed on multiple GPUs.
80+
Ls.append(L)
81+
outputs.append(z)
82+
for L in Ls:
83+
L.backward()
84+
trainer.step(batch.data[0].shape[0])
85+
metric.update(label, outputs)
86+
if i % log_interval == 0 and i > 0:
87+
name, acc = metric.get()
88+
logging.info('Epoch [%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f' %
89+
(epoch, i, batch_size / (time.time() - btic), name, acc))
90+
btic = time.time()
91+
92+
name, acc = metric.get()
93+
logging.info('[Epoch %d] training: %s=%f' % (epoch, name, acc))
94+
logging.info('[Epoch %d] time cost: %f' % (epoch, time.time() - tic))
95+
96+
name, val_acc = test(ctx, net, test_data)
97+
logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc))
98+
99+
# only save params on primary host
100+
if current_host == hosts[0]:
101+
if val_acc > best_accuracy:
102+
net.save_params('{}/model-{:0>4}.params'.format(model_dir, epoch))
103+
best_accuracy = val_acc
104+
105+
return net
106+
107+
108+
def save(net, model_dir):
109+
# model_dir will be empty except on primary container
110+
files = os.listdir(model_dir)
111+
if files:
112+
best = sorted(os.listdir(model_dir))[-1]
113+
os.rename(os.path.join(model_dir, best), os.path.join(model_dir, 'model.params'))
114+
115+
116+
def get_data(path, augment, num_cpus, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
117+
return mx.io.ImageRecordIter(
118+
path_imgrec=path,
119+
resize=resize,
120+
data_shape=data_shape,
121+
batch_size=batch_size,
122+
rand_crop=augment,
123+
rand_mirror=augment,
124+
preprocess_threads=num_cpus,
125+
num_parts=num_parts,
126+
part_index=part_index)
127+
128+
129+
def get_test_data(num_cpus, data_dir, batch_size, data_shape, resize=-1):
130+
return get_data(os.path.join(data_dir, "test.rec"), False, num_cpus, batch_size, data_shape, resize, 1, 0)
131+
132+
133+
def get_train_data(num_cpus, data_dir, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
134+
return get_data(os.path.join(data_dir, "train.rec"), True, num_cpus, batch_size, data_shape, resize, num_parts,
135+
part_index)
136+
137+
138+
def test(ctx, net, test_data):
139+
test_data.reset()
140+
metric = mx.metric.Accuracy()
141+
142+
for i, batch in enumerate(test_data):
143+
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
144+
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
145+
outputs = []
146+
for x in data:
147+
outputs.append(net(x))
148+
metric.update(label, outputs)
149+
return metric.get()
150+
151+
152+
# ------------------------------------------------------------ #
153+
# Hosting methods #
154+
# ------------------------------------------------------------ #
155+
156+
def model_fn(model_dir):
157+
"""
158+
Load the gluon model. Called once when hosting service starts.
159+
160+
:param: model_dir The directory where model files are stored.
161+
:return: a model (in this case a Gluon network)
162+
"""
163+
164+
net = models.get_model('resnet34_v2', ctx=mx.cpu(), pretrained=False, classes=10)
165+
net.load_params('%s/model.params' % model_dir, ctx=mx.cpu())
166+
return net
167+
168+
169+
def transform_fn(net, data, input_content_type, output_content_type):
170+
"""
171+
Transform a request using the Gluon model. Called once per request.
172+
173+
:param net: The Gluon model.
174+
:param data: The request payload.
175+
:param input_content_type: The request content type.
176+
:param output_content_type: The (desired) response content type.
177+
:return: response payload and content type.
178+
"""
179+
# we can use content types to vary input/output handling, but
180+
# here we just assume json for both
181+
parsed = json.loads(data)
182+
nda = mx.nd.array(parsed)
183+
output = net(nda)
184+
prediction = mx.nd.argmax(output, axis=1)
185+
response_body = json.dumps(prediction.asnumpy().tolist()[0])
186+
return response_body, output_content_type
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy as np
2+
import os
3+
import zipfile
4+
from skimage import io
5+
from mxnet.test_utils import download
6+
7+
8+
def download_training_data():
9+
print('downloading training data...')
10+
if not os.path.isdir("data"):
11+
os.makedirs('data')
12+
if (not os.path.exists('data/train.rec')) or \
13+
(not os.path.exists('data/test.rec')) or \
14+
(not os.path.exists('data/train.lst')) or \
15+
(not os.path.exists('data/test.lst')):
16+
zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip')
17+
with zipfile.ZipFile(zip_file_path) as zf:
18+
zf.extractall()
19+
os.rename('cifar', 'data')
20+
print('done')
21+
22+
23+
def read_image(filename):
24+
img = io.imread(filename)
25+
img = np.array(img).transpose(2, 0, 1)
26+
img = np.expand_dims(img, axis=0)
27+
28+
return img
29+
30+
31+
def read_images(filenames):
32+
return [read_image(f) for f in filenames]

0 commit comments

Comments
 (0)