|
| 1 | +# Standard Library |
| 2 | +import argparse |
| 3 | +import random |
| 4 | + |
| 5 | +# Third Party |
| 6 | +import mxnet as mx |
| 7 | +import numpy as np |
| 8 | +from mxnet import autograd, gluon |
| 9 | +from mxnet.gluon import nn |
| 10 | + |
| 11 | + |
| 12 | +def parse_args(): |
| 13 | + parser = argparse.ArgumentParser( |
| 14 | + description="Train a mxnet gluon model for FashonMNIST dataset" |
| 15 | + ) |
| 16 | + parser.add_argument("--batch-size", type=int, default=256, help="Batch size") |
| 17 | + parser.add_argument("--epochs", type=int, default=1, help="Number of Epochs") |
| 18 | + parser.add_argument("--learning_rate", type=float, default=0.1) |
| 19 | + parser.add_argument( |
| 20 | + "--context", type=str, default="cpu", help="Context can be either cpu or gpu" |
| 21 | + ) |
| 22 | + parser.add_argument( |
| 23 | + "--validate", type=bool, default=True, help="Run validation if running with smdebug" |
| 24 | + ) |
| 25 | + |
| 26 | + opt = parser.parse_args() |
| 27 | + return opt |
| 28 | + |
| 29 | + |
| 30 | +def test(ctx, net, val_data): |
| 31 | + metric = mx.metric.Accuracy() |
| 32 | + for i, (data, label) in enumerate(val_data): |
| 33 | + data = data.as_in_context(ctx) |
| 34 | + label = label.as_in_context(ctx) |
| 35 | + output = net(data) |
| 36 | + metric.update([label], [output]) |
| 37 | + |
| 38 | + return metric.get() |
| 39 | + |
| 40 | + |
| 41 | +def train_model(net, epochs, ctx, learning_rate, momentum, train_data, val_data): |
| 42 | + # Collect all parameters from net and its children, then initialize them. |
| 43 | + net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) |
| 44 | + # Trainer is for updating parameters with gradient. |
| 45 | + trainer = gluon.Trainer( |
| 46 | + net.collect_params(), "sgd", {"learning_rate": learning_rate, "momentum": momentum} |
| 47 | + ) |
| 48 | + metric = mx.metric.Accuracy() |
| 49 | + loss = gluon.loss.SoftmaxCrossEntropyLoss() |
| 50 | + |
| 51 | + for epoch in range(epochs): |
| 52 | + # reset data iterator and metric at begining of epoch. |
| 53 | + metric.reset() |
| 54 | + for i, (data, label) in enumerate(train_data): |
| 55 | + # Copy data to ctx if necessary |
| 56 | + data = data.as_in_context(ctx) |
| 57 | + label = label.as_in_context(ctx) |
| 58 | + # Start recording computation graph with record() section. |
| 59 | + # Recorded graphs can then be differentiated with backward. |
| 60 | + with autograd.record(): |
| 61 | + output = net(data) |
| 62 | + L = loss(output, label) |
| 63 | + L.backward() |
| 64 | + # take a gradient step with batch_size equal to data.shape[0] |
| 65 | + trainer.step(data.shape[0]) |
| 66 | + # update metric at last. |
| 67 | + metric.update([label], [output]) |
| 68 | + |
| 69 | + if i % 100 == 0 and i > 0: |
| 70 | + name, acc = metric.get() |
| 71 | + print("[Epoch %d Batch %d] Training: %s=%f" % (epoch, i, name, acc)) |
| 72 | + |
| 73 | + name, acc = metric.get() |
| 74 | + print("[Epoch %d] Training: %s=%f" % (epoch, name, acc)) |
| 75 | + name, val_acc = test(ctx, net, val_data) |
| 76 | + print("[Epoch %d] Validation: %s=%f" % (epoch, name, val_acc)) |
| 77 | + |
| 78 | + |
| 79 | +def transformer(data, label): |
| 80 | + data = data.reshape((-1,)).astype(np.float32) / 255 |
| 81 | + return data, label |
| 82 | + |
| 83 | + |
| 84 | +def prepare_data(batch_size): |
| 85 | + train_data = gluon.data.DataLoader( |
| 86 | + gluon.data.vision.MNIST("/tmp", train=True, transform=transformer), |
| 87 | + batch_size=batch_size, |
| 88 | + shuffle=True, |
| 89 | + last_batch="discard", |
| 90 | + ) |
| 91 | + |
| 92 | + val_data = gluon.data.DataLoader( |
| 93 | + gluon.data.vision.MNIST("/tmp", train=False, transform=transformer), |
| 94 | + batch_size=batch_size, |
| 95 | + shuffle=False, |
| 96 | + ) |
| 97 | + return train_data, val_data |
| 98 | + |
| 99 | + |
| 100 | +# Create a model using gluon API. The hook is currently |
| 101 | +# supports MXNet gluon models only. |
| 102 | +def create_gluon_model(): |
| 103 | + net = nn.Sequential() |
| 104 | + with net.name_scope(): |
| 105 | + net.add(nn.Dense(128, activation="relu")) |
| 106 | + net.add(nn.Dense(64, activation="relu")) |
| 107 | + net.add(nn.Dense(10)) |
| 108 | + return net |
| 109 | + |
| 110 | + |
| 111 | +def main(): |
| 112 | + opt = parse_args() |
| 113 | + mx.random.seed(128) |
| 114 | + random.seed(12) |
| 115 | + np.random.seed(2) |
| 116 | + |
| 117 | + context = mx.cpu() if opt.context.lower() == "cpu" else mx.gpu() |
| 118 | + # Create a Gluon Model. |
| 119 | + net = create_gluon_model() |
| 120 | + |
| 121 | + # Start the training. |
| 122 | + train_data, val_data = prepare_data(opt.batch_size) |
| 123 | + |
| 124 | + train_model( |
| 125 | + net=net, |
| 126 | + epochs=opt.epochs, |
| 127 | + ctx=context, |
| 128 | + learning_rate=opt.learning_rate, |
| 129 | + momentum=0.9, |
| 130 | + train_data=train_data, |
| 131 | + val_data=val_data, |
| 132 | + ) |
| 133 | + |
| 134 | + |
| 135 | +if __name__ == "__main__": |
| 136 | + main() |
0 commit comments