|
| 1 | +from __future__ import print_function |
| 2 | + |
| 3 | +import json |
| 4 | +import logging |
| 5 | +import os |
| 6 | +import time |
| 7 | + |
| 8 | +import mxnet as mx |
| 9 | +from mxnet import autograd as ag |
| 10 | +from mxnet import gluon |
| 11 | +from mxnet.gluon.model_zoo import vision as models |
| 12 | + |
| 13 | + |
| 14 | +# ------------------------------------------------------------ # |
| 15 | +# Training methods # |
| 16 | +# ------------------------------------------------------------ # |
| 17 | + |
| 18 | +def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir, hyperparameters, **kwargs): |
| 19 | + # retrieve the hyperparameters we set in notebook (with some defaults) |
| 20 | + batch_size = hyperparameters.get('batch_size', 128) |
| 21 | + epochs = hyperparameters.get('epochs', 100) |
| 22 | + learning_rate = hyperparameters.get('learning_rate', 0.1) |
| 23 | + momentum = hyperparameters.get('momentum', 0.9) |
| 24 | + log_interval = hyperparameters.get('log_interval', 1) |
| 25 | + wd = hyperparameters.get('wd', 0.0001) |
| 26 | + |
| 27 | + if len(hosts) == 1: |
| 28 | + kvstore = 'device' if num_gpus > 0 else 'local' |
| 29 | + else: |
| 30 | + kvstore = 'dist_device_sync' |
| 31 | + |
| 32 | + ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] |
| 33 | + net = models.get_model('resnet34_v2', ctx=ctx, pretrained=False, classes=10) |
| 34 | + batch_size *= max(1, len(ctx)) |
| 35 | + |
| 36 | + # load training and validation data |
| 37 | + # we use the gluon.data.vision.CIFAR10 class because of its built in pre-processing logic, |
| 38 | + # but point it at the location where SageMaker placed the data files, so it doesn't download them again. |
| 39 | + |
| 40 | + part_index = 0 |
| 41 | + for i, host in enumerate(hosts): |
| 42 | + if host == current_host: |
| 43 | + part_index = i |
| 44 | + break |
| 45 | + |
| 46 | + |
| 47 | + data_dir = channel_input_dirs['training'] |
| 48 | + train_data = get_train_data(num_cpus, data_dir, batch_size, (3, 32, 32), |
| 49 | + num_parts=len(hosts), part_index=part_index) |
| 50 | + test_data = get_test_data(num_cpus, data_dir, batch_size, (3, 32, 32)) |
| 51 | + |
| 52 | + # Collect all parameters from net and its children, then initialize them. |
| 53 | + net.initialize(mx.init.Xavier(magnitude=2), ctx=ctx) |
| 54 | + # Trainer is for updating parameters with gradient. |
| 55 | + trainer = gluon.Trainer(net.collect_params(), 'sgd', |
| 56 | + optimizer_params={'learning_rate': learning_rate, 'momentum': momentum, 'wd': wd}, |
| 57 | + kvstore=kvstore) |
| 58 | + metric = mx.metric.Accuracy() |
| 59 | + loss = gluon.loss.SoftmaxCrossEntropyLoss() |
| 60 | + |
| 61 | + best_accuracy = 0.0 |
| 62 | + for epoch in range(epochs): |
| 63 | + # reset data iterator and metric at begining of epoch. |
| 64 | + train_data.reset() |
| 65 | + tic = time.time() |
| 66 | + metric.reset() |
| 67 | + btic = time.time() |
| 68 | + |
| 69 | + for i, batch in enumerate(train_data): |
| 70 | + data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) |
| 71 | + label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) |
| 72 | + outputs = [] |
| 73 | + Ls = [] |
| 74 | + with ag.record(): |
| 75 | + for x, y in zip(data, label): |
| 76 | + z = net(x) |
| 77 | + L = loss(z, y) |
| 78 | + # store the loss and do backward after we have done forward |
| 79 | + # on all GPUs for better speed on multiple GPUs. |
| 80 | + Ls.append(L) |
| 81 | + outputs.append(z) |
| 82 | + for L in Ls: |
| 83 | + L.backward() |
| 84 | + trainer.step(batch.data[0].shape[0]) |
| 85 | + metric.update(label, outputs) |
| 86 | + if i % log_interval == 0 and i > 0: |
| 87 | + name, acc = metric.get() |
| 88 | + logging.info('Epoch [%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f' % |
| 89 | + (epoch, i, batch_size / (time.time() - btic), name, acc)) |
| 90 | + btic = time.time() |
| 91 | + |
| 92 | + name, acc = metric.get() |
| 93 | + logging.info('[Epoch %d] training: %s=%f' % (epoch, name, acc)) |
| 94 | + logging.info('[Epoch %d] time cost: %f' % (epoch, time.time() - tic)) |
| 95 | + |
| 96 | + name, val_acc = test(ctx, net, test_data) |
| 97 | + logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc)) |
| 98 | + |
| 99 | + # only save params on primary host |
| 100 | + if current_host == hosts[0]: |
| 101 | + if val_acc > best_accuracy: |
| 102 | + net.save_params('{}/model-{:0>4}.params'.format(model_dir, epoch)) |
| 103 | + best_accuracy = val_acc |
| 104 | + |
| 105 | + return net |
| 106 | + |
| 107 | + |
| 108 | +def save(net, model_dir): |
| 109 | + # model_dir will be empty except on primary container |
| 110 | + files = os.listdir(model_dir) |
| 111 | + if files: |
| 112 | + best = sorted(os.listdir(model_dir))[-1] |
| 113 | + os.rename(os.path.join(model_dir, best), os.path.join(model_dir, 'model.params')) |
| 114 | + |
| 115 | + |
| 116 | +def get_data(path, augment, num_cpus, batch_size, data_shape, resize=-1, num_parts=1, part_index=0): |
| 117 | + return mx.io.ImageRecordIter( |
| 118 | + path_imgrec=path, |
| 119 | + resize=resize, |
| 120 | + data_shape=data_shape, |
| 121 | + batch_size=batch_size, |
| 122 | + rand_crop=augment, |
| 123 | + rand_mirror=augment, |
| 124 | + preprocess_threads=num_cpus, |
| 125 | + num_parts=num_parts, |
| 126 | + part_index=part_index) |
| 127 | + |
| 128 | + |
| 129 | +def get_test_data(num_cpus, data_dir, batch_size, data_shape, resize=-1): |
| 130 | + return get_data(os.path.join(data_dir, "test.rec"), False, num_cpus, batch_size, data_shape, resize, 1, 0) |
| 131 | + |
| 132 | + |
| 133 | +def get_train_data(num_cpus, data_dir, batch_size, data_shape, resize=-1, num_parts=1, part_index=0): |
| 134 | + return get_data(os.path.join(data_dir, "train.rec"), True, num_cpus, batch_size, data_shape, resize, num_parts, |
| 135 | + part_index) |
| 136 | + |
| 137 | + |
| 138 | +def test(ctx, net, test_data): |
| 139 | + test_data.reset() |
| 140 | + metric = mx.metric.Accuracy() |
| 141 | + |
| 142 | + for i, batch in enumerate(test_data): |
| 143 | + data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) |
| 144 | + label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) |
| 145 | + outputs = [] |
| 146 | + for x in data: |
| 147 | + outputs.append(net(x)) |
| 148 | + metric.update(label, outputs) |
| 149 | + return metric.get() |
| 150 | + |
| 151 | + |
| 152 | +# ------------------------------------------------------------ # |
| 153 | +# Hosting methods # |
| 154 | +# ------------------------------------------------------------ # |
| 155 | + |
| 156 | +def model_fn(model_dir): |
| 157 | + """ |
| 158 | + Load the gluon model. Called once when hosting service starts. |
| 159 | +
|
| 160 | + :param: model_dir The directory where model files are stored. |
| 161 | + :return: a model (in this case a Gluon network) |
| 162 | + """ |
| 163 | + |
| 164 | + net = models.get_model('resnet34_v2', ctx=mx.cpu(), pretrained=False, classes=10) |
| 165 | + net.load_params('%s/model.params' % model_dir, ctx=mx.cpu()) |
| 166 | + return net |
| 167 | + |
| 168 | + |
| 169 | +def transform_fn(net, data, input_content_type, output_content_type): |
| 170 | + """ |
| 171 | + Transform a request using the Gluon model. Called once per request. |
| 172 | +
|
| 173 | + :param net: The Gluon model. |
| 174 | + :param data: The request payload. |
| 175 | + :param input_content_type: The request content type. |
| 176 | + :param output_content_type: The (desired) response content type. |
| 177 | + :return: response payload and content type. |
| 178 | + """ |
| 179 | + # we can use content types to vary input/output handling, but |
| 180 | + # here we just assume json for both |
| 181 | + parsed = json.loads(data) |
| 182 | + nda = mx.nd.array(parsed) |
| 183 | + output = net(nda) |
| 184 | + prediction = mx.nd.argmax(output, axis=1) |
| 185 | + response_body = json.dumps(prediction.asnumpy().tolist()[0]) |
| 186 | + return response_body, output_content_type |
0 commit comments