diff --git a/src/tf_container/proxy_client.py b/src/tf_container/proxy_client.py index 0d0d671f..9dab0934 100644 --- a/src/tf_container/proxy_client.py +++ b/src/tf_container/proxy_client.py @@ -14,7 +14,7 @@ import numpy as np from google.protobuf import json_format -from grpc.beta import implementations +import grpc from tensorflow import make_tensor_proto from tensorflow.core.example import example_pb2, feature_pb2 from tensorflow.core.framework import tensor_pb2 @@ -22,13 +22,14 @@ PREDICT_INPUTS from tensorflow_serving.apis import get_model_metadata_pb2 from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2 -from tensorflow_serving.apis import prediction_service_pb2 +from tensorflow_serving.apis import prediction_service_pb2_grpc from tf_container.run import logger as _logger INFERENCE_ACCELERATOR_PRESENT_ENV = 'SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT' TF_SERVING_GRPC_REQUEST_TIMEOUT_ENV = 'SAGEMAKER_TFS_GRPC_REQUEST_TIMEOUT' +MAX_GRPC_MESSAGE_SIZE = 1024 ** 3 * 2 - 1 # 2GB - 1 DEFAULT_GRPC_REQUEST_TIMEOUT_FOR_INFERENCE_ACCELERATOR = 30.0 REGRESSION = 'tensorflow/serving/regress' @@ -82,8 +83,12 @@ def request(self, data): return request_fn(data) def cache_prediction_metadata(self): - channel = implementations.insecure_channel(self.host, self.tf_serving_port) - stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) + channel = grpc.insecure_channel( + '{}:{}'.format(self.host, self.tf_serving_port), + options=[ + ('grpc.max_send_message_length', MAX_GRPC_MESSAGE_SIZE), + ('grpc.max_receive_message_length', MAX_GRPC_MESSAGE_SIZE)]) + stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) request = get_model_metadata_pb2.GetModelMetadataRequest() request.model_spec.name = self.model_name diff --git a/test/integ/container_tests/large_grpc_message.py b/test/integ/container_tests/large_grpc_message.py new file mode 100644 index 00000000..e697d16a --- /dev/null +++ b/test/integ/container_tests/large_grpc_message.py @@ -0,0 +1,67 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import json + +import requests + + +def test_grpc_message_4m_json(): + # this will generate a request just over the original + # 4MB limit (1024 * 1024 * 1 * 4 bytes) + overhead + # this matches the message size from the original issue report + # response will have same size, so we are testing both ends + data = { + 'shape': [1, 1024, 1024, 1], + 'dtype': 'float32' + } + + response = requests.post("http://localhost:8080/invocations", + data=json.dumps(data), + headers={'Content-type': 'application/json', + 'Accept': 'application/json'}).content + + prediction = json.loads(response) + + expected_shape = { + 'dim': [ + {'size': '1'}, + {'size': '1024'}, + {'size': '1024'}, + {'size': '1'}, + ] + } + + assert expected_shape == prediction['outputs']['y']['tensorShape'] + assert 2.0 == prediction['outputs']['y']['floatVal'][-1] + + +def test_large_grpc_message_512m_pb2(): + # this will generate request ~ 512mb + # (1024 * 1024 * 128 * 4 bytes) + overhead + # response will have same size, so we are testing both ends + # returning bytes (serialized pb2) instead of json, because + # our default json output function runs out of memory with + # much smaller messages (around 128MB on if gunicorn running in 8GB) + data = { + 'shape': [1, 1024, 1024, 128], + 'dtype': 'float32' + } + + response = requests.post("http://localhost:8080/invocations", + data=json.dumps(data), + headers={'Content-Type': 'application/json', + 'Accept': 'application/octet-stream'}) + + assert 200 == response.status_code + assert 512 * 1024 ** 2 <= int(response.headers['Content-Length']) diff --git a/test/integ/docker_utils.py b/test/integ/docker_utils.py index 9328efee..5de1a729 100644 --- a/test/integ/docker_utils.py +++ b/test/integ/docker_utils.py @@ -156,7 +156,7 @@ def execute_pytest(self, tests_path): class HostingContainer(Container): def __init__(self, image, opt_ml, script_name, processor, requirements_file=None, - startup_delay=5): + startup_delay=5, region=None): super(HostingContainer, self).__init__(image=image, processor=processor, startup_delay=startup_delay) @@ -164,6 +164,7 @@ def __init__(self, image, opt_ml, script_name, processor, requirements_file=None self.script_name = script_name self.opt_ml = opt_ml self.requirements_file = requirements_file + self.region = region def __enter__(self): cmd = [self.docker, @@ -174,6 +175,7 @@ def __enter__(self): '-e', 'AWS_ACCESS_KEY_ID', '-e', 'AWS_SECRET_ACCESS_KEY', '-e', 'AWS_SESSION_TOKEN', + '-e', 'SAGEMAKER_REGION={}'.format(self.region if self.region else ''), '-e', 'SAGEMAKER_CONTAINER_LOG_LEVEL=20', '-e', 'SAGEMAKER_PROGRAM={}'.format(self.script_name), '-e', 'SAGEMAKER_REQUIREMENTS={}'.format(self.requirements_file), diff --git a/test/integ/test_large_grpc_message.py b/test/integ/test_large_grpc_message.py new file mode 100644 index 00000000..e94aed43 --- /dev/null +++ b/test/integ/test_large_grpc_message.py @@ -0,0 +1,54 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +import os + +import tensorflow as tf + +from test.integ.docker_utils import HostingContainer +from test.integ.utils import copy_resource +from test.integ.conftest import SCRIPT_PATH + + +def create_model(export_dir): + builder = tf.saved_model.builder.SavedModelBuilder(export_dir) + with tf.Session() as session: + x = tf.placeholder(tf.float32, shape=[None, 1024, 1024, 1], name='x') + a = tf.constant(2.0) + + y = tf.multiply(a, x, name='y') + predict_signature_def = ( + tf.saved_model.signature_def_utils.predict_signature_def({ + 'x': x + }, {'y': y})) + signature_def_map = { + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + predict_signature_def + } + session.run(tf.global_variables_initializer()) + builder.add_meta_graph_and_variables( + session, [tf.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map) + builder.save() + + +def test_large_grpc_message(docker_image, opt_ml, processor, region): + resource_path = os.path.join(SCRIPT_PATH, '../resources/large_grpc_message') + copy_resource(resource_path, opt_ml, 'code', 'code') + export_dir = os.path.join(opt_ml, 'model', 'export', 'Servo', '1') + create_model(export_dir) + + with HostingContainer(opt_ml=opt_ml, image=docker_image, + script_name='inference.py', + processor=processor, + region=region) as c: + c.execute_pytest('test/integ/container_tests/large_grpc_message.py') diff --git a/test/resources/large_grpc_message/code/inference.py b/test/resources/large_grpc_message/code/inference.py new file mode 100644 index 00000000..1aa81c33 --- /dev/null +++ b/test/resources/large_grpc_message/code/inference.py @@ -0,0 +1,32 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. + +import json + +import numpy as np + + +def input_fn(data, content_type): + """ + Args: + data: json string containing tensor shape info + e.g. {"shape": [1, 1024, 1024, 1], "dtype": "float32"} + content_type: ignored + Returns: + a dict with { 'x': numpy array with the specified shape and dtype } + """ + + input_info = json.loads(data) + shape = input_info['shape'] + dtype = getattr(np, input_info['dtype']) + return {'x': np.ones(shape, dtype)} diff --git a/test/unit/test_proxy_client.py b/test/unit/test_proxy_client.py index 468c206b..1550e8b1 100644 --- a/test/unit/test_proxy_client.py +++ b/test/unit/test_proxy_client.py @@ -17,9 +17,9 @@ import pytest from google import protobuf from mock import MagicMock, patch, ANY -from tensorflow_serving.apis import prediction_service_pb2, get_model_metadata_pb2 +from tensorflow_serving.apis import prediction_service_pb2_grpc, get_model_metadata_pb2 -from tf_container.proxy_client import GRPCProxyClient +from tf_container.proxy_client import GRPCProxyClient, MAX_GRPC_MESSAGE_SIZE REGRESSION = 'tensorflow/serving/regress' INFERENCE = 'tensorflow/serving/inference' @@ -360,14 +360,16 @@ def test_classification_protobuf(proxy_client): @patch('tensorflow_serving.apis.get_model_metadata_pb2.SignatureDefMap') @patch('tensorflow_serving.apis.get_model_metadata_pb2.GetModelMetadataRequest') -@patch('tensorflow_serving.apis.prediction_service_pb2.beta_create_PredictionService_stub') -@patch('grpc.beta.implementations.insecure_channel') +@patch('tensorflow_serving.apis.prediction_service_pb2_grpc.PredictionServiceStub') +@patch('grpc.insecure_channel') def test_cache_prediction_metadata(channel, stub, request, signature_def_map, proxy_client): proxy_client.cache_prediction_metadata() - channel.assert_called_once_with('localhost', DEFAULT_PORT) + channel.assert_called_once_with('localhost:{}'.format(DEFAULT_PORT), options=[ + ('grpc.max_send_message_length', MAX_GRPC_MESSAGE_SIZE), + ('grpc.max_receive_message_length', MAX_GRPC_MESSAGE_SIZE)]) - stub = prediction_service_pb2.beta_create_PredictionService_stub + stub = prediction_service_pb2_grpc.PredictionServiceStub stub.assert_called_once_with(channel()) request = get_model_metadata_pb2.GetModelMetadataRequest