You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string
546
+
Where ``request_body`` is a byte buffer and ``request_content_type`` is a Python string.
547
+
548
+
``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
549
+
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
550
+
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
542
551
543
552
The SageMaker PyTorch model server provides a default implementation of ``input_fn``.
544
553
This function deserializes JSON, CSV, or NPY encoded data into a torch.Tensor.
@@ -586,16 +595,19 @@ The ``predict_fn`` function has the following signature:
586
595
587
596
.. code:: python
588
597
589
-
defpredict_fn(input_object, model)
598
+
defpredict_fn(input_object, model, context)
590
599
591
600
Where ``input_object`` is the object returned from ``input_fn`` and
592
601
``model`` is the model loaded by ``model_fn``.
602
+
If you are using multiple GPUs, then specify the ``context`` argument, which contains information such as the GPU ID for a dynamically-selected GPU and the batch size.
603
+
One of the examples below demonstrates how to configure ``predict_fn`` with the ``context`` argument to handle multiple GPUs. For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
604
+
If you are using CPUs or a single GPU, then you do not need to specify the ``context`` argument.
593
605
594
606
The default implementation of ``predict_fn`` invokes the loaded model's ``__call__`` function on ``input_object``,
595
607
and returns the resulting value. The return-type should be a torch.Tensor to be compatible with the default
596
608
``output_fn``.
597
609
598
-
The example below shows an overridden ``predict_fn``:
610
+
The following example shows an overridden ``predict_fn``:
599
611
600
612
.. code:: python
601
613
@@ -609,6 +621,20 @@ The example below shows an overridden ``predict_fn``:
609
621
with torch.no_grad():
610
622
return model(input_data.to(device))
611
623
624
+
The following example is for use cases with multiple GPUs and shows an overridden ``predict_fn`` that uses the ``context`` argument to dynamically select a GPU device for making predictions:
625
+
626
+
.. code:: python
627
+
628
+
import torch
629
+
import numpy as np
630
+
631
+
defpredict_fn(input_data, model):
632
+
device = torch.device("cuda:"+str(context.system_properties.get("gpu_id")) if torch.cuda.is_available() else"cpu")
633
+
model.to(device)
634
+
model.eval()
635
+
with torch.no_grad():
636
+
return model(input_data.to(device))
637
+
612
638
If you implement your own prediction function, you should take care to ensure that:
613
639
614
640
- The first argument is expected to be the return value from input_fn.
@@ -664,11 +690,14 @@ The ``output_fn`` has the following signature:
664
690
665
691
.. code:: python
666
692
667
-
defoutput_fn(prediction, content_type)
693
+
defoutput_fn(prediction, content_type, context)
668
694
669
695
Where ``prediction`` is the result of invoking ``predict_fn`` and
670
-
the content type for the response, as specified by the InvokeEndpoint request.
671
-
The function should return a byte array of data serialized to content_type.
696
+
the content type for the response, as specified by the InvokeEndpoint request. The function should return a byte array of data serialized to ``content_type``.
697
+
698
+
``context`` is an optional argument that contains additional serving information, such as the GPU ID and batch size.
699
+
If specified in the function declaration, the context will be created and passed to the function by SageMaker.
700
+
For more information about ``context``, see the `Serving Context class <https://github.com/pytorch/serve/blob/master/ts/context.py>`_.
672
701
673
702
The default implementation expects ``prediction`` to be a torch.Tensor and can serialize the result to JSON, CSV, or NPY.
674
703
It accepts response content types of "application/json", "text/csv", and "application/x-npy".
0 commit comments