Skip to content

Commit fbe1802

Browse files
authored
documentation: TFS support for pre/processing functions (aws#807)
1 parent c921ead commit fbe1802

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed

src/sagemaker/tensorflow/deploying_tensorflow_serving.rst

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,192 @@ More information on how to create ``export_outputs`` can be found in `specifying
269269
refer to TensorFlow's `Save and Restore <https://www.tensorflow.org/guide/saved_model>`_ documentation for other ways to control the
270270
inference-time behavior of your SavedModels.
271271

272+
Providing Python scripts for pre/pos-processing
273+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
274+
275+
You can add your customized Python code to process your input and output data:
276+
277+
.. code::
278+
279+
from sagemaker.tensorflow.serving import Model
280+
281+
model = Model(entry_point='inference.py',
282+
model_data='s3://mybucket/model.tar.gz',
283+
role='MySageMakerRole')
284+
285+
How to implement the pre- and/or post-processing handler(s)
286+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
287+
288+
Your entry point file should implement either a pair of ``input_handler``
289+
and ``output_handler`` functions or a single ``handler`` function.
290+
Note that if ``handler`` function is implemented, ``input_handler``
291+
and ``output_handler`` are ignored.
292+
293+
To implement pre- and/or post-processing handler(s), use the Context
294+
object that the Python service creates. The Context object is a namedtuple with the following attributes:
295+
296+
- ``model_name (string)``: the name of the model to use for
297+
inference. For example, 'half-plus-three'
298+
299+
- ``model_version (string)``: version of the model. For example, '5'
300+
301+
- ``method (string)``: inference method. For example, 'predict',
302+
'classify' or 'regress', for more information on methods, please see
303+
`Classify and Regress
304+
API <https://www.tensorflow.org/tfx/serving/api_rest#classify_and_regress_api>`__
305+
and `Predict
306+
API <https://www.tensorflow.org/tfx/serving/api_rest#predict_api>`__
307+
308+
- ``rest_uri (string)``: the TFS REST uri generated by the Python
309+
service. For example,
310+
'http://localhost:8501/v1/models/half_plus_three:predict'
311+
312+
- ``grpc_uri (string)``: the GRPC port number generated by the Python
313+
service. For example, '9000'
314+
315+
- ``custom_attributes (string)``: content of
316+
'X-Amzn-SageMaker-Custom-Attributes' header from the original
317+
request. For example,
318+
'tfs-model-name=half*plus*\ three,tfs-method=predict'
319+
320+
- ``request_content_type (string)``: the original request content type,
321+
defaulted to 'application/json' if not provided
322+
323+
- ``accept_header (string)``: the original request accept type,
324+
defaulted to 'application/json' if not provided
325+
326+
- ``content_length (int)``: content length of the original request
327+
328+
The following code example implements ``input_handler`` and
329+
``output_handler``. By providing these, the Python service posts the
330+
request to the TFS REST URI with the data pre-processed by ``input_handler``
331+
and passes the response to ``output_handler`` for post-processing.
332+
333+
.. code::
334+
335+
import json
336+
337+
def input_handler(data, context):
338+
""" Pre-process request input before it is sent to TensorFlow Serving REST API
339+
Args:
340+
data (obj): the request data, in format of dict or string
341+
context (Context): an object containing request and configuration details
342+
Returns:
343+
(dict): a JSON-serializable dict that contains request body and headers
344+
"""
345+
if context.request_content_type == 'application/json':
346+
# pass through json (assumes it's correctly formed)
347+
d = data.read().decode('utf-8')
348+
return d if len(d) else ''
349+
350+
if context.request_content_type == 'text/csv':
351+
# very simple csv handler
352+
return json.dumps({
353+
'instances': [float(x) for x in data.read().decode('utf-8').split(',')]
354+
})
355+
356+
raise ValueError('{{"error": "unsupported content type {}"}}'.format(
357+
context.request_content_type or "unknown"))
358+
359+
360+
def output_handler(data, context):
361+
"""Post-process TensorFlow Serving output before it is returned to the client.
362+
Args:
363+
data (obj): the TensorFlow serving response
364+
context (Context): an object containing request and configuration details
365+
Returns:
366+
(bytes, string): data to return to client, response content type
367+
"""
368+
if data.status_code != 200:
369+
raise ValueError(data.content.decode('utf-8'))
370+
371+
response_content_type = context.accept_header
372+
prediction = data.content
373+
return prediction, response_content_type
374+
375+
You might want to have complete control over the request.
376+
For example, you might want to make a TFS request (REST or GRPC) to the first model,
377+
inspect the results, and then make a request to a second model. In this case, implement
378+
the ``handler`` method instead of the ``input_handler`` and ``output_handler`` methods, as demonstrated
379+
in the following code:
380+
381+
.. code::
382+
383+
import json
384+
import requests
385+
386+
387+
def handler(data, context):
388+
"""Handle request.
389+
Args:
390+
data (obj): the request data
391+
context (Context): an object containing request and configuration details
392+
Returns:
393+
(bytes, string): data to return to client, (optional) response content type
394+
"""
395+
processed_input = _process_input(data, context)
396+
response = requests.post(context.rest_uri, data=processed_input)
397+
return _process_output(response, context)
398+
399+
400+
def _process_input(data, context):
401+
if context.request_content_type == 'application/json':
402+
# pass through json (assumes it's correctly formed)
403+
d = data.read().decode('utf-8')
404+
return d if len(d) else ''
405+
406+
if context.request_content_type == 'text/csv':
407+
# very simple csv handler
408+
return json.dumps({
409+
'instances': [float(x) for x in data.read().decode('utf-8').split(',')]
410+
})
411+
412+
raise ValueError('{{"error": "unsupported content type {}"}}'.format(
413+
context.request_content_type or "unknown"))
414+
415+
416+
def _process_output(data, context):
417+
if data.status_code != 200:
418+
raise ValueError(data.content.decode('utf-8'))
419+
420+
response_content_type = context.accept_header
421+
prediction = data.content
422+
return prediction, response_content_type
423+
424+
You can also bring in external dependencies to help with your data
425+
processing. There are 2 ways to do this:
426+
427+
1. If you included ``requirements.txt`` in your ``source_dir`` or in
428+
your dependencies, the container installs the Python dependencies at runtime using ``pip install -r``:
429+
430+
.. code::
431+
432+
from sagemaker.tensorflow.serving import Model
433+
434+
model = Model(entry_point='inference.py',
435+
dependencies=['requirements.txt'],
436+
model_data='s3://mybucket/model.tar.gz',
437+
role='MySageMakerRole')
438+
439+
440+
2. If you are working in a network-isolation situation or if you don't
441+
want to install dependencies at runtime every time your endpoint starts or a batch
442+
transform job runs, you might want to put
443+
pre-downloaded dependencies under a ``lib`` directory and this
444+
directory as dependency. The container adds the modules to the Python
445+
path. Note that if both ``lib`` and ``requirements.txt``
446+
are present in the model archive, the ``requirements.txt`` is ignored:
447+
448+
.. code::
449+
450+
from sagemaker.tensorflow.serving import Model
451+
452+
model = Model(entry_point='inference.py',
453+
dependencies=['/path/to/folder/named/lib'],
454+
model_data='s3://mybucket/model.tar.gz',
455+
role='MySageMakerRole')
456+
457+
272458
Deploying more than one model to your Endpoint
273459
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
274460

0 commit comments

Comments
 (0)