1
+ print ('******* in inference.py *******' )
2
+ import tensorflow as tf
3
+ print (f'TensorFlow version is: { tf .version .VERSION } ' )
4
+
5
+ from tensorflow .keras .preprocessing import image
6
+ from tensorflow .keras .applications .mobilenet_v2 import MobileNetV2 , preprocess_input
7
+ print (f'Keras version is: { tf .keras .__version__ } ' )
8
+
9
+ import io
10
+ import base64
11
+ import json
12
+ import numpy as np
13
+ from numpy import argmax
14
+ from collections import namedtuple
15
+ from PIL import Image
16
+ import time
17
+ import requests
18
+
19
+ # Imports for GRPC invoke on TFS
20
+ import grpc
21
+ from tensorflow .compat .v1 import make_tensor_proto
22
+ from tensorflow_serving .apis import predict_pb2
23
+ from tensorflow_serving .apis import prediction_service_pb2_grpc
24
+
25
+ import os
26
+ # default to use of GRPC
27
+ PREDICT_USING_GRPC = os .environ .get ('PREDICT_USING_GRPC' , 'true' )
28
+ if PREDICT_USING_GRPC == 'true' :
29
+ USE_GRPC = True
30
+ else :
31
+ USE_GRPC = False
32
+
33
+ MAX_GRPC_MESSAGE_LENGTH = 512 * 1024 * 1024
34
+
35
+ HEIGHT = 224
36
+ WIDTH = 224
37
+
38
+ # Restrict memory growth on GPU's
39
+ physical_gpus = tf .config .experimental .list_physical_devices ('GPU' )
40
+ if physical_gpus :
41
+ try :
42
+ # Currently, memory growth needs to be the same across GPUs
43
+ for gpu in physical_gpus :
44
+ tf .config .experimental .set_memory_growth (gpu , True )
45
+ logical_gpus = tf .config .experimental .list_logical_devices ('GPU' )
46
+ print (len (physical_gpus ), 'Physical GPUs,' , len (logical_gpus ), 'Logical GPUs' )
47
+ except RuntimeError as e :
48
+ # Memory growth must be set before GPUs have been initialized
49
+ print (e )
50
+ else :
51
+ print ('**** NO physical GPUs' )
52
+
53
+
54
+ num_inferences = 0
55
+ print (f'num_inferences: { num_inferences } ' )
56
+
57
+ Context = namedtuple ('Context' ,
58
+ 'model_name, model_version, method, rest_uri, grpc_uri, '
59
+ 'custom_attributes, request_content_type, accept_header' )
60
+
61
+ def handler (data , context ):
62
+
63
+ global num_inferences
64
+ num_inferences += 1
65
+
66
+ print (f'\n ************ inference #: { num_inferences } ' )
67
+ if context .request_content_type == 'application/x-image' :
68
+ stream = io .BytesIO (data .read ())
69
+ img = Image .open (stream ).convert ('RGB' )
70
+ _print_image_metadata (img )
71
+
72
+ img = img .resize ((WIDTH , HEIGHT ))
73
+ img_array = image .img_to_array (img ) #, data_format = "channels_first")
74
+ # the image is now in an array of shape (224, 224, 3) or (3, 224, 224) based on data_format
75
+ # need to expand it to add dim for num samples, e.g. (1, 224, 224, 3)
76
+ x = img_array .reshape ((1 ,) + img_array .shape )
77
+ instance = preprocess_input (x )
78
+ print (f' final image shape: { instance .shape } ' )
79
+ del x , img
80
+ else :
81
+ _return_error (415 , 'Unsupported content type "{}"' .format (context .request_content_type or 'Unknown' ))
82
+
83
+ start_time = time .time ()
84
+
85
+ if USE_GRPC :
86
+ prediction = _predict_using_grpc (context , instance )
87
+
88
+ else : # use TFS REST API
89
+ inst_json = json .dumps ({'instances' : instance .tolist ()})
90
+ response = requests .post (context .rest_uri , data = inst_json )
91
+ if response .status_code != 200 :
92
+ raise Exception (response .content .decode ('utf-8' ))
93
+ prediction = response .content
94
+
95
+ end_time = time .time ()
96
+ latency = int ((end_time - start_time ) * 1000 )
97
+ print (f'=== TFS invoke took: { latency } ms' )
98
+
99
+ response_content_type = context .accept_header
100
+ return prediction , response_content_type
101
+
102
+ def _return_error (code , message ):
103
+ raise ValueError ('Error: {}, {}' .format (str (code ), message ))
104
+
105
+ def _predict_using_grpc (context , instance ):
106
+ request = predict_pb2 .PredictRequest ()
107
+ request .model_spec .name = 'model'
108
+ request .model_spec .signature_name = 'serving_default'
109
+
110
+ request .inputs ['input_1' ].CopyFrom (make_tensor_proto (instance ))
111
+ options = [
112
+ ('grpc.max_send_message_length' , MAX_GRPC_MESSAGE_LENGTH ),
113
+ ('grpc.max_receive_message_length' , MAX_GRPC_MESSAGE_LENGTH )
114
+ ]
115
+ channel = grpc .insecure_channel (f'0.0.0.0:{ context .grpc_port } ' , options = options )
116
+ stub = prediction_service_pb2_grpc .PredictionServiceStub (channel )
117
+ result_future = stub .Predict .future (request , 30 ) # 5 seconds
118
+ output_tensor_proto = result_future .result ().outputs ['output' ]
119
+ output_shape = [dim .size for dim in output_tensor_proto .tensor_shape .dim ]
120
+ output_np = np .array (output_tensor_proto .float_val ).reshape (output_shape )
121
+ predicted_class_idx = argmax (output_np )
122
+ print (f' Predicted class: { predicted_class_idx } ' )
123
+ prediction_json = {'predictions' : output_np .tolist ()}
124
+ return json .dumps (prediction_json )
125
+
126
+ def _print_image_metadata (img ):
127
+ # Retrieve the attributes of the image
128
+ fileFormat = img .format
129
+ imageMode = img .mode
130
+ imageSize = img .size # (width, height)
131
+ colorPalette = img .palette
132
+
133
+ print (f' File format: { fileFormat } ' )
134
+ print (f' Image mode: { imageMode } ' )
135
+ print (f' Image size: { imageSize } ' )
136
+ print (f' Color pal: { colorPalette } ' )
137
+
138
+ print (f' Keys from image.info dictionary:' )
139
+ for key , value in img .info .items ():
140
+ print (f' { key } ' )
0 commit comments