Skip to content

Commit a240855

Browse files
authored
Merge pull request aws#191 from iquintero/mxnet_distributed
Ensure MXNet notebooks run in distributed mode.
2 parents 11f3a63 + 4196324 commit a240855

File tree

4 files changed

+54
-10
lines changed

4 files changed

+54
-10
lines changed

sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,17 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
3636
# load training and validation data
3737
# we use the gluon.data.vision.CIFAR10 class because of its built in pre-processing logic,
3838
# but point it at the location where SageMaker placed the data files, so it doesn't download them again.
39+
40+
part_index = 0
41+
for i, host in enumerate(hosts):
42+
if host == current_host:
43+
part_index = i
44+
break
45+
46+
3947
data_dir = channel_input_dirs['training']
40-
train_data = get_train_data(num_cpus, data_dir, batch_size, (3, 32, 32))
48+
train_data = get_train_data(num_cpus, data_dir, batch_size, (3, 32, 32),
49+
num_parts=len(hosts), part_index=part_index)
4150
test_data = get_test_data(num_cpus, data_dir, batch_size, (3, 32, 32))
4251

4352
# Collect all parameters from net and its children, then initialize them.
@@ -104,23 +113,26 @@ def save(net, model_dir):
104113
os.rename(os.path.join(model_dir, best), os.path.join(model_dir, 'model.params'))
105114

106115

107-
def get_data(path, augment, num_cpus, batch_size, data_shape, resize=-1):
116+
def get_data(path, augment, num_cpus, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
108117
return mx.io.ImageRecordIter(
109118
path_imgrec=path,
110119
resize=resize,
111120
data_shape=data_shape,
112121
batch_size=batch_size,
113122
rand_crop=augment,
114123
rand_mirror=augment,
115-
preprocess_threads=num_cpus)
124+
preprocess_threads=num_cpus,
125+
num_parts=num_parts,
126+
part_index=part_index)
116127

117128

118129
def get_test_data(num_cpus, data_dir, batch_size, data_shape, resize=-1):
119-
return get_data(os.path.join(data_dir, "test.rec"), False, num_cpus, batch_size, data_shape, resize)
130+
return get_data(os.path.join(data_dir, "test.rec"), False, num_cpus, batch_size, data_shape, resize, 1, 0)
120131

121132

122-
def get_train_data(num_cpus, data_dir, batch_size, data_shape, resize=-1):
123-
return get_data(os.path.join(data_dir, "train.rec"), True, num_cpus, batch_size, data_shape, resize)
133+
def get_train_data(num_cpus, data_dir, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
134+
return get_data(os.path.join(data_dir, "train.rec"), True, num_cpus, batch_size, data_shape, resize, num_parts,
135+
part_index)
124136

125137

126138
def test(ctx, net, test_data):

sagemaker-python-sdk/mxnet_gluon_mnist/mnist.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# ------------------------------------------------------------ #
1717

1818

19-
def train(channel_input_dirs, hyperparameters, hosts, num_gpus, **kwargs):
19+
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_gpus):
2020
# SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
2121
# the current container environment, but here we just use simple cpu context.
2222
ctx = mx.cpu()
@@ -53,6 +53,19 @@ def train(channel_input_dirs, hyperparameters, hosts, num_gpus, **kwargs):
5353
metric = mx.metric.Accuracy()
5454
loss = gluon.loss.SoftmaxCrossEntropyLoss()
5555

56+
# shard the training data in case we are doing distributed training. Alternatively to splitting in memory,
57+
# the data could be pre-split in S3 and use ShardedByS3Key to do distributed training.
58+
if len(hosts) > 1:
59+
train_data = [x for x in train_data]
60+
shard_size = len(train_data) // len(hosts)
61+
for i, host in enumerate(hosts):
62+
if host == current_host:
63+
start = shard_size * i
64+
end = start + shard_size
65+
break
66+
67+
train_data = train_data[start:end]
68+
5669
for epoch in range(epochs):
5770
# reset data iterator and metric at begining of epoch.
5871
metric.reset()

sagemaker-python-sdk/mxnet_gluon_sentiment/sentiment.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
4646
train_sentences = [[vocab.get(token, 1) for token in line if len(line)>0] for line in train_sentences]
4747
val_sentences = [[vocab.get(token, 1) for token in line if len(line)>0] for line in val_sentences]
4848

49-
train_iterator = BucketSentenceIter(train_sentences, train_labels, batch_size)
49+
# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
50+
# to do parallel training.
51+
shard_size = len(train_sentences) // len(hosts)
52+
for i, host in enumerate(hosts):
53+
if host == current_host:
54+
start = shard_size * i
55+
end = start + shard_size
56+
break
57+
58+
train_iterator = BucketSentenceIter(train_sentences[start:end], train_labels[start:end], batch_size)
5059
val_iterator = BucketSentenceIter(val_sentences, val_labels, batch_size)
5160

5261
# define the network

sagemaker-python-sdk/mxnet_mnist/mnist.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,21 @@ def build_graph():
3535
return mx.sym.SoftmaxOutput(data=fc3, name='softmax')
3636

3737

38-
def train(channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus, **kwargs):
38+
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus):
3939
(train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train']))
4040
(test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test']))
41+
42+
# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
43+
# to do parallel training.
44+
shard_size = len(train_images) // len(hosts)
45+
for i, host in enumerate(hosts):
46+
if host == current_host:
47+
start = shard_size * i
48+
end = start + shard_size
49+
break
50+
4151
batch_size = 100
42-
train_iter = mx.io.NDArrayIter(train_images, train_labels, batch_size, shuffle=True)
52+
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True)
4353
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
4454
logging.getLogger().setLevel(logging.DEBUG)
4555
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'

0 commit comments

Comments
 (0)