Skip to content

Commit b43439c

Browse files
waytrue17Wei Chu
andauthored
feature: pass context to handler functions (#109)
update unit test run black locally fix signature for py2.7 pin flake8 update readme Co-authored-by: Wei Chu <[email protected]>
1 parent 52cd814 commit b43439c

File tree

7 files changed

+185
-49
lines changed

7 files changed

+185
-49
lines changed

README.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@ To use the SageMaker Inference Toolkit, you need to do the following:
4646

4747
class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler):
4848

49-
def default_model_fn(self, model_dir):
49+
def default_model_fn(self, model_dir, context=None):
5050
"""Loads a model. For PyTorch, a default function to load a model cannot be provided.
5151
Users should provide customized model_fn() in script.
5252
5353
Args:
5454
model_dir: a directory where model is saved.
55+
context (obj): the request context (default: None).
5556
5657
Returns: A PyTorch model.
5758
"""
@@ -60,40 +61,54 @@ To use the SageMaker Inference Toolkit, you need to do the following:
6061
See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk
6162
"""))
6263

63-
def default_input_fn(self, input_data, content_type):
64+
def default_input_fn(self, input_data, content_type, context=None):
6465
"""A default input_fn that can handle JSON, CSV and NPZ formats.
6566
6667
Args:
6768
input_data: the request payload serialized in the content_type format
6869
content_type: the request content_type
70+
context (obj): the request context (default: None).
6971
7072
Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor depending if cuda is available.
7173
"""
7274
return decoder.decode(input_data, content_type)
7375

74-
def default_predict_fn(self, data, model):
76+
def default_predict_fn(self, data, model, context=None):
7577
"""A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn.
7678
Runs prediction on GPU if cuda is available.
7779
7880
Args:
7981
data: input data (torch.Tensor) for prediction deserialized by input_fn
8082
model: PyTorch model loaded in memory by model_fn
83+
context (obj): the request context (default: None).
8184
8285
Returns: a prediction
8386
"""
8487
return model(input_data)
8588

86-
def default_output_fn(self, prediction, accept):
89+
def default_output_fn(self, prediction, accept, context=None):
8790
"""A default output_fn for PyTorch. Serializes predictions from predict_fn to JSON, CSV or NPY format.
8891
8992
Args:
9093
prediction: a prediction result from predict_fn
9194
accept: type which the output data needs to be serialized
95+
context (obj): the request context (default: None).
9296
9397
Returns: output data serialized
9498
"""
9599
return encoder.encode(prediction, accept)
96100
```
101+
Note, passing context as an argument to the handler functions is optional. Customer can choose to omit context from the function declaration if it's not needed in the runtime. For example, the following handler function declarations will also work:
102+
103+
```
104+
def default_model_fn(self, model_dir)
105+
106+
def default_input_fn(self, input_data, content_type)
107+
108+
def default_predict_fn(self, data, model)
109+
110+
def default_output_fn(self, prediction, accept)
111+
```
97112

98113
2. Implement a handler service that is executed by the model server.
99114
([Here is an example](https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/handler_service.py) of a handler service.)

src/sagemaker_inference/default_handler_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,4 @@ def initialize(self, context):
6363
else:
6464
os.environ[PYTHON_PATH_ENV] = code_dir_path
6565

66-
self._service.validate_and_initialize(model_dir=model_dir)
66+
self._service.validate_and_initialize(model_dir=model_dir, context=context)

src/sagemaker_inference/default_inference_handler.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
class DefaultInferenceHandler(object):
2222
"""Bare-bones implementation of default inference functions."""
2323

24-
def default_model_fn(self, model_dir):
24+
def default_model_fn(self, model_dir, context=None):
2525
"""Function responsible for loading the model.
2626
2727
Args:
2828
model_dir (str): The directory where model files are stored.
29+
context (obj): the request context (default: None).
2930
3031
Returns:
3132
obj: the loaded model.
@@ -40,25 +41,28 @@ def default_model_fn(self, model_dir):
4041
)
4142
)
4243

43-
def default_input_fn(self, input_data, content_type): # pylint: disable=no-self-use
44+
def default_input_fn(self, input_data, content_type, context=None):
45+
# pylint: disable=unused-argument, no-self-use
4446
"""Function responsible for deserializing the input data into an object for prediction.
4547
4648
Args:
4749
input_data (obj): the request data.
4850
content_type (str): the request content type.
51+
context (obj): the request context (default: None).
4952
5053
Returns:
5154
obj: data ready for prediction.
5255
5356
"""
5457
return decoder.decode(input_data, content_type)
5558

56-
def default_predict_fn(self, data, model):
59+
def default_predict_fn(self, data, model, context=None):
5760
"""Function responsible for model predictions.
5861
5962
Args:
60-
model (obj): model loaded by the model_fn
61-
data: deserialized data returned by the input_fn
63+
model (obj): model loaded by the model_fn.
64+
data: deserialized data returned by the input_fn.
65+
context (obj): the request context (default: None).
6266
6367
Returns:
6468
obj: prediction result.
@@ -73,12 +77,13 @@ def default_predict_fn(self, data, model):
7377
)
7478
)
7579

76-
def default_output_fn(self, prediction, accept): # pylint: disable=no-self-use
80+
def default_output_fn(self, prediction, accept, context=None): # pylint: disable=no-self-use
7781
"""Function responsible for serializing the prediction result to the desired accept type.
7882
7983
Args:
8084
prediction (obj): prediction result returned by the predict_fn.
8185
accept (str): accept header expected by the client.
86+
context (obj): the request context (default: None).
8287
8388
Returns:
8489
obj: prediction data.

src/sagemaker_inference/transformer.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
import importlib
2020
import traceback
2121

22+
try:
23+
from inspect import signature # pylint: disable=ungrouped-imports
24+
except ImportError:
25+
# for Python2.7
26+
import subprocess
27+
import sys
28+
29+
subprocess.check_call([sys.executable, "-m", "pip", "install", "inspect2"])
30+
from inspect2 import signature
31+
2232
try:
2333
from importlib.util import find_spec # pylint: disable=ungrouped-imports
2434
except ImportError:
@@ -73,6 +83,7 @@ def __init__(self, default_inference_handler=None):
7383
self._input_fn = None
7484
self._predict_fn = None
7585
self._output_fn = None
86+
self._context = None
7687

7788
@staticmethod
7889
def handle_error(context, inference_exception, trace):
@@ -109,7 +120,7 @@ def transform(self, data, context):
109120
try:
110121
properties = context.system_properties
111122
model_dir = properties.get("model_dir")
112-
self.validate_and_initialize(model_dir=model_dir)
123+
self.validate_and_initialize(model_dir=model_dir, context=context)
113124

114125
input_data = data[0].get("body")
115126

@@ -125,7 +136,9 @@ def transform(self, data, context):
125136
if content_type in content_types.UTF8_TYPES:
126137
input_data = input_data.decode("utf-8")
127138

128-
result = self._transform_fn(self._model, input_data, content_type, accept)
139+
result = self._run_handler_function(
140+
self._transform_fn, *(self._model, input_data, content_type, accept)
141+
)
129142

130143
response = result
131144
response_content_type = accept
@@ -148,20 +161,25 @@ def transform(self, data, context):
148161
trace,
149162
)
150163

151-
def validate_and_initialize(self, model_dir=environment.model_dir): # type: () -> None
164+
def validate_and_initialize(self, model_dir=environment.model_dir, context=None):
152165
"""Validates the user module against the SageMaker inference contract.
153166
154167
Load the model as defined by the ``model_fn`` to prepare handling predictions.
155168
156169
"""
157170
if not self._initialized:
171+
self._context = context
158172
self._environment = environment.Environment()
159173
self._validate_user_module_and_set_functions()
174+
160175
if self._pre_model_fn is not None:
161-
self._pre_model_fn(model_dir)
162-
self._model = self._model_fn(model_dir)
176+
self._run_handler_function(self._pre_model_fn, *(model_dir,))
177+
178+
self._model = self._run_handler_function(self._model_fn, *(model_dir,))
179+
163180
if self._model_warmup_fn is not None:
164-
self._model_warmup_fn(model_dir, self._model)
181+
self._run_handler_function(self._model_warmup_fn, *(model_dir, self._model))
182+
165183
self._initialized = True
166184

167185
def _validate_user_module_and_set_functions(self):
@@ -214,7 +232,8 @@ def _validate_user_module_and_set_functions(self):
214232

215233
self._transform_fn = self._default_transform_fn
216234

217-
def _default_transform_fn(self, model, input_data, content_type, accept):
235+
def _default_transform_fn(self, model, input_data, content_type, accept, context=None):
236+
# pylint: disable=unused-argument
218237
"""Make predictions against the model and return a serialized response.
219238
This serves as the default implementation of transform_fn, used when the
220239
user has not provided an implementation.
@@ -224,13 +243,36 @@ def _default_transform_fn(self, model, input_data, content_type, accept):
224243
input_data (obj): the request data.
225244
content_type (str): the request content type.
226245
accept (str): accept header expected by the client.
246+
context (obj): the request context (default: None).
227247
228248
Returns:
229249
obj: the serialized prediction result or a tuple of the form
230250
(response_data, content_type)
231251
232252
"""
233-
data = self._input_fn(input_data, content_type)
234-
prediction = self._predict_fn(data, model)
235-
result = self._output_fn(prediction, accept)
253+
data = self._run_handler_function(self._input_fn, *(input_data, content_type))
254+
prediction = self._run_handler_function(self._predict_fn, *(data, model))
255+
result = self._run_handler_function(self._output_fn, *(prediction, accept))
256+
return result
257+
258+
def _run_handler_function(self, func, *argv):
259+
"""Helper to call the handler function which covers 2 cases:
260+
1. the handle function takes context
261+
2. the handle function does not take context
262+
"""
263+
num_func_input = len(signature(func).parameters)
264+
if num_func_input == len(argv):
265+
# function does not take context
266+
result = func(*argv)
267+
elif num_func_input == len(argv) + 1:
268+
# function takes context
269+
argv_context = argv + (self._context,)
270+
result = func(*argv_context)
271+
else:
272+
raise TypeError(
273+
"{} takes {} arguments but {} were given.".format(
274+
func.__name__, num_func_input, len(argv)
275+
)
276+
)
277+
236278
return result

test/unit/test_default_inference_handler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from mock import patch
13+
from mock import Mock, patch
1414
import pytest
1515

1616
from sagemaker_inference import content_types
@@ -19,7 +19,8 @@
1919

2020
@patch("sagemaker_inference.decoder.decode")
2121
def test_default_input_fn(loads):
22-
assert DefaultInferenceHandler().default_input_fn(42, content_types.JSON)
22+
context = Mock()
23+
assert DefaultInferenceHandler().default_input_fn(42, content_types.JSON, context)
2324

2425
loads.assert_called_with(42, content_types.JSON)
2526

@@ -34,7 +35,8 @@ def test_default_input_fn(loads):
3435
)
3536
@patch("sagemaker_inference.encoder.encode", lambda prediction, accept: prediction**2)
3637
def test_default_output_fn(accept, expected_content_type):
37-
result, content_type = DefaultInferenceHandler().default_output_fn(2, accept)
38+
context = Mock()
39+
result, content_type = DefaultInferenceHandler().default_output_fn(2, accept, context)
3840
assert result == 4
3941
assert content_type == expected_content_type
4042

0 commit comments

Comments
 (0)