Skip to content

Document upcoming MXNet training script format #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 18, 2018
Merged
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion src/sagemaker/mxnet/README.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

=====================================
MXNet SageMaker Estimators and Models
=====================================
Expand Down Expand Up @@ -31,6 +30,14 @@ In the following sections, we'll discuss how to prepare a training script for ex
Preparing the MXNet training script
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

+-------------------------------------------------------------------------------------------------------------------------------+
| WARNING |
+===============================================================================================================================+
| This required structure for training scripts will be deprecated with the next major release of MXNet images. |
| The ``train`` function will no longer be required; instead the training script must be able to be run as a standalone script. |
| For more information, see `"Updating your MXNet training script" <#updating-your-mxnet-training-script>`__. |
+-------------------------------------------------------------------------------------------------------------------------------+

Your MXNet training script must be a Python 2.7 or 3.5 compatible source file. The MXNet training script must contain a function ``train``, which SageMaker invokes to run training. You can include other functions as well, but it must contain a ``train`` function.

When you run your script on SageMaker via the ``MXNet`` Estimator, SageMaker injects information about the training environment into your training function via Python keyword arguments. You can choose to take advantage of these by including them as keyword arguments in your train function. The full list of arguments is:
Expand Down Expand Up @@ -574,6 +581,79 @@ https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker-pytho
These are also available in SageMaker Notebook Instance hosted Jupyter notebooks under the "sample notebooks" folder.


Updating your MXNet training script
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The required structure for training scripts will be deprecated with the next major release of MXNet images.
The ``train`` function will no longer be required; instead the training script must be able to be run as a standalone script.
In this way, the training script will become similar to a training script you might run outside of SageMaker.

There are a few steps needed to make a training script with the old format compatible with the new format.
You don't need to do this yet, but it's documented here for future reference, as this change is coming soon.

First, add a `main guard <https://docs.python.org/3/library/__main__.html>`__ (``if __name__ == '__main__':``).
The code executed from your main guard needs to:

1. Set hyperparameters and directory locations
2. Initiate training
3. Save the model

Hyperparameters will be passed as command-line arguments to your training script.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest replacing "will be" with "are".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest replacing "will be" with "are"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used future tense because these instructions are going to be live awhile before the changes themselves are released - I'm afraid present tense might be too confusing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

In addition, the locations for finding input data and saving the model and output data will need to be defined.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We pass these in the container with environment variables. I don't think users can override these in their training script. They can specify the s3 locations for these. But since we handle the data downloading and mounting the shared partition for the container they have to use the paths in the environment variables. I think we should explain this a little better and put a reference to the environment variables in sagemaker-containers. There is a list in the README.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we store the locations in environment variables, but we do rely on the user to read those environment variables in their script, so they do have to define variables with those locations. I can add explanation about the environment variables though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest:

In addition, you need to define the locations of where to get the input data and where to save the model artifacts and output data.

We recommend using `an argument parser <https://docs.python.org/3.5/howto/argparse.html>`__ for this part.
Using the ``argparse`` library as an example, the code would look something like this:

.. code:: python

import argparse

if __name__ == '__main__':
parser = argparse.ArgumentParser()

# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--learning-rate', type=float, default=0.1)

# input data and model directories
parser.add_argument('--model-dir', type=str, default='opt/ml/model')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the environment variables to set the args here as well.

parser.add_argument('--train', type=str, default='opt/ml/input/data/train')
parser.add_argument('--test', type=str, default='opt/ml/input/data/test')

args, _ = parser.parse_known_args()

The code in the main guard should also take care of training and saving the model.
This can be as simple as just calling the methods used with the previous training script format:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest:

This can be as simple as calling the train and save methods used in the previous training script format.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed


.. code:: python

if __name__ == '__main__':
# arg parsing (shown above) goes here

model = train(args.batch_size, args.epochs, args.learning_rate, args.train, args.test)
save(args.model_dir, model)

Note that saving the model will no longer be done by default; this must be done by the training script.
If you were previously relying on the default save method, here is one you can copy into your code:

.. code:: python

import json
import os

def save(model_dir, model):
model.symbol.save(os.path.join(model_dir, 'model-symbol.json'))
model.save_params(os.path.join(model_dir, 'model-0000.params'))

signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]}
for data_desc in model.data_shapes]
with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f:
json.dump(signature, f)

These changes will make training with MXNet similar to training with Chainer or PyTorch on SageMaker.
For more information about those experiences, see `"Preparing the Chainer training script" <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/chainer#preparing-the-chainer-training-script>`__ and `"Preparing the PyTorch Training Script" <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/pytorch#preparing-the-pytorch-training-script>`__.


SageMaker MXNet Containers
~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down