@@ -269,6 +269,192 @@ More information on how to create ``export_outputs`` can be found in `specifying
269
269
refer to TensorFlow's `Save and Restore <https://www.tensorflow.org/guide/saved_model >`_ documentation for other ways to control the
270
270
inference-time behavior of your SavedModels.
271
271
272
+ Providing Python scripts for pre/pos-processing
273
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
274
+
275
+ You can add your customized Python code to process your input and output data:
276
+
277
+ .. code ::
278
+
279
+ from sagemaker.tensorflow.serving import Model
280
+
281
+ model = Model(entry_point='inference.py',
282
+ model_data='s3://mybucket/model.tar.gz',
283
+ role='MySageMakerRole')
284
+
285
+ How to implement the pre- and/or post-processing handler(s)
286
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
287
+
288
+ Your entry point file should implement either a pair of ``input_handler ``
289
+ and ``output_handler `` functions or a single ``handler `` function.
290
+ Note that if ``handler `` function is implemented, ``input_handler ``
291
+ and ``output_handler `` are ignored.
292
+
293
+ To implement pre- and/or post-processing handler(s), use the Context
294
+ object that the Python service creates. The Context object is a namedtuple with the following attributes:
295
+
296
+ - ``model_name (string) ``: the name of the model to use for
297
+ inference. For example, 'half-plus-three'
298
+
299
+ - ``model_version (string) ``: version of the model. For example, '5'
300
+
301
+ - ``method (string) ``: inference method. For example, 'predict',
302
+ 'classify' or 'regress', for more information on methods, please see
303
+ `Classify and Regress
304
+ API <https://www.tensorflow.org/tfx/serving/api_rest#classify_and_regress_api> `__
305
+ and `Predict
306
+ API <https://www.tensorflow.org/tfx/serving/api_rest#predict_api> `__
307
+
308
+ - ``rest_uri (string) ``: the TFS REST uri generated by the Python
309
+ service. For example,
310
+ 'http://localhost:8501/v1/models/half_plus_three:predict'
311
+
312
+ - ``grpc_uri (string) ``: the GRPC port number generated by the Python
313
+ service. For example, '9000'
314
+
315
+ - ``custom_attributes (string) ``: content of
316
+ 'X-Amzn-SageMaker-Custom-Attributes' header from the original
317
+ request. For example,
318
+ 'tfs-model-name=half*plus*\ three,tfs-method=predict'
319
+
320
+ - ``request_content_type (string) ``: the original request content type,
321
+ defaulted to 'application/json' if not provided
322
+
323
+ - ``accept_header (string) ``: the original request accept type,
324
+ defaulted to 'application/json' if not provided
325
+
326
+ - ``content_length (int) ``: content length of the original request
327
+
328
+ The following code example implements ``input_handler `` and
329
+ ``output_handler ``. By providing these, the Python service posts the
330
+ request to the TFS REST URI with the data pre-processed by ``input_handler ``
331
+ and passes the response to ``output_handler `` for post-processing.
332
+
333
+ .. code ::
334
+
335
+ import json
336
+
337
+ def input_handler(data, context):
338
+ """ Pre-process request input before it is sent to TensorFlow Serving REST API
339
+ Args:
340
+ data (obj): the request data, in format of dict or string
341
+ context (Context): an object containing request and configuration details
342
+ Returns:
343
+ (dict): a JSON-serializable dict that contains request body and headers
344
+ """
345
+ if context.request_content_type == 'application/json':
346
+ # pass through json (assumes it's correctly formed)
347
+ d = data.read().decode('utf-8')
348
+ return d if len(d) else ''
349
+
350
+ if context.request_content_type == 'text/csv':
351
+ # very simple csv handler
352
+ return json.dumps({
353
+ 'instances': [float(x) for x in data.read().decode('utf-8').split(',')]
354
+ })
355
+
356
+ raise ValueError('{{"error": "unsupported content type {}"}}'.format(
357
+ context.request_content_type or "unknown"))
358
+
359
+
360
+ def output_handler(data, context):
361
+ """Post-process TensorFlow Serving output before it is returned to the client.
362
+ Args:
363
+ data (obj): the TensorFlow serving response
364
+ context (Context): an object containing request and configuration details
365
+ Returns:
366
+ (bytes, string): data to return to client, response content type
367
+ """
368
+ if data.status_code != 200:
369
+ raise ValueError(data.content.decode('utf-8'))
370
+
371
+ response_content_type = context.accept_header
372
+ prediction = data.content
373
+ return prediction, response_content_type
374
+
375
+ You might want to have complete control over the request.
376
+ For example, you might want to make a TFS request (REST or GRPC) to the first model,
377
+ inspect the results, and then make a request to a second model. In this case, implement
378
+ the ``handler `` method instead of the ``input_handler `` and ``output_handler `` methods, as demonstrated
379
+ in the following code:
380
+
381
+ .. code ::
382
+
383
+ import json
384
+ import requests
385
+
386
+
387
+ def handler(data, context):
388
+ """Handle request.
389
+ Args:
390
+ data (obj): the request data
391
+ context (Context): an object containing request and configuration details
392
+ Returns:
393
+ (bytes, string): data to return to client, (optional) response content type
394
+ """
395
+ processed_input = _process_input(data, context)
396
+ response = requests.post(context.rest_uri, data=processed_input)
397
+ return _process_output(response, context)
398
+
399
+
400
+ def _process_input(data, context):
401
+ if context.request_content_type == 'application/json':
402
+ # pass through json (assumes it's correctly formed)
403
+ d = data.read().decode('utf-8')
404
+ return d if len(d) else ''
405
+
406
+ if context.request_content_type == 'text/csv':
407
+ # very simple csv handler
408
+ return json.dumps({
409
+ 'instances': [float(x) for x in data.read().decode('utf-8').split(',')]
410
+ })
411
+
412
+ raise ValueError('{{"error": "unsupported content type {}"}}'.format(
413
+ context.request_content_type or "unknown"))
414
+
415
+
416
+ def _process_output(data, context):
417
+ if data.status_code != 200:
418
+ raise ValueError(data.content.decode('utf-8'))
419
+
420
+ response_content_type = context.accept_header
421
+ prediction = data.content
422
+ return prediction, response_content_type
423
+
424
+ You can also bring in external dependencies to help with your data
425
+ processing. There are 2 ways to do this:
426
+
427
+ 1. If you included ``requirements.txt `` in your ``source_dir `` or in
428
+ your dependencies, the container installs the Python dependencies at runtime using ``pip install -r ``:
429
+
430
+ .. code ::
431
+
432
+ from sagemaker.tensorflow.serving import Model
433
+
434
+ model = Model(entry_point='inference.py',
435
+ dependencies=['requirements.txt'],
436
+ model_data='s3://mybucket/model.tar.gz',
437
+ role='MySageMakerRole')
438
+
439
+
440
+ 2. If you are working in a network-isolation situation or if you don't
441
+ want to install dependencies at runtime every time your endpoint starts or a batch
442
+ transform job runs, you might want to put
443
+ pre-downloaded dependencies under a ``lib `` directory and this
444
+ directory as dependency. The container adds the modules to the Python
445
+ path. Note that if both ``lib `` and ``requirements.txt ``
446
+ are present in the model archive, the ``requirements.txt `` is ignored:
447
+
448
+ .. code ::
449
+
450
+ from sagemaker.tensorflow.serving import Model
451
+
452
+ model = Model(entry_point='inference.py',
453
+ dependencies=['/path/to/folder/named/lib'],
454
+ model_data='s3://mybucket/model.tar.gz',
455
+ role='MySageMakerRole')
456
+
457
+
272
458
Deploying more than one model to your Endpoint
273
459
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
274
460
0 commit comments