Skip to content

Commit 7a2618a

Browse files
authored
Update notebooks for MXNet 1.3 (aws#451)
1 parent 13abcfb commit 7a2618a

File tree

5 files changed

+135
-77
lines changed

5 files changed

+135
-77
lines changed

sagemaker-python-sdk/mxnet_gluon_sentiment/mxnet_sentiment_analysis_with_gluon.ipynb

+17-27
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
{
1414
"cell_type": "code",
1515
"execution_count": null,
16-
"metadata": {
17-
"collapsed": true
18-
},
16+
"metadata": {},
1917
"outputs": [],
2018
"source": [
2119
"import os\n",
@@ -68,9 +66,7 @@
6866
{
6967
"cell_type": "code",
7068
"execution_count": null,
71-
"metadata": {
72-
"collapsed": true
73-
},
69+
"metadata": {},
7470
"outputs": [],
7571
"source": [
7672
"inputs = sagemaker_session.upload_data(path='data', key_prefix='data/DEMO-sentiment')"
@@ -90,9 +86,7 @@
9086
{
9187
"cell_type": "code",
9288
"execution_count": null,
93-
"metadata": {
94-
"collapsed": true
95-
},
89+
"metadata": {},
9690
"outputs": [],
9791
"source": [
9892
"!cat 'sentiment.py'"
@@ -110,21 +104,21 @@
110104
{
111105
"cell_type": "code",
112106
"execution_count": null,
113-
"metadata": {
114-
"collapsed": true
115-
},
107+
"metadata": {},
116108
"outputs": [],
117109
"source": [
118-
"m = MXNet(\"sentiment.py\",\n",
110+
"m = MXNet('sentiment.py',\n",
119111
" role=role,\n",
120112
" train_instance_count=1,\n",
121-
" train_instance_type=\"ml.c4.2xlarge\",\n",
122-
" framework_version=\"1.2.1\",\n",
123-
" hyperparameters={'batch_size': 8,\n",
124-
" 'epochs': 2,\n",
125-
" 'learning_rate': 0.01,\n",
126-
" 'embedding_size': 50, \n",
127-
" 'log_interval': 1000})"
113+
" train_instance_type='ml.c4.2xlarge',\n",
114+
" framework_version='1.3.0',\n",
115+
" py_version='py2',\n",
116+
" launch_parameter_server=True,\n",
117+
" hyperparameters={'batch-size': 8,\n",
118+
" 'epochs': 2,\n",
119+
" 'learning-rate': 0.01,\n",
120+
" 'embedding-size': 50, \n",
121+
" 'log-interval': 1000})"
128122
]
129123
},
130124
{
@@ -137,9 +131,7 @@
137131
{
138132
"cell_type": "code",
139133
"execution_count": null,
140-
"metadata": {
141-
"collapsed": true
142-
},
134+
"metadata": {},
143135
"outputs": [],
144136
"source": [
145137
"m.fit(inputs)"
@@ -189,7 +181,7 @@
189181
" \"the movie was so enthralling !\"]\n",
190182
"\n",
191183
"response = predictor.predict(data)\n",
192-
"print response"
184+
"print(response)"
193185
]
194186
},
195187
{
@@ -204,9 +196,7 @@
204196
{
205197
"cell_type": "code",
206198
"execution_count": null,
207-
"metadata": {
208-
"collapsed": true
209-
},
199+
"metadata": {},
210200
"outputs": [],
211201
"source": [
212202
"sagemaker.Session().delete_endpoint(predictor.endpoint)"

sagemaker-python-sdk/mxnet_gluon_sentiment/sentiment.py

+45-16
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
from __future__ import print_function
22

3+
import argparse
4+
import bisect
5+
from collections import Counter
6+
from itertools import chain, islice
7+
import json
38
import logging
9+
import time
10+
import random
11+
import os
12+
413
import mxnet as mx
514
from mxnet import gluon, autograd, nd
6-
from mxnet.gluon import nn
7-
import numpy as np
8-
import json
9-
import time
10-
import re
1115
from mxnet.io import DataIter, DataBatch, DataDesc
12-
import bisect, random
13-
from collections import Counter
14-
from itertools import chain, islice
16+
import numpy as np
17+
18+
from sagemaker_mxnet_container.training_utils import scheduler_host
1519

1620

1721
logging.basicConfig(level=logging.DEBUG)
@@ -20,22 +24,16 @@
2024
# Training methods #
2125
# ------------------------------------------------------------ #
2226

23-
def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir, hyperparameters, **kwargs):
24-
# retrieve the hyperparameters we set in notebook (with some defaults)
25-
batch_size = hyperparameters.get('batch_size', 8)
26-
epochs = hyperparameters.get('epochs', 2)
27-
learning_rate = hyperparameters.get('learning_rate', 0.01)
28-
log_interval = hyperparameters.get('log_interval', 1000)
29-
embedding_size = hyperparameters.get('embedding_size', 50)
3027

28+
def train(current_host, hosts, num_cpus, num_gpus, training_dir, model_dir,
29+
batch_size, epochs, learning_rate, log_interval, embedding_size):
3130
if len(hosts) == 1:
3231
kvstore = 'device' if num_gpus > 0 else 'local'
3332
else:
3433
kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync'
3534

3635
ctx = mx.gpu() if num_gpus > 0 else mx.cpu()
3736

38-
training_dir = channel_input_dirs['training']
3937
train_sentences, train_labels, _ = get_dataset(training_dir + '/train')
4038
val_sentences, val_labels, _ = get_dataset(training_dir + '/test')
4139

@@ -312,6 +310,37 @@ def test(ctx, net, val_data):
312310
return metric.get()
313311

314312

313+
def parse_args():
314+
parser = argparse.ArgumentParser()
315+
316+
# retrieve the hyperparameters we set in notebook (with some defaults)
317+
parser.add_argument('--batch-size', type=int, default=8)
318+
parser.add_argument('--epochs', type=int, default=2)
319+
parser.add_argument('--learning-rate', type=float, default=0.01)
320+
parser.add_argument('--log-interval', type=int, default=1000)
321+
parser.add_argument('--embedding-size', type=int, default=50)
322+
323+
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
324+
parser.add_argument('--training_channel', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
325+
326+
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
327+
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
328+
329+
return parser.parse_args()
330+
331+
332+
if __name__ == '__main__':
333+
args = parse_args()
334+
num_cpus = int(os.environ['SM_NUM_CPUS'])
335+
num_gpus = int(os.environ['SM_NUM_GPUS'])
336+
337+
model = train(args.current_host, args.hosts, num_cpus, num_gpus, args.training_channel, args.model_dir,
338+
args.batch_size, args.epochs, args.learning_rate, args.log_interval, args.embedding_size)
339+
340+
if args.current_host == scheduler_host(args.hosts):
341+
save(model, args.model_dir)
342+
343+
315344
# ------------------------------------------------------------ #
316345
# Hosting methods #
317346
# ------------------------------------------------------------ #

sagemaker-python-sdk/mxnet_mnist/mnist.py

+65-20
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
import argparse
2+
import gzip
3+
import json
14
import logging
5+
import os
6+
import struct
27

3-
import gzip
48
import mxnet as mx
59
import numpy as np
6-
import os
7-
import struct
10+
11+
from sagemaker_mxnet_container.training_utils import scheduler_host
812

913

1014
def load_data(path):
@@ -35,39 +39,80 @@ def build_graph():
3539
return mx.sym.SoftmaxOutput(data=fc3, name='softmax')
3640

3741

38-
def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, num_gpus):
39-
(train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train']))
40-
(test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test']))
42+
def get_training_context(num_gpus):
43+
if num_gpus:
44+
return [mx.gpu(i) for i in range(num_gpus)]
45+
else:
46+
return mx.cpu()
47+
48+
49+
def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel,
50+
hosts, current_host, model_dir):
51+
(train_labels, train_images) = load_data(training_channel)
52+
(test_labels, test_images) = load_data(testing_channel)
4153

42-
# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
43-
# to do parallel training.
54+
# Data parallel training - shard the data so each host
55+
# only trains on a subset of the total data.
4456
shard_size = len(train_images) // len(hosts)
4557
for i, host in enumerate(hosts):
4658
if host == current_host:
4759
start = shard_size * i
4860
end = start + shard_size
4961
break
5062

51-
batch_size = 100
52-
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True)
63+
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size,
64+
shuffle=True)
5365
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
66+
5467
logging.getLogger().setLevel(logging.DEBUG)
68+
5569
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
56-
mlp_model = mx.mod.Module(
57-
symbol=build_graph(),
58-
context=get_train_context(num_cpus, num_gpus))
70+
71+
mlp_model = mx.mod.Module(symbol=build_graph(),
72+
context=get_training_context(num_gpus))
5973
mlp_model.fit(train_iter,
6074
eval_data=val_iter,
6175
kvstore=kvstore,
6276
optimizer='sgd',
63-
optimizer_params={'learning_rate': float(hyperparameters.get("learning_rate", 0.1))},
77+
optimizer_params={'learning_rate': learning_rate},
6478
eval_metric='acc',
6579
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
66-
num_epoch=25)
67-
return mlp_model
80+
num_epoch=epochs)
81+
82+
if current_host == scheduler_host(hosts):
83+
save(model_dir, mlp_model)
84+
85+
86+
def save(model_dir, model):
87+
model.symbol.save(os.path.join(model_dir, 'model-symbol.json'))
88+
model.save_params(os.path.join(model_dir, 'model-0000.params'))
89+
90+
signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]}
91+
for data_desc in model.data_shapes]
92+
with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f:
93+
json.dump(signature, f)
94+
95+
96+
def parse_args():
97+
parser = argparse.ArgumentParser()
98+
99+
parser.add_argument('--batch-size', type=int, default=100)
100+
parser.add_argument('--epochs', type=int, default=10)
101+
parser.add_argument('--learning-rate', type=float, default=0.1)
102+
103+
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
104+
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
105+
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])
106+
107+
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
108+
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
109+
110+
return parser.parse_args()
111+
68112

113+
if __name__ == '__main__':
114+
args = parse_args()
115+
num_gpus = int(os.environ['SM_NUM_GPUS'])
69116

70-
def get_train_context(num_cpus, num_gpus):
71-
if num_gpus > 0:
72-
return mx.gpu()
73-
return mx.cpu()
117+
train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test,
118+
args.hosts, args.current_host, args.model_dir)

sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb

+6-12
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"cell_type": "code",
1818
"execution_count": null,
1919
"metadata": {
20-
"collapsed": true,
2120
"isConfigCell": true
2221
},
2322
"outputs": [],
@@ -79,9 +78,7 @@
7978
{
8079
"cell_type": "code",
8180
"execution_count": null,
82-
"metadata": {
83-
"collapsed": true
84-
},
81+
"metadata": {},
8582
"outputs": [],
8683
"source": [
8784
"from sagemaker.mxnet import MXNet\n",
@@ -92,8 +89,9 @@
9289
" code_location=custom_code_upload_location,\n",
9390
" train_instance_count=1,\n",
9491
" train_instance_type='ml.m4.xlarge',\n",
95-
" framework_version='1.2.1',\n",
96-
" hyperparameters={'learning_rate': 0.1})"
92+
" framework_version='1.3.0',\n",
93+
" launch_parameter_server=True,\n",
94+
" hyperparameters={'learning-rate': 0.1})"
9795
]
9896
},
9997
{
@@ -219,9 +217,7 @@
219217
{
220218
"cell_type": "code",
221219
"execution_count": null,
222-
"metadata": {
223-
"collapsed": true
224-
},
220+
"metadata": {},
225221
"outputs": [],
226222
"source": [
227223
"print(\"Endpoint name: \" + predictor.endpoint)"
@@ -230,9 +226,7 @@
230226
{
231227
"cell_type": "code",
232228
"execution_count": null,
233-
"metadata": {
234-
"collapsed": true
235-
},
229+
"metadata": {},
236230
"outputs": [],
237231
"source": [
238232
"import sagemaker\n",

sagemaker-python-sdk/mxnet_mnist/mxnet_mnist_with_batch_transform.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@
9292
" code_location=custom_code_upload_location,\n",
9393
" train_instance_count=1,\n",
9494
" train_instance_type='ml.m4.xlarge',\n",
95-
" framework_version='1.2.1',\n",
96-
" hyperparameters={'learning_rate': 0.1})"
95+
" framework_version='1.3.0',\n",
96+
" hyperparameters={'learning-rate': 0.1})"
9797
]
9898
},
9999
{

0 commit comments

Comments
 (0)