Skip to content

increase grpc message size limit to 2gb #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/tf_container/proxy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,22 @@

import numpy as np
from google.protobuf import json_format
from grpc.beta import implementations
import grpc
from tensorflow import make_tensor_proto
from tensorflow.core.example import example_pb2, feature_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY, \
PREDICT_INPUTS
from tensorflow_serving.apis import get_model_metadata_pb2
from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc

from tf_container.run import logger as _logger

INFERENCE_ACCELERATOR_PRESENT_ENV = 'SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT'
TF_SERVING_GRPC_REQUEST_TIMEOUT_ENV = 'SAGEMAKER_TFS_GRPC_REQUEST_TIMEOUT'

MAX_GRPC_MESSAGE_SIZE = 1024 ** 3 * 2 - 1 # 2GB - 1
DEFAULT_GRPC_REQUEST_TIMEOUT_FOR_INFERENCE_ACCELERATOR = 30.0

REGRESSION = 'tensorflow/serving/regress'
Expand Down Expand Up @@ -82,8 +83,12 @@ def request(self, data):
return request_fn(data)

def cache_prediction_metadata(self):
channel = implementations.insecure_channel(self.host, self.tf_serving_port)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel(
'{}:{}'.format(self.host, self.tf_serving_port),
options=[
('grpc.max_send_message_length', MAX_GRPC_MESSAGE_SIZE),
('grpc.max_receive_message_length', MAX_GRPC_MESSAGE_SIZE)])
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = get_model_metadata_pb2.GetModelMetadataRequest()

request.model_spec.name = self.model_name
Expand Down
67 changes: 67 additions & 0 deletions test/integ/container_tests/large_grpc_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import json

import requests


def test_grpc_message_4m_json():
# this will generate a request just over the original
# 4MB limit (1024 * 1024 * 1 * 4 bytes) + overhead
# this matches the message size from the original issue report
# response will have same size, so we are testing both ends
data = {
'shape': [1, 1024, 1024, 1],
'dtype': 'float32'
}

response = requests.post("http://localhost:8080/invocations",
data=json.dumps(data),
headers={'Content-type': 'application/json',
'Accept': 'application/json'}).content

prediction = json.loads(response)

expected_shape = {
'dim': [
{'size': '1'},
{'size': '1024'},
{'size': '1024'},
{'size': '1'},
]
}

assert expected_shape == prediction['outputs']['y']['tensorShape']
assert 2.0 == prediction['outputs']['y']['floatVal'][-1]


def test_large_grpc_message_512m_pb2():
# this will generate request ~ 512mb
# (1024 * 1024 * 128 * 4 bytes) + overhead
# response will have same size, so we are testing both ends
# returning bytes (serialized pb2) instead of json, because
# our default json output function runs out of memory with
# much smaller messages (around 128MB on if gunicorn running in 8GB)
data = {
'shape': [1, 1024, 1024, 128],
'dtype': 'float32'
}

response = requests.post("http://localhost:8080/invocations",
data=json.dumps(data),
headers={'Content-Type': 'application/json',
'Accept': 'application/octet-stream'})

assert 200 == response.status_code
assert 512 * 1024 ** 2 <= int(response.headers['Content-Length'])
4 changes: 3 additions & 1 deletion test/integ/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,15 @@ def execute_pytest(self, tests_path):

class HostingContainer(Container):
def __init__(self, image, opt_ml, script_name, processor, requirements_file=None,
startup_delay=5):
startup_delay=5, region=None):
super(HostingContainer, self).__init__(image=image,
processor=processor,
startup_delay=startup_delay)
self.opt_ml = opt_ml
self.script_name = script_name
self.opt_ml = opt_ml
self.requirements_file = requirements_file
self.region = region

def __enter__(self):
cmd = [self.docker,
Expand All @@ -174,6 +175,7 @@ def __enter__(self):
'-e', 'AWS_ACCESS_KEY_ID',
'-e', 'AWS_SECRET_ACCESS_KEY',
'-e', 'AWS_SESSION_TOKEN',
'-e', 'SAGEMAKER_REGION={}'.format(self.region if self.region else ''),
'-e', 'SAGEMAKER_CONTAINER_LOG_LEVEL=20',
'-e', 'SAGEMAKER_PROGRAM={}'.format(self.script_name),
'-e', 'SAGEMAKER_REQUIREMENTS={}'.format(self.requirements_file),
Expand Down
54 changes: 54 additions & 0 deletions test/integ/test_large_grpc_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import os

import tensorflow as tf

from test.integ.docker_utils import HostingContainer
from test.integ.utils import copy_resource
from test.integ.conftest import SCRIPT_PATH


def create_model(export_dir):
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session() as session:
x = tf.placeholder(tf.float32, shape=[None, 1024, 1024, 1], name='x')
a = tf.constant(2.0)

y = tf.multiply(a, x, name='y')
predict_signature_def = (
tf.saved_model.signature_def_utils.predict_signature_def({
'x': x
}, {'y': y}))
signature_def_map = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
predict_signature_def
}
session.run(tf.global_variables_initializer())
builder.add_meta_graph_and_variables(
session, [tf.saved_model.tag_constants.SERVING],
signature_def_map=signature_def_map)
builder.save()


def test_large_grpc_message(docker_image, opt_ml, processor, region):
resource_path = os.path.join(SCRIPT_PATH, '../resources/large_grpc_message')
copy_resource(resource_path, opt_ml, 'code', 'code')
export_dir = os.path.join(opt_ml, 'model', 'export', 'Servo', '1')
create_model(export_dir)

with HostingContainer(opt_ml=opt_ml, image=docker_image,
script_name='inference.py',
processor=processor,
region=region) as c:
c.execute_pytest('test/integ/container_tests/large_grpc_message.py')
32 changes: 32 additions & 0 deletions test/resources/large_grpc_message/code/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json

import numpy as np


def input_fn(data, content_type):
"""
Args:
data: json string containing tensor shape info
e.g. {"shape": [1, 1024, 1024, 1], "dtype": "float32"}
content_type: ignored
Returns:
a dict with { 'x': numpy array with the specified shape and dtype }
"""

input_info = json.loads(data)
shape = input_info['shape']
dtype = getattr(np, input_info['dtype'])
return {'x': np.ones(shape, dtype)}
14 changes: 8 additions & 6 deletions test/unit/test_proxy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import pytest
from google import protobuf
from mock import MagicMock, patch, ANY
from tensorflow_serving.apis import prediction_service_pb2, get_model_metadata_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc, get_model_metadata_pb2

from tf_container.proxy_client import GRPCProxyClient
from tf_container.proxy_client import GRPCProxyClient, MAX_GRPC_MESSAGE_SIZE

REGRESSION = 'tensorflow/serving/regress'
INFERENCE = 'tensorflow/serving/inference'
Expand Down Expand Up @@ -360,14 +360,16 @@ def test_classification_protobuf(proxy_client):

@patch('tensorflow_serving.apis.get_model_metadata_pb2.SignatureDefMap')
@patch('tensorflow_serving.apis.get_model_metadata_pb2.GetModelMetadataRequest')
@patch('tensorflow_serving.apis.prediction_service_pb2.beta_create_PredictionService_stub')
@patch('grpc.beta.implementations.insecure_channel')
@patch('tensorflow_serving.apis.prediction_service_pb2_grpc.PredictionServiceStub')
@patch('grpc.insecure_channel')
def test_cache_prediction_metadata(channel, stub, request, signature_def_map, proxy_client):
proxy_client.cache_prediction_metadata()

channel.assert_called_once_with('localhost', DEFAULT_PORT)
channel.assert_called_once_with('localhost:{}'.format(DEFAULT_PORT), options=[
('grpc.max_send_message_length', MAX_GRPC_MESSAGE_SIZE),
('grpc.max_receive_message_length', MAX_GRPC_MESSAGE_SIZE)])

stub = prediction_service_pb2.beta_create_PredictionService_stub
stub = prediction_service_pb2_grpc.PredictionServiceStub
stub.assert_called_once_with(channel())

request = get_model_metadata_pb2.GetModelMetadataRequest
Expand Down