Skip to content

Commit 8b4403f

Browse files
authored
Add notebook for ONNX export with MXNet 1.3 (aws#454)
1 parent 47e2b03 commit 8b4403f

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import argparse
2+
import gzip
3+
import json
4+
import logging
5+
import os
6+
import tempfile
7+
import shutil
8+
import struct
9+
10+
import mxnet as mx
11+
from mxnet.contrib import onnx as onnx_mxnet
12+
import numpy as np
13+
14+
from sagemaker_mxnet_container.training_utils import scheduler_host
15+
16+
17+
def load_data(path):
18+
with gzip.open(find_file(path, "labels.gz")) as flbl:
19+
struct.unpack(">II", flbl.read(8))
20+
labels = np.fromstring(flbl.read(), dtype=np.int8)
21+
with gzip.open(find_file(path, "images.gz")) as fimg:
22+
_, _, rows, cols = struct.unpack(">IIII", fimg.read(16))
23+
images = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(labels), rows, cols)
24+
images = images.reshape(images.shape[0], 1, 28, 28).astype(np.float32) / 255
25+
return labels, images
26+
27+
28+
def find_file(root_path, file_name):
29+
for root, dirs, files in os.walk(root_path):
30+
if file_name in files:
31+
return os.path.join(root, file_name)
32+
33+
34+
def build_graph():
35+
data = mx.sym.var('data')
36+
data = mx.sym.flatten(data=data)
37+
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128)
38+
act1 = mx.sym.Activation(data=fc1, act_type="relu")
39+
fc2 = mx.sym.FullyConnected(data=act1, num_hidden=64)
40+
act2 = mx.sym.Activation(data=fc2, act_type="relu")
41+
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10)
42+
return mx.sym.SoftmaxOutput(data=fc3, name='softmax')
43+
44+
45+
def get_training_context(num_gpus):
46+
if num_gpus:
47+
return [mx.gpu(i) for i in range(num_gpus)]
48+
else:
49+
return mx.cpu()
50+
51+
52+
def train(batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel,
53+
hosts, current_host, model_dir):
54+
(train_labels, train_images) = load_data(training_channel)
55+
(test_labels, test_images) = load_data(testing_channel)
56+
57+
# Data parallel training - shard the data so each host
58+
# only trains on a subset of the total data.
59+
shard_size = len(train_images) // len(hosts)
60+
for i, host in enumerate(hosts):
61+
if host == current_host:
62+
start = shard_size * i
63+
end = start + shard_size
64+
break
65+
66+
train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size,
67+
shuffle=True)
68+
val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size)
69+
70+
logging.getLogger().setLevel(logging.DEBUG)
71+
72+
kvstore = 'local' if len(hosts) == 1 else 'dist_sync'
73+
74+
mlp_model = mx.mod.Module(symbol=build_graph(),
75+
context=get_training_context(num_gpus))
76+
mlp_model.fit(train_iter,
77+
eval_data=val_iter,
78+
kvstore=kvstore,
79+
optimizer='sgd',
80+
optimizer_params={'learning_rate': learning_rate},
81+
eval_metric='acc',
82+
batch_end_callback=mx.callback.Speedometer(batch_size, 100),
83+
num_epoch=epochs)
84+
85+
if current_host == scheduler_host(hosts):
86+
save(model_dir, mlp_model)
87+
88+
89+
def save(model_dir, model):
90+
tmp_dir = tempfile.mkdtemp()
91+
92+
symbol_file = os.path.join(tmp_dir, 'model-symbol.json')
93+
params_file = os.path.join(tmp_dir, 'model-0000.params')
94+
95+
model.symbol.save(symbol_file)
96+
model.save_params(params_file)
97+
98+
data_shapes = [[dim for dim in data_desc.shape] for data_desc in model.data_shapes]
99+
output_path = os.path.join(model_dir, 'model.onnx')
100+
101+
onnx_mxnet.export_model(symbol_file, params_file, data_shapes, np.float32, output_path)
102+
103+
shutil.rmtree(tmp_dir)
104+
105+
106+
def parse_args():
107+
parser = argparse.ArgumentParser()
108+
109+
parser.add_argument('--batch-size', type=int, default=100)
110+
parser.add_argument('--epochs', type=int, default=10)
111+
parser.add_argument('--learning-rate', type=float, default=0.1)
112+
113+
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
114+
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
115+
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])
116+
117+
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
118+
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
119+
120+
return parser.parse_args()
121+
122+
123+
if __name__ == '__main__':
124+
args = parse_args()
125+
num_gpus = int(os.environ['SM_NUM_GPUS'])
126+
127+
train(args.batch_size, args.epochs, args.learning_rate, num_gpus, args.train, args.test,
128+
args.hosts, args.current_host, args.model_dir)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Exporting ONNX Models with MXNet\n",
8+
"\n",
9+
"The [Open Neural Network Exchange](https://onnx.ai/) (ONNX) is an open format for representing deep learning models with its extensible computation graph model and definitions of built-in operators and standard data types. Starting with MXNet 1.3, models trained using MXNet can now be saved as ONNX models.\n",
10+
"\n",
11+
"In this example, we will show how to train a model on Amazon SageMaker and save it as an ONNX model. This notebooks is based on the [MXNet MNIST notebook](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb) and the [MXNet example for exporting to ONNX](https://mxnet.incubator.apache.org/tutorials/onnx/export_mxnet_to_onnx.html)."
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"### Setup\n",
19+
"\n",
20+
"First we need to define a few variables that will be needed later in the example."
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"metadata": {},
27+
"outputs": [],
28+
"source": [
29+
"import boto3\n",
30+
"\n",
31+
"from sagemaker import get_execution_role\n",
32+
"from sagemaker.session import Session\n",
33+
"\n",
34+
"# AWS region\n",
35+
"region = boto3.Session().region_name\n",
36+
"\n",
37+
"# S3 bucket for saving code and model artifacts.\n",
38+
"# Feel free to specify a different bucket here if you wish.\n",
39+
"bucket = Session().default_bucket()\n",
40+
"\n",
41+
"# Location to save your custom code in tar.gz format.\n",
42+
"custom_code_upload_location = 's3://{}/customcode/mxnet'.format(bucket)\n",
43+
"\n",
44+
"# Location where results of model training are saved.\n",
45+
"model_artifacts_location = 's3://{}/artifacts'.format(bucket)\n",
46+
"\n",
47+
"# IAM execution role that gives SageMaker access to resources in your AWS account.\n",
48+
"# We can use the SageMaker Python SDK to get the role from our notebook environment. \n",
49+
"role = get_execution_role()"
50+
]
51+
},
52+
{
53+
"cell_type": "markdown",
54+
"metadata": {},
55+
"source": [
56+
"### The training script\n",
57+
"\n",
58+
"The ``mnist.py`` script provides all the code we need for training and hosting a SageMaker model. The script we will use is adaptated from Apache MXNet [MNIST tutorial](https://mxnet.incubator.apache.org/tutorials/python/mnist.html)."
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": null,
64+
"metadata": {},
65+
"outputs": [],
66+
"source": [
67+
"!pygmentize mnist.py"
68+
]
69+
},
70+
{
71+
"cell_type": "markdown",
72+
"metadata": {},
73+
"source": [
74+
"### Exporting to ONNX\n",
75+
"\n",
76+
"The important part of this script can be found in the `save` method. This is where the ONNX model is exported:\n",
77+
"\n",
78+
"```python\n",
79+
"import os\n",
80+
"\n",
81+
"from mxnet.contrib import onnx as onnx_mxnet\n",
82+
"import numpy as np\n",
83+
"\n",
84+
"def save(model_dir, model):\n",
85+
" symbol_file = os.path.join(model_dir, 'model-symbol.json')\n",
86+
" params_file = os.path.join(model_dir, 'model-0000.params')\n",
87+
"\n",
88+
" model.symbol.save(symbol_file)\n",
89+
" model.save_params(params_file)\n",
90+
"\n",
91+
" data_shapes = [[dim for dim in data_desc.shape] for data_desc in model.data_shapes]\n",
92+
" output_path = os.path.join(model_dir, 'model.onnx')\n",
93+
" \n",
94+
" onnx_mxnet.export_model(symbol_file, params_file, data_shapes, np.float32, output_path)\n",
95+
"```\n",
96+
"\n",
97+
"The last line in that method, `onnx_mxnet.export_model`, is what saves the model in the ONNX format. You can see that we pass the following arguments:\n",
98+
"\n",
99+
"* `symbol_file`: path to the saved input symbol file\n",
100+
"* `params_file`: path to the saved input params file\n",
101+
"* `data_shapes`: list of the input shapes\n",
102+
"* `np.float32`: input data type\n",
103+
"* `output_path`: path to save the generated ONNX file\n",
104+
"\n",
105+
"For more information, see the [MXNet Documentation](https://mxnet.incubator.apache.org/api/python/contrib/onnx.html#mxnet.contrib.onnx.mx2onnx.export_model.export_model)."
106+
]
107+
},
108+
{
109+
"cell_type": "markdown",
110+
"metadata": {},
111+
"source": [
112+
"### Training the model\n",
113+
"\n",
114+
"With the training script written to export an ONNX model, the rest of training process looks like any other Amazon SageMaker training job using MXNet. For a more in-depth explanation of these steps, see the [MXNet MNIST notebook](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb)."
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"metadata": {},
121+
"outputs": [],
122+
"source": [
123+
"from sagemaker.mxnet import MXNet\n",
124+
"\n",
125+
"mnist_estimator = MXNet(entry_point='mnist.py',\n",
126+
" role=role,\n",
127+
" output_path=model_artifacts_location,\n",
128+
" code_location=custom_code_upload_location,\n",
129+
" train_instance_count=1,\n",
130+
" train_instance_type='ml.m4.xlarge',\n",
131+
" framework_version='1.3.0',\n",
132+
" hyperparameters={'learning-rate': 0.1})\n",
133+
"\n",
134+
"train_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/train'.format(region)\n",
135+
"test_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/test'.format(region)\n",
136+
"\n",
137+
"mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})"
138+
]
139+
},
140+
{
141+
"cell_type": "markdown",
142+
"metadata": {},
143+
"source": [
144+
"### Next steps\n",
145+
"\n",
146+
"Now that we have an ONNX model, we can deploy it to an endpoint in the same way we do in the [MXNet MNIST notebook](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb).\n",
147+
"\n",
148+
"For examples on how to write a `model_fn` to load the ONNX model, please refer to:\n",
149+
"* the [MXNet ONNX Super Resolution notebook](https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker-python-sdk/mxnet_onnx_superresolution)\n",
150+
"* the [MXNet documentation](https://mxnet.incubator.apache.org/api/python/contrib/onnx.html#mxnet.contrib.onnx.onnx2mx.import_model.import_model)"
151+
]
152+
}
153+
],
154+
"metadata": {
155+
"kernelspec": {
156+
"display_name": "conda_mxnet_p36",
157+
"language": "python",
158+
"name": "conda_mxnet_p36"
159+
},
160+
"language_info": {
161+
"codemirror_mode": {
162+
"name": "ipython",
163+
"version": 3
164+
},
165+
"file_extension": ".py",
166+
"mimetype": "text/x-python",
167+
"name": "python",
168+
"nbconvert_exporter": "python",
169+
"pygments_lexer": "ipython3",
170+
"version": "3.6.5"
171+
}
172+
},
173+
"nbformat": 4,
174+
"nbformat_minor": 2
175+
}

0 commit comments

Comments
 (0)