Skip to content

Commit d8d64da

Browse files
authored
doc: add more details about PyTorch eia (#1357)
* doc: add more details about PyTorch eia
1 parent 2d7dd63 commit d8d64da

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

doc/using_pytorch.rst

+33-3
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,22 @@ to a certain filesystem path called ``model_dir``. This value is accessible thro
118118
After your training job is complete, SageMaker will compress and upload the serialized model to S3, and your model data
119119
will be available in the S3 ``output_path`` you specified when you created the PyTorch Estimator.
120120

121+
If you are using Elastic Inference, you must convert your models to the TorchScript format and use ``torch.jit.save`` to save the model.
122+
For example:
123+
124+
.. code:: python
125+
126+
import os
127+
import torch
128+
129+
# ... train `model`, then save it to `model_dir`
130+
model_dir = os.path.join(model_dir, "model.pt")
131+
torch.jit.save(model, model_dir)
132+
121133
Using third-party libraries
122134
---------------------------
123135

124-
When running your training script on SageMaker, it will have access to some pre-installed third-party libraries including ``torch``, ``torchvisopm``, and ``numpy``.
136+
When running your training script on SageMaker, it will have access to some pre-installed third-party libraries including ``torch``, ``torchvision``, and ``numpy``.
125137
For more information on the runtime environment, including specific package versions, see `SageMaker PyTorch Docker containers <#id4>`__.
126138

127139
If there are other packages you want to use with your script, you can include a ``requirements.txt`` file in the same directory as your training script to install other dependencies at runtime. Both ``requirements.txt`` and your training script should be put in the same folder. You must specify this folder in ``source_dir`` argument when creating PyTorch estimator.
@@ -303,7 +315,8 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d
303315
304316
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving
305317
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save
306-
yor parameter file as ``model.pt`` instead of ``model.pth``. For more information on inference script, please refer to:
318+
your ScriptModule as ``model.pt``. If you are implementing your own ``model_fn``, please use TorchScript and ``torch.jit.save``
319+
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load``. For more information on inference script, please refer to:
307320
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py>`_.
308321

309322
Serve a PyTorch Model
@@ -461,6 +474,23 @@ If you implement your own prediction function, you should take care to ensure th
461474
first argument to ``output_fn``. If you use the default
462475
``output_fn``, this should be a torch.Tensor.
463476

477+
The default Elastic Inference ``predict_fn`` is similar but runs the TorchScript model using ``torch.jit.optimized_execution``.
478+
If you are implementing your own ``predict_fn``, please also use the ``torch.jit.optimized_execution``
479+
block, for example:
480+
481+
.. code:: python
482+
483+
import torch
484+
import numpy as np
485+
486+
def predict_fn(input_data, model):
487+
device = torch.device("cpu")
488+
model = model.to(device)
489+
input_data = data.to(device)
490+
model.eval()
491+
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
492+
output = model(input_data)
493+
464494
Process Model Output
465495
^^^^^^^^^^^^^^^^^^^^
466496

@@ -671,6 +701,6 @@ The following are optional arguments. When you create a ``PyTorch`` object, you
671701
SageMaker PyTorch Docker Containers
672702
***********************************
673703

674-
For information about SageMaker PyTorch containers, see `the SageMaker PyTorch containers repository <https://github.com/aws/sagemaker-pytorch-container>`_.
704+
For information about SageMaker PyTorch containers, see `the SageMaker PyTorch container repository <https://github.com/aws/sagemaker-pytorch-container>`_ and `SageMaker PyTorch Serving container repository <https://github.com/aws/sagemaker-pytorch-serving-container>`__.
675705

676706
For information about SageMaker PyTorch container dependencies, see `SageMaker PyTorch Containers <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/pytorch#sagemaker-pytorch-docker-containers>`_.

0 commit comments

Comments
 (0)