-
Notifications
You must be signed in to change notification settings - Fork 1.2k
doc: add more details about PyTorch eia #1357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,10 +118,22 @@ to a certain filesystem path called ``model_dir``. This value is accessible thro | |
After your training job is complete, SageMaker will compress and upload the serialized model to S3, and your model data | ||
will be available in the S3 ``output_path`` you specified when you created the PyTorch Estimator. | ||
|
||
If you are using Elastic Inference, you must convert your models to the TorchScript format and use ``torch.jit.save`` to save the model. | ||
For example: | ||
|
||
.. code:: python | ||
|
||
import os | ||
import torch | ||
|
||
# ... train `model`, then save it to `model_dir` | ||
model_dir = os.path.join(model_dir, "model.pt") | ||
torch.jit.save(model, model_dir) | ||
|
||
Using third-party libraries | ||
--------------------------- | ||
|
||
When running your training script on SageMaker, it will have access to some pre-installed third-party libraries including ``torch``, ``torchvisopm``, and ``numpy``. | ||
When running your training script on SageMaker, it will have access to some pre-installed third-party libraries including ``torch``, ``torchvision``, and ``numpy``. | ||
For more information on the runtime environment, including specific package versions, see `SageMaker PyTorch Docker containers <#id4>`__. | ||
|
||
If there are other packages you want to use with your script, you can include a ``requirements.txt`` file in the same directory as your training script to install other dependencies at runtime. Both ``requirements.txt`` and your training script should be put in the same folder. You must specify this folder in ``source_dir`` argument when creating PyTorch estimator. | ||
|
@@ -303,7 +315,8 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d | |
|
||
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving | ||
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save | ||
yor parameter file as ``model.pt`` instead of ``model.pth``. For more information on inference script, please refer to: | ||
your ScriptModule as ``model.pt``. If you are implementing your own ``model_fn``, please use TorchScript and ``torch.jit.save`` | ||
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load``. For more information on inference script, please refer to: | ||
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py>`_. | ||
|
||
Serve a PyTorch Model | ||
|
@@ -461,6 +474,23 @@ If you implement your own prediction function, you should take care to ensure th | |
first argument to ``output_fn``. If you use the default | ||
``output_fn``, this should be a torch.Tensor. | ||
|
||
The default Elastic Inference ``predict_fn`` is similar but using TorchScript and ``torch.jit.optimized_execution`` | ||
to load the output. If you are implementing your own ``predict_fn``, please also use the ``torch.jit.optimized_execution`` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "to load the output" => "to run the model". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about "The default Elastic Inference |
||
block, for example: | ||
|
||
.. code:: python | ||
|
||
import torch | ||
import numpy as np | ||
|
||
def predict_fn(input_data, model): | ||
device = torch.device("cpu") | ||
model = model.to(device) | ||
input_data = data.to(device) | ||
model.eval() | ||
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}): | ||
output = model(input_data) | ||
|
||
Process Model Output | ||
^^^^^^^^^^^^^^^^^^^^ | ||
|
||
|
@@ -671,6 +701,6 @@ The following are optional arguments. When you create a ``PyTorch`` object, you | |
SageMaker PyTorch Docker Containers | ||
*********************************** | ||
|
||
For information about SageMaker PyTorch containers, see `the SageMaker PyTorch containers repository <https://github.com/aws/sagemaker-pytorch-container>`_. | ||
For information about SageMaker PyTorch containers, see `the SageMaker PyTorch container repository <https://github.com/aws/sagemaker-pytorch-container>`_ and `SageMaker PyTorch Serving container repository <https://github.com/aws/sagemaker-pytorch-serving-container>`__. | ||
|
||
For information about SageMaker PyTorch container dependencies, see `SageMaker PyTorch Containers <https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/pytorch#sagemaker-pytorch-docker-containers>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about changing the existing examples to use
model.pt
as well? Or are you against it because people have to implement their own model_fn without EI? I'm worried people will get confusedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, changing all
.pth
to.pt
will make us change all previous test scripts, docs and notebook examples.