Skip to content

Commit dfee185

Browse files
authored
increase grpc message size limit to 2gb (#180)
1 parent f12e81b commit dfee185

File tree

6 files changed

+173
-11
lines changed

6 files changed

+173
-11
lines changed

src/tf_container/proxy_client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,22 @@
1414

1515
import numpy as np
1616
from google.protobuf import json_format
17-
from grpc.beta import implementations
17+
import grpc
1818
from tensorflow import make_tensor_proto
1919
from tensorflow.core.example import example_pb2, feature_pb2
2020
from tensorflow.core.framework import tensor_pb2
2121
from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY, \
2222
PREDICT_INPUTS
2323
from tensorflow_serving.apis import get_model_metadata_pb2
2424
from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2
25-
from tensorflow_serving.apis import prediction_service_pb2
25+
from tensorflow_serving.apis import prediction_service_pb2_grpc
2626

2727
from tf_container.run import logger as _logger
2828

2929
INFERENCE_ACCELERATOR_PRESENT_ENV = 'SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT'
3030
TF_SERVING_GRPC_REQUEST_TIMEOUT_ENV = 'SAGEMAKER_TFS_GRPC_REQUEST_TIMEOUT'
3131

32+
MAX_GRPC_MESSAGE_SIZE = 1024 ** 3 * 2 - 1 # 2GB - 1
3233
DEFAULT_GRPC_REQUEST_TIMEOUT_FOR_INFERENCE_ACCELERATOR = 30.0
3334

3435
REGRESSION = 'tensorflow/serving/regress'
@@ -82,8 +83,12 @@ def request(self, data):
8283
return request_fn(data)
8384

8485
def cache_prediction_metadata(self):
85-
channel = implementations.insecure_channel(self.host, self.tf_serving_port)
86-
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
86+
channel = grpc.insecure_channel(
87+
'{}:{}'.format(self.host, self.tf_serving_port),
88+
options=[
89+
('grpc.max_send_message_length', MAX_GRPC_MESSAGE_SIZE),
90+
('grpc.max_receive_message_length', MAX_GRPC_MESSAGE_SIZE)])
91+
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
8792
request = get_model_metadata_pb2.GetModelMetadataRequest()
8893

8994
request.model_spec.name = self.model_name
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
14+
import json
15+
16+
import requests
17+
18+
19+
def test_grpc_message_4m_json():
20+
# this will generate a request just over the original
21+
# 4MB limit (1024 * 1024 * 1 * 4 bytes) + overhead
22+
# this matches the message size from the original issue report
23+
# response will have same size, so we are testing both ends
24+
data = {
25+
'shape': [1, 1024, 1024, 1],
26+
'dtype': 'float32'
27+
}
28+
29+
response = requests.post("http://localhost:8080/invocations",
30+
data=json.dumps(data),
31+
headers={'Content-type': 'application/json',
32+
'Accept': 'application/json'}).content
33+
34+
prediction = json.loads(response)
35+
36+
expected_shape = {
37+
'dim': [
38+
{'size': '1'},
39+
{'size': '1024'},
40+
{'size': '1024'},
41+
{'size': '1'},
42+
]
43+
}
44+
45+
assert expected_shape == prediction['outputs']['y']['tensorShape']
46+
assert 2.0 == prediction['outputs']['y']['floatVal'][-1]
47+
48+
49+
def test_large_grpc_message_512m_pb2():
50+
# this will generate request ~ 512mb
51+
# (1024 * 1024 * 128 * 4 bytes) + overhead
52+
# response will have same size, so we are testing both ends
53+
# returning bytes (serialized pb2) instead of json, because
54+
# our default json output function runs out of memory with
55+
# much smaller messages (around 128MB on if gunicorn running in 8GB)
56+
data = {
57+
'shape': [1, 1024, 1024, 128],
58+
'dtype': 'float32'
59+
}
60+
61+
response = requests.post("http://localhost:8080/invocations",
62+
data=json.dumps(data),
63+
headers={'Content-Type': 'application/json',
64+
'Accept': 'application/octet-stream'})
65+
66+
assert 200 == response.status_code
67+
assert 512 * 1024 ** 2 <= int(response.headers['Content-Length'])

test/integ/docker_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,15 @@ def execute_pytest(self, tests_path):
156156

157157
class HostingContainer(Container):
158158
def __init__(self, image, opt_ml, script_name, processor, requirements_file=None,
159-
startup_delay=5):
159+
startup_delay=5, region=None):
160160
super(HostingContainer, self).__init__(image=image,
161161
processor=processor,
162162
startup_delay=startup_delay)
163163
self.opt_ml = opt_ml
164164
self.script_name = script_name
165165
self.opt_ml = opt_ml
166166
self.requirements_file = requirements_file
167+
self.region = region
167168

168169
def __enter__(self):
169170
cmd = [self.docker,
@@ -174,6 +175,7 @@ def __enter__(self):
174175
'-e', 'AWS_ACCESS_KEY_ID',
175176
'-e', 'AWS_SECRET_ACCESS_KEY',
176177
'-e', 'AWS_SESSION_TOKEN',
178+
'-e', 'SAGEMAKER_REGION={}'.format(self.region if self.region else ''),
177179
'-e', 'SAGEMAKER_CONTAINER_LOG_LEVEL=20',
178180
'-e', 'SAGEMAKER_PROGRAM={}'.format(self.script_name),
179181
'-e', 'SAGEMAKER_REQUIREMENTS={}'.format(self.requirements_file),

test/integ/test_large_grpc_message.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
import os
14+
15+
import tensorflow as tf
16+
17+
from test.integ.docker_utils import HostingContainer
18+
from test.integ.utils import copy_resource
19+
from test.integ.conftest import SCRIPT_PATH
20+
21+
22+
def create_model(export_dir):
23+
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
24+
with tf.Session() as session:
25+
x = tf.placeholder(tf.float32, shape=[None, 1024, 1024, 1], name='x')
26+
a = tf.constant(2.0)
27+
28+
y = tf.multiply(a, x, name='y')
29+
predict_signature_def = (
30+
tf.saved_model.signature_def_utils.predict_signature_def({
31+
'x': x
32+
}, {'y': y}))
33+
signature_def_map = {
34+
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
35+
predict_signature_def
36+
}
37+
session.run(tf.global_variables_initializer())
38+
builder.add_meta_graph_and_variables(
39+
session, [tf.saved_model.tag_constants.SERVING],
40+
signature_def_map=signature_def_map)
41+
builder.save()
42+
43+
44+
def test_large_grpc_message(docker_image, opt_ml, processor, region):
45+
resource_path = os.path.join(SCRIPT_PATH, '../resources/large_grpc_message')
46+
copy_resource(resource_path, opt_ml, 'code', 'code')
47+
export_dir = os.path.join(opt_ml, 'model', 'export', 'Servo', '1')
48+
create_model(export_dir)
49+
50+
with HostingContainer(opt_ml=opt_ml, image=docker_image,
51+
script_name='inference.py',
52+
processor=processor,
53+
region=region) as c:
54+
c.execute_pytest('test/integ/container_tests/large_grpc_message.py')
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import json
15+
16+
import numpy as np
17+
18+
19+
def input_fn(data, content_type):
20+
"""
21+
Args:
22+
data: json string containing tensor shape info
23+
e.g. {"shape": [1, 1024, 1024, 1], "dtype": "float32"}
24+
content_type: ignored
25+
Returns:
26+
a dict with { 'x': numpy array with the specified shape and dtype }
27+
"""
28+
29+
input_info = json.loads(data)
30+
shape = input_info['shape']
31+
dtype = getattr(np, input_info['dtype'])
32+
return {'x': np.ones(shape, dtype)}

test/unit/test_proxy_client.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import pytest
1818
from google import protobuf
1919
from mock import MagicMock, patch, ANY
20-
from tensorflow_serving.apis import prediction_service_pb2, get_model_metadata_pb2
20+
from tensorflow_serving.apis import prediction_service_pb2_grpc, get_model_metadata_pb2
2121

22-
from tf_container.proxy_client import GRPCProxyClient
22+
from tf_container.proxy_client import GRPCProxyClient, MAX_GRPC_MESSAGE_SIZE
2323

2424
REGRESSION = 'tensorflow/serving/regress'
2525
INFERENCE = 'tensorflow/serving/inference'
@@ -360,14 +360,16 @@ def test_classification_protobuf(proxy_client):
360360

361361
@patch('tensorflow_serving.apis.get_model_metadata_pb2.SignatureDefMap')
362362
@patch('tensorflow_serving.apis.get_model_metadata_pb2.GetModelMetadataRequest')
363-
@patch('tensorflow_serving.apis.prediction_service_pb2.beta_create_PredictionService_stub')
364-
@patch('grpc.beta.implementations.insecure_channel')
363+
@patch('tensorflow_serving.apis.prediction_service_pb2_grpc.PredictionServiceStub')
364+
@patch('grpc.insecure_channel')
365365
def test_cache_prediction_metadata(channel, stub, request, signature_def_map, proxy_client):
366366
proxy_client.cache_prediction_metadata()
367367

368-
channel.assert_called_once_with('localhost', DEFAULT_PORT)
368+
channel.assert_called_once_with('localhost:{}'.format(DEFAULT_PORT), options=[
369+
('grpc.max_send_message_length', MAX_GRPC_MESSAGE_SIZE),
370+
('grpc.max_receive_message_length', MAX_GRPC_MESSAGE_SIZE)])
369371

370-
stub = prediction_service_pb2.beta_create_PredictionService_stub
372+
stub = prediction_service_pb2_grpc.PredictionServiceStub
371373
stub.assert_called_once_with(channel())
372374

373375
request = get_model_metadata_pb2.GetModelMetadataRequest

0 commit comments

Comments
 (0)