Skip to content

Commit 0b58a2c

Browse files
Merge pull request #4 from absynthe/05_deploy_mc
Updates to the 05_deployment
2 parents e7ad5fa + 0cabb51 commit 0b58a2c

20 files changed

+534
-4372
lines changed

02_training/code/inference.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
print('******* in inference.py *******')
2+
import tensorflow as tf
3+
print(f'TensorFlow version is: {tf.version.VERSION}')
4+
5+
from tensorflow.keras.preprocessing import image
6+
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
7+
print(f'Keras version is: {tf.keras.__version__}')
8+
9+
import io
10+
import base64
11+
import json
12+
import numpy as np
13+
from numpy import argmax
14+
from collections import namedtuple
15+
from PIL import Image
16+
import time
17+
import requests
18+
19+
# Imports for GRPC invoke on TFS
20+
import grpc
21+
from tensorflow.compat.v1 import make_tensor_proto
22+
from tensorflow_serving.apis import predict_pb2
23+
from tensorflow_serving.apis import prediction_service_pb2_grpc
24+
25+
import os
26+
# default to use of GRPC
27+
PREDICT_USING_GRPC = os.environ.get('PREDICT_USING_GRPC', 'true')
28+
if PREDICT_USING_GRPC == 'true':
29+
USE_GRPC = True
30+
else:
31+
USE_GRPC = False
32+
33+
MAX_GRPC_MESSAGE_LENGTH = 512 * 1024 * 1024
34+
35+
HEIGHT = 224
36+
WIDTH = 224
37+
38+
# Restrict memory growth on GPU's
39+
physical_gpus = tf.config.experimental.list_physical_devices('GPU')
40+
if physical_gpus:
41+
try:
42+
# Currently, memory growth needs to be the same across GPUs
43+
for gpu in physical_gpus:
44+
tf.config.experimental.set_memory_growth(gpu, True)
45+
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
46+
print(len(physical_gpus), 'Physical GPUs,', len(logical_gpus), 'Logical GPUs')
47+
except RuntimeError as e:
48+
# Memory growth must be set before GPUs have been initialized
49+
print(e)
50+
else:
51+
print('**** NO physical GPUs')
52+
53+
54+
num_inferences = 0
55+
print(f'num_inferences: {num_inferences}')
56+
57+
Context = namedtuple('Context',
58+
'model_name, model_version, method, rest_uri, grpc_uri, '
59+
'custom_attributes, request_content_type, accept_header')
60+
61+
def handler(data, context):
62+
63+
global num_inferences
64+
num_inferences += 1
65+
66+
print(f'\n************ inference #: {num_inferences}')
67+
if context.request_content_type == 'application/x-image':
68+
stream = io.BytesIO(data.read())
69+
img = Image.open(stream).convert('RGB')
70+
_print_image_metadata(img)
71+
72+
img = img.resize((WIDTH, HEIGHT))
73+
img_array = image.img_to_array(img) #, data_format = "channels_first")
74+
# the image is now in an array of shape (224, 224, 3) or (3, 224, 224) based on data_format
75+
# need to expand it to add dim for num samples, e.g. (1, 224, 224, 3)
76+
x = img_array.reshape((1,) + img_array.shape)
77+
instance = preprocess_input(x)
78+
print(f' final image shape: {instance.shape}')
79+
del x, img
80+
else:
81+
_return_error(415, 'Unsupported content type "{}"'.format(context.request_content_type or 'Unknown'))
82+
83+
start_time = time.time()
84+
85+
if USE_GRPC:
86+
prediction = _predict_using_grpc(context, instance)
87+
88+
else: # use TFS REST API
89+
inst_json = json.dumps({'instances': instance.tolist()})
90+
response = requests.post(context.rest_uri, data=inst_json)
91+
if response.status_code != 200:
92+
raise Exception(response.content.decode('utf-8'))
93+
prediction = response.content
94+
95+
end_time = time.time()
96+
latency = int((end_time - start_time) * 1000)
97+
print(f'=== TFS invoke took: {latency} ms')
98+
99+
response_content_type = context.accept_header
100+
return prediction, response_content_type
101+
102+
def _return_error(code, message):
103+
raise ValueError('Error: {}, {}'.format(str(code), message))
104+
105+
def _predict_using_grpc(context, instance):
106+
request = predict_pb2.PredictRequest()
107+
request.model_spec.name = 'model'
108+
request.model_spec.signature_name = 'serving_default'
109+
110+
request.inputs['input_1'].CopyFrom(make_tensor_proto(instance))
111+
options = [
112+
('grpc.max_send_message_length', MAX_GRPC_MESSAGE_LENGTH),
113+
('grpc.max_receive_message_length', MAX_GRPC_MESSAGE_LENGTH)
114+
]
115+
channel = grpc.insecure_channel(f'0.0.0.0:{context.grpc_port}', options=options)
116+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
117+
result_future = stub.Predict.future(request, 30) # 5 seconds
118+
output_tensor_proto = result_future.result().outputs['output']
119+
output_shape = [dim.size for dim in output_tensor_proto.tensor_shape.dim]
120+
output_np = np.array(output_tensor_proto.float_val).reshape(output_shape)
121+
predicted_class_idx = argmax(output_np)
122+
print(f' Predicted class: {predicted_class_idx}')
123+
prediction_json = {'predictions': output_np.tolist()}
124+
return json.dumps(prediction_json)
125+
126+
def _print_image_metadata(img):
127+
# Retrieve the attributes of the image
128+
fileFormat = img.format
129+
imageMode = img.mode
130+
imageSize = img.size # (width, height)
131+
colorPalette = img.palette
132+
133+
print(f' File format: {fileFormat}')
134+
print(f' Image mode: {imageMode}')
135+
print(f' Image size: {imageSize}')
136+
print(f' Color pal: {colorPalette}')
137+
138+
print(f' Keys from image.info dictionary:')
139+
for key, value in img.info.items():
140+
print(f' {key}')

02_training/code/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# This is the set of Python packages that will get pip installed
2+
# at startup of the Amazon SageMaker endpoint or batch transformation.
3+
Pillow
4+
numpy
5+
tensorflow==2.4.1

05_deployment/README.md

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)