|
| 1 | +import argparse |
| 2 | +import gzip |
| 3 | +import json |
1 | 4 | import logging
|
| 5 | +import os |
| 6 | +import struct |
2 | 7 |
|
3 |
| -import gzip |
4 | 8 | import mxnet as mx
|
5 | 9 | import numpy as np
|
6 |
| -import os |
7 |
| -import struct |
| 10 | + |
| 11 | +from sagemaker_mxnet_container.training_utils import scheduler_host |
8 | 12 |
|
9 | 13 |
|
10 | 14 | def load_data(path):
|
@@ -35,39 +39,80 @@ def build_graph():
|
35 | 39 | return mx.sym.SoftmaxOutput(data=fc3, name='softmax')
|
36 | 40 |
|
37 | 41 |
|
38 |
| -def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus): |
39 |
| - (train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train'])) |
40 |
| - (test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test'])) |
| 42 | +def get_training_context(num_gpus): |
| 43 | + if num_gpus: |
| 44 | + return [mx.gpu(i) for i in range(num_gpus)] |
| 45 | + else: |
| 46 | + return mx.cpu() |
| 47 | + |
| 48 | + |
| 49 | +def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel, |
| 50 | + hosts, current_host, model_dir): |
| 51 | + (train_labels, train_images) = load_data(training_channel) |
| 52 | + (test_labels, test_images) = load_data(testing_channel) |
41 | 53 |
|
42 |
| - # Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key |
43 |
| - # to do parallel training. |
| 54 | + # Data parallel training - shard the data so each host |
| 55 | + # only trains on a subset of the total data. |
44 | 56 | shard_size = len(train_images) // len(hosts)
|
45 | 57 | for i, host in enumerate(hosts):
|
46 | 58 | if host == current_host:
|
47 | 59 | start = shard_size * i
|
48 | 60 | end = start + shard_size
|
49 | 61 | break
|
50 | 62 |
|
51 |
| - batch_size = 100 |
52 |
| - train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True) |
| 63 | + train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, |
| 64 | + shuffle=True) |
53 | 65 | val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
|
| 66 | + |
54 | 67 | logging.getLogger().setLevel(logging.DEBUG)
|
| 68 | + |
55 | 69 | kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
|
56 |
| - mlp_model = mx.mod.Module( |
57 |
| - symbol=build_graph(), |
58 |
| - context=get_train_context(num_cpus, num_gpus)) |
| 70 | + |
| 71 | + mlp_model = mx.mod.Module(symbol=build_graph(), |
| 72 | + context=get_training_context(num_gpus)) |
59 | 73 | mlp_model.fit(train_iter,
|
60 | 74 | eval_data=val_iter,
|
61 | 75 | kvstore=kvstore,
|
62 | 76 | optimizer='sgd',
|
63 |
| - optimizer_params={'learning_rate': float(hyperparameters.get("learning_rate", 0.1))}, |
| 77 | + optimizer_params={'learning_rate': learning_rate}, |
64 | 78 | eval_metric='acc',
|
65 | 79 | batch_end_callback=mx.callback.Speedometer(batch_size, 100),
|
66 |
| - num_epoch=25) |
67 |
| - return mlp_model |
| 80 | + num_epoch=epochs) |
| 81 | + |
| 82 | + if current_host == scheduler_host(hosts): |
| 83 | + save(model_dir, mlp_model) |
| 84 | + |
| 85 | + |
| 86 | +def save(model_dir, model): |
| 87 | + model.symbol.save(os.path.join(model_dir, 'model-symbol.json')) |
| 88 | + model.save_params(os.path.join(model_dir, 'model-0000.params')) |
| 89 | + |
| 90 | + signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]} |
| 91 | + for data_desc in model.data_shapes] |
| 92 | + with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f: |
| 93 | + json.dump(signature, f) |
| 94 | + |
| 95 | + |
| 96 | +def parse_args(): |
| 97 | + parser = argparse.ArgumentParser() |
| 98 | + |
| 99 | + parser.add_argument('--batch-size', type=int, default=100) |
| 100 | + parser.add_argument('--epochs', type=int, default=10) |
| 101 | + parser.add_argument('--learning-rate', type=float, default=0.1) |
| 102 | + |
| 103 | + parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) |
| 104 | + parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) |
| 105 | + parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) |
| 106 | + |
| 107 | + parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST']) |
| 108 | + parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS'])) |
| 109 | + |
| 110 | + return parser.parse_args() |
| 111 | + |
68 | 112 |
|
| 113 | +if __name__ == '__main__': |
| 114 | + args = parse_args() |
| 115 | + num_gpus = int(os.environ['SM_NUM_GPUS']) |
69 | 116 |
|
70 |
| -def get_train_context(num_cpus, num_gpus): |
71 |
| - if num_gpus > 0: |
72 |
| - return mx.gpu() |
73 |
| - return mx.cpu() |
| 117 | + train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test, |
| 118 | + args.hosts, args.current_host, args.model_dir) |
0 commit comments