Skip to content

Commit bb0b3f4

Browse files
committed
Continue writing
1 parent d790884 commit bb0b3f4

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

src/sagemaker/mxnet/README.rst

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Preparing the MXNet training script
3535
+===============================================================================================================================+
3636
| This required structure for training scripts will be deprecated with the next major release of MXNet images. |
3737
| The ``train`` function will no longer be required; instead the training script must be able to be run as a standalone script. |
38-
| For more information, see `Updating your MXNet training script <#updating-your-mxnet-training-script>`__. |
38+
| For more information, see `"Updating your MXNet training script" <#updating-your-mxnet-training-script>`__. |
3939
+-------------------------------------------------------------------------------------------------------------------------------+
4040

4141
Your MXNet training script must be a Python 2.7 or 3.5 compatible source file. The MXNet training script must contain a function ``train``, which SageMaker invokes to run training. You can include other functions as well, but it must contain a ``train`` function.
@@ -604,20 +604,23 @@ Using the ``argparse`` library as an example, this part of the code would look s
604604

605605
.. code:: python
606606
607-
parser = argparse.ArgumentParser()
607+
import argparse
608608
609-
# hyperparameters sent by the client are passed as command-line arguments to the script.
610-
parser.add_argument('--epochs', type=int, default=10)
611-
parser.add_argument('--batch-size', type=int, default=100)
612-
parser.add_argument('--learning-rate', type=float, default=0.1)
609+
if __name__ == '__main__':
610+
parser = argparse.ArgumentParser()
613611
614-
# data, model, and output directories
615-
parser.add_argument('--output-data-dir', type=str, default='opt/ml/output/data')
616-
parser.add_argument('--model-dir', type=str, default='opt/ml/model')
617-
parser.add_argument('--train', type=str, default='opt/ml/input/data/train')
618-
parser.add_argument('--test', type=str, default='opt/ml/input/data/test')
612+
# hyperparameters sent by the client are passed as command-line arguments to the script.
613+
parser.add_argument('--epochs', type=int, default=10)
614+
parser.add_argument('--batch-size', type=int, default=100)
615+
parser.add_argument('--learning-rate', type=float, default=0.1)
619616
620-
args, _ = parser.parse_known_args()
617+
# data, model, and output directories
618+
parser.add_argument('--output-data-dir', type=str, default='opt/ml/output/data')
619+
parser.add_argument('--model-dir', type=str, default='opt/ml/model')
620+
parser.add_argument('--train', type=str, default='opt/ml/input/data/train')
621+
parser.add_argument('--test', type=str, default='opt/ml/input/data/test')
622+
623+
args, _ = parser.parse_known_args()
621624
622625
The code in the main guard should also take care of training and saving the model.
623626
(This can be as simple as just calling the methods used with the previous training script format.)
@@ -626,6 +629,9 @@ If you were previously relying on the default save method, here is one you can c
626629

627630
.. code:: python
628631
632+
import json
633+
import os
634+
629635
def save(model_dir, model):
630636
model.symbol.save(os.path.join(model_dir, 'model-symbol.json'))
631637
model.save_params(os.path.join(model_dir, 'model-0000.params'))
@@ -635,6 +641,9 @@ If you were previously relying on the default save method, here is one you can c
635641
with open(os.path.join(model_dir, 'model-shapes.json'), 'w') as f:
636642
json.dump(signature, f)
637643
644+
These changes will make training with MXNet similar to training with Chainer or PyTorch on SageMaker.
645+
For more information about those experiences, see `"Preparing the Chainer training script" <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/chainer#preparing-the-chainer-training-script>`__ and `"Preparing the PyTorch Training Script" <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/pytorch#preparing-the-pytorch-training-script>`__.
646+
638647

639648
SageMaker MXNet Containers
640649
~~~~~~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)