Skip to content

Commit cb184cb

Browse files
authored
Document upcoming MXNet training script format (#390)
The next major release of MXNet will change the training script format. This README change documents the changes needed by the user to adjust to the new format. This is currently just a warning as the new format is not out yet. The warning is meant to help users plan for the upcoming change.
1 parent 52e4d76 commit cb184cb

File tree

1 file changed

+84
-1
lines changed

1 file changed

+84
-1
lines changed

src/sagemaker/mxnet/README.rst

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
=====================================
32
MXNet SageMaker Estimators and Models
43
=====================================
@@ -31,6 +30,14 @@ In the following sections, we'll discuss how to prepare a training script for ex
3130
Preparing the MXNet training script
3231
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3332

33+
+-------------------------------------------------------------------------------------------------------------------------------+
34+
| WARNING |
35+
+===============================================================================================================================+
36+
| This required structure for training scripts will be deprecated with the next major release of MXNet images. |
37+
| The ``train`` function will no longer be required; instead the training script must be able to be run as a standalone script. |
38+
| For more information, see `"Updating your MXNet training script" <#updating-your-mxnet-training-script>`__. |
39+
+-------------------------------------------------------------------------------------------------------------------------------+
40+
3441
Your MXNet training script must be a Python 2.7 or 3.5 compatible source file. The MXNet training script must contain a function ``train``, which SageMaker invokes to run training. You can include other functions as well, but it must contain a ``train`` function.
3542

3643
When you run your script on SageMaker via the ``MXNet`` Estimator, SageMaker injects information about the training environment into your training function via Python keyword arguments. You can choose to take advantage of these by including them as keyword arguments in your train function. The full list of arguments is:
@@ -574,6 +581,82 @@ https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker-pytho
574581
These are also available in SageMaker Notebook Instance hosted Jupyter notebooks under the "sample notebooks" folder.
575582

576583

584+
Updating your MXNet training script
585+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
586+
587+
The required structure for training scripts will be deprecated with the next major release of MXNet images.
588+
The ``train`` function will no longer be required; instead the training script must be able to be run as a standalone script.
589+
In this way, the training script will become similar to a training script you might run outside of SageMaker.
590+
591+
There are a few steps needed to make a training script with the old format compatible with the new format.
592+
You don't need to do this yet, but it's documented here for future reference, as this change is coming soon.
593+
594+
First, add a `main guard <https://docs.python.org/3/library/__main__.html>`__ (``if __name__ == '__main__':``).
595+
The code executed from your main guard needs to:
596+
597+
1. Set hyperparameters and directory locations
598+
2. Initiate training
599+
3. Save the model
600+
601+
Hyperparameters will be passed as command-line arguments to your training script.
602+
In addition, the container will define the locations of input data and where to save the model artifacts and output data as environment variables rather than passing that information as arguments to the ``train`` function.
603+
You can find the full list of available environment variables in the `SageMaker Containers README <https://github.com/aws/sagemaker-containers#list-of-provided-environment-variables-by-sagemaker-containers>`__.
604+
605+
We recommend using `an argument parser <https://docs.python.org/3.5/howto/argparse.html>`__ for this part.
606+
Using the ``argparse`` library as an example, the code would look something like this:
607+
608+
.. code:: python
609+
610+
import argparse
611+
import os
612+
613+
if __name__ == '__main__':
614+
parser = argparse.ArgumentParser()
615+
616+
# hyperparameters sent by the client are passed as command-line arguments to the script.
617+
parser.add_argument('--epochs', type=int, default=10)
618+
parser.add_argument('--batch-size', type=int, default=100)
619+
parser.add_argument('--learning-rate', type=float, default=0.1)
620+
621+
# input data and model directories
622+
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
623+
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
624+
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])
625+
626+
args, _ = parser.parse_known_args()
627+
628+
The code in the main guard should also take care of training and saving the model.
629+
This can be as simple as just calling the ``train`` and ``save`` methods used in the previous training script format:
630+
631+
.. code:: python
632+
633+
if __name__ == '__main__':
634+
# arg parsing (shown above) goes here
635+
636+
model = train(args.batch_size, args.epochs, args.learning_rate, args.train, args.test)
637+
save(args.model_dir, model)
638+
639+
Note that saving the model will no longer be done by default; this must be done by the training script.
640+
If you were previously relying on the default save method, here is one you can copy into your code:
641+
642+
.. code:: python
643+
644+
import json
645+
import os
646+
647+
def save(model_dir, model):
648+
model.symbol.save(os.path.join(model_dir, 'model-symbol.json'))
649+
model.save_params(os.path.join(model_dir, 'model-0000.params'))
650+
651+
signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]}
652+
for data_desc in model.data_shapes]
653+
with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f:
654+
json.dump(signature, f)
655+
656+
These changes will make training with MXNet similar to training with Chainer or PyTorch on SageMaker.
657+
For more information about those experiences, see `"Preparing the Chainer training script" <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/chainer#preparing-the-chainer-training-script>`__ and `"Preparing the PyTorch Training Script" <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/pytorch#preparing-the-pytorch-training-script>`__.
658+
659+
577660
SageMaker MXNet Containers
578661
~~~~~~~~~~~~~~~~~~~~~~~~~~
579662

0 commit comments

Comments
 (0)