@@ -589,18 +589,19 @@ The ``train`` function will no longer be required; instead the training script m
589
589
In this way, the training script will become similar to a training script you might run outside of SageMaker.
590
590
591
591
There are a few steps needed to make a training script with the old format compatible with the new format.
592
- You don't need to do this yet, but it's documented here for future reference.
592
+ You don't need to do this yet, but it's documented here for future reference, as this change is coming soon .
593
593
594
594
First, add a `main guard <https://docs.python.org/3/library/__main__.html >`__ (``if __name__ == '__main__': ``).
595
595
The code executed from your main guard needs to:
596
596
597
- 1. Set hyperparameters and other variables
597
+ 1. Set hyperparameters and directory locations
598
598
2. Initiate training
599
599
3. Save the model
600
600
601
- Hyperparameters will now be passed as command-line arguments to your training script.
602
- We recommend using an `argument parser <https://docs.python.org/3.5/howto/argparse.html >`__ to aid with this.
603
- Using the ``argparse `` library as an example, this part of the code would look something like this:
601
+ Hyperparameters will be passed as command-line arguments to your training script.
602
+ In addition, the locations for finding input data and saving the model and output data will need to be defined.
603
+ We recommend using `an argument parser <https://docs.python.org/3.5/howto/argparse.html >`__ for this part.
604
+ Using the ``argparse `` library as an example, the code would look something like this:
604
605
605
606
.. code :: python
606
607
@@ -614,17 +615,25 @@ Using the ``argparse`` library as an example, this part of the code would look s
614
615
parser.add_argument(' --batch-size' , type = int , default = 100 )
615
616
parser.add_argument(' --learning-rate' , type = float , default = 0.1 )
616
617
617
- # data, model, and output directories
618
- parser.add_argument(' --output-data-dir' , type = str , default = ' opt/ml/output/data' )
618
+ # input data and model directories
619
619
parser.add_argument(' --model-dir' , type = str , default = ' opt/ml/model' )
620
620
parser.add_argument(' --train' , type = str , default = ' opt/ml/input/data/train' )
621
621
parser.add_argument(' --test' , type = str , default = ' opt/ml/input/data/test' )
622
622
623
623
args, _ = parser.parse_known_args()
624
624
625
625
The code in the main guard should also take care of training and saving the model.
626
- (This can be as simple as just calling the methods used with the previous training script format.)
627
- Note now that saving the model will not be done by default; this must be done by the training script.
626
+ This can be as simple as just calling the methods used with the previous training script format:
627
+
628
+ .. code :: python
629
+
630
+ if __name__ == ' __main__' :
631
+ # arg parsing (shown above) goes here
632
+
633
+ model = train(args.batch_size, args.epochs, args.learning_rate, args.train, args.test)
634
+ save(args.model_dir, model)
635
+
636
+ Note that saving the model will no longer be done by default; this must be done by the training script.
628
637
If you were previously relying on the default save method, here is one you can copy into your code:
629
638
630
639
.. code :: python
0 commit comments