Skip to content

Commit f38b620

Browse files
documentation: add context for pytorch (#3352)
Co-authored-by: Basil Beirouti <[email protected]>
1 parent 8e2a68e commit f38b620

File tree

1 file changed

+42
-13
lines changed

1 file changed

+42
-13
lines changed

doc/frameworks/pytorch/using_pytorch.rst

+42-13
Original file line numberDiff line numberDiff line change
@@ -415,20 +415,25 @@ Before a model can be served, it must be loaded. The SageMaker PyTorch model ser
415415

416416
.. code:: python
417417
418-
def model_fn(model_dir)
418+
def model_fn(model_dir, context)
419+
420+
``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
421+
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
422+
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
419423

420424
SageMaker will inject the directory where your model files and sub-directories, saved by ``save``, have been mounted.
421425
Your model function should return a model object that can be used for model serving.
422426

423427
The following code-snippet shows an example ``model_fn`` implementation.
424-
It loads the model parameters from a ``model.pth`` file in the SageMaker model directory ``model_dir``.
428+
It loads the model parameters from a ``model.pth`` file in the SageMaker model directory ``model_dir``. As explained in the preceding example,
429+
``context`` is an optional argument that passes additional information.
425430

426431
.. code:: python
427432
428433
import torch
429434
import os
430435
431-
def model_fn(model_dir):
436+
def model_fn(model_dir, context):
432437
model = Your_Model()
433438
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
434439
model.load_state_dict(torch.load(f))
@@ -482,13 +487,13 @@ function in the chain. Inside the SageMaker PyTorch model server, the process lo
482487
.. code:: python
483488
484489
# Deserialize the Invoke request body into an object we can perform prediction on
485-
input_object = input_fn(request_body, request_content_type)
490+
input_object = input_fn(request_body, request_content_type, context)
486491
487492
# Perform prediction on the deserialized object, with the loaded model
488-
prediction = predict_fn(input_object, model)
493+
prediction = predict_fn(input_object, model, context)
489494
490495
# Serialize the prediction result into the desired response content type
491-
output = output_fn(prediction, response_content_type)
496+
output = output_fn(prediction, response_content_type, context)
492497
493498
The above code sample shows the three function definitions:
494499

@@ -536,9 +541,13 @@ it should return an object that can be passed to ``predict_fn`` and have the fol
536541

537542
.. code:: python
538543
539-
def input_fn(request_body, request_content_type)
544+
def input_fn(request_body, request_content_type, context)
540545
541-
Where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string
546+
Where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string.
547+
548+
``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
549+
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
550+
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
542551

543552
The SageMaker PyTorch model server provides a default implementation of ``input_fn``.
544553
This function deserializes JSON, CSV, or NPY encoded data into a torch.Tensor.
@@ -586,16 +595,19 @@ The ``predict_fn`` function has the following signature:
586595

587596
.. code:: python
588597
589-
def predict_fn(input_object, model)
598+
def predict_fn(input_object, model, context)
590599
591600
Where ``input_object`` is the object returned from ``input_fn`` and
592601
``model`` is the model loaded by ``model_fn``.
602+
If you are using multiple GPUs, then specify the ``context`` argument, which contains information such as the GPU ID for a dynamically-selected GPU and the batch size.
603+
One of the examples below demonstrates how to configure ``predict_fn`` with the ``context`` argument to handle multiple GPUs. For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
604+
If you are using CPUs or a single GPU, then you do not need to specify the ``context`` argument.
593605

594606
The default implementation of ``predict_fn`` invokes the loaded model's ``__call__`` function on ``input_object``,
595607
and returns the resulting value. The return-type should be a torch.Tensor to be compatible with the default
596608
``output_fn``.
597609

598-
The example below shows an overridden ``predict_fn``:
610+
The following example shows an overridden ``predict_fn``:
599611

600612
.. code:: python
601613
@@ -609,6 +621,20 @@ The example below shows an overridden ``predict_fn``:
609621
with torch.no_grad():
610622
return model(input_data.to(device))
611623
624+
The following example is for use cases with multiple GPUs and shows an overridden ``predict_fn`` that uses the ``context`` argument to dynamically select a GPU device for making predictions:
625+
626+
.. code:: python
627+
628+
import torch
629+
import numpy as np
630+
631+
def predict_fn(input_data, model):
632+
device = torch.device("cuda:" + str(context.system_properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
633+
model.to(device)
634+
model.eval()
635+
with torch.no_grad():
636+
return model(input_data.to(device))
637+
612638
If you implement your own prediction function, you should take care to ensure that:
613639

614640
- The first argument is expected to be the return value from input_fn.
@@ -664,11 +690,14 @@ The ``output_fn`` has the following signature:
664690

665691
.. code:: python
666692
667-
def output_fn(prediction, content_type)
693+
def output_fn(prediction, content_type, context)
668694
669695
Where ``prediction`` is the result of invoking ``predict_fn`` and
670-
the content type for the response, as specified by the InvokeEndpoint request.
671-
The function should return a byte array of data serialized to content_type.
696+
the content type for the response, as specified by the InvokeEndpoint request. The function should return a byte array of data serialized to ``content_type``.
697+
698+
``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
699+
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
700+
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
672701

673702
The default implementation expects ``prediction`` to be a torch.Tensor and can serialize the result to JSON, CSV, or NPY.
674703
It accepts response content types of "application/json", "text/csv", and "application/x-npy".

0 commit comments

Comments
 (0)