Skip to content

doc: Explain why default model_fn loads PyTorch-EI models to CPU by default #1404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 13, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions doc/using_pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ Load a Model
------------

Before a model can be served, it must be loaded. The SageMaker PyTorch model server loads your model by invoking a
``model_fn`` function that you must provide in your script. The ``model_fn`` should have the following signature:
``model_fn`` function that you must provide in your script when you are not using Elastic Inference. The ``model_fn`` should have the following signature:

.. code:: python

Expand All @@ -316,7 +316,11 @@ It loads the model parameters from a ``model.pth`` file in the SageMaker model d
However, if you are using PyTorch Elastic Inference, you do not have to provide a ``model_fn`` since the PyTorch serving
container has a default one for you. But please note that if you are utilizing the default ``model_fn``, please save
your ScriptModule as ``model.pt``. If you are implementing your own ``model_fn``, please use TorchScript and ``torch.jit.save``
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load``. For more information on inference script, please refer to:
to save your ScriptModule, then load it in your ``model_fn`` with ``torch.jit.load(..., map_location=torch.device('cpu'))``.

The client-side Elastic Inference framework is CPU-only, even though inference still happens in a CUDA context on the server. Thus, the default ``model_fn`` for Elastic Inference loads the model to CPU. Tracing models may lead to tensor creation on a specific device, which may cause device-related errors when loading a model onto a different device. Providing an explicit ``map_location=torch.device('cpu')`` argument forces all tensors to CPU.

For more information on the default inference handler functions, please refer to:
`SageMaker PyTorch Default Inference Handler <https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py>`_.

Serve a PyTorch Model
Expand Down