From 13a60af65322b4c1dbb90da5fe301472ac3baa1d Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 11 May 2022 07:57:14 +0000 Subject: [PATCH 1/5] Return batch in transformer --- .../handler_service.py | 5 +- .../transformer.py | 69 +++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 src/sagemaker_pytorch_serving_container/transformer.py diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index 07e81c7b..7d547f3c 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -13,7 +13,8 @@ from __future__ import absolute_import from sagemaker_inference.default_handler_service import DefaultHandlerService -from sagemaker_inference.transformer import Transformer +#from sagemaker_inference.transformer import Transformer +from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler import os @@ -38,7 +39,7 @@ class HandlerService(DefaultHandlerService): def __init__(self): self._initialized = False - transformer = Transformer(default_inference_handler=DefaultPytorchInferenceHandler()) + transformer = PyTorchTransformer(default_inference_handler=DefaultPytorchInferenceHandler()) super(HandlerService, self).__init__(transformer=transformer) def initialize(self, context): diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py new file mode 100644 index 00000000..19cab4ee --- /dev/null +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -0,0 +1,69 @@ +from six.moves import http_client + +from sagemaker_inference.transformer import Transformer +from sagemaker_inference import content_types, environment, utils +from sagemaker_inference.default_inference_handler import DefaultInferenceHandler +from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError + +class PyTorchTransformer(Transformer): + def transform(self, data, context): + """Take a request with input data, deserialize it, make a prediction, and return a + serialized response. + + Args: + data (obj): the request data. + context (obj): metadata on the incoming request data. + + Returns: + list[obj]: The serialized prediction result wrapped in a list if + inference is successful. Otherwise returns an error message + with the context set appropriately. + """ + try: + properties = context.system_properties + model_dir = properties.get("model_dir") + self.validate_and_initialize(model_dir=model_dir) + + response_list = [] + + for i in range(len(data)): + print(f"Processing Data: {data[i]}") + input_data = data[i].get("body") + + request_processor = context.request_processor[0] + + request_property = request_processor.get_request_properties() + content_type = utils.retrieve_content_type_header(request_property) + accept = request_property.get("Accept") or request_property.get("accept") + + if not accept or accept == content_types.ANY: + accept = self._environment.default_accept + + if content_type in content_types.UTF8_TYPES: + input_data = input_data.decode("utf-8") + + result = self._transform_fn(self._model, input_data, content_type, accept) + + response = result + response_content_type = accept + + if isinstance(result, tuple): + # handles tuple for backwards compatibility + response = result[0] + response_content_type = result[1] + + context.set_response_content_type(0, response_content_type) + + response_list.append(response) + + return response_list + except Exception as e: # pylint: disable=broad-except + trace = traceback.format_exc() + if isinstance(e, BaseInferenceToolkitError): + return self.handle_error(context, e, trace) + else: + return self.handle_error( + context, + GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)), + trace, + ) \ No newline at end of file From 96880a6da3dcbcf574a0ef3b2166901b577aa110 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 11 May 2022 10:02:04 +0000 Subject: [PATCH 2/5] Format --- src/sagemaker_pytorch_serving_container/transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py index 19cab4ee..741e80d8 100644 --- a/src/sagemaker_pytorch_serving_container/transformer.py +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -5,9 +5,11 @@ from sagemaker_inference.default_inference_handler import DefaultInferenceHandler from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError + class PyTorchTransformer(Transformer): - def transform(self, data, context): - """Take a request with input data, deserialize it, make a prediction, and return a + def transform(self, data, context): + """ + Take a request with input data, deserialize it, make a prediction, and return a serialized response. Args: @@ -66,4 +68,4 @@ def transform(self, data, context): context, GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)), trace, - ) \ No newline at end of file + ) From cb9739b59612e1ac6fc279dc2e0e18872f1e002b Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 11 May 2022 10:26:51 +0000 Subject: [PATCH 3/5] Remove *.mar --- src/sagemaker_pytorch_serving_container/torchserve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker_pytorch_serving_container/torchserve.py b/src/sagemaker_pytorch_serving_container/torchserve.py index 048a06b4..701f61aa 100644 --- a/src/sagemaker_pytorch_serving_container/torchserve.py +++ b/src/sagemaker_pytorch_serving_container/torchserve.py @@ -140,7 +140,6 @@ def _generate_ts_config_properties(handler_service): "{DEFAULT_TS_MODEL_NAME}": {{\\ "1.0": {{\\ "defaultVersion": true,\\ - "marName": "{DEFAULT_TS_MODEL_NAME}.mar",\\ "minWorkers": {ts_env._min_workers},\\ "maxWorkers": {ts_env._max_workers},\\ "batchSize": {ts_env._batch_size},\\ From e5bcd46b9290157bce76839cf84c2a528da9545f Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Wed, 11 May 2022 14:16:21 +0000 Subject: [PATCH 4/5] Add *.mar back again for testing --- src/sagemaker_pytorch_serving_container/torchserve.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker_pytorch_serving_container/torchserve.py b/src/sagemaker_pytorch_serving_container/torchserve.py index 701f61aa..048a06b4 100644 --- a/src/sagemaker_pytorch_serving_container/torchserve.py +++ b/src/sagemaker_pytorch_serving_container/torchserve.py @@ -140,6 +140,7 @@ def _generate_ts_config_properties(handler_service): "{DEFAULT_TS_MODEL_NAME}": {{\\ "1.0": {{\\ "defaultVersion": true,\\ + "marName": "{DEFAULT_TS_MODEL_NAME}.mar",\\ "minWorkers": {ts_env._min_workers},\\ "maxWorkers": {ts_env._max_workers},\\ "batchSize": {ts_env._batch_size},\\ From 4eacfb6dc7559f990e303e522419afb3305ce930 Mon Sep 17 00:00:00 2001 From: Nikhil Kulkarni Date: Fri, 22 Jul 2022 00:03:41 +0000 Subject: [PATCH 5/5] Remove log --- src/sagemaker_pytorch_serving_container/handler_service.py | 1 - src/sagemaker_pytorch_serving_container/torchserve.py | 2 +- src/sagemaker_pytorch_serving_container/transformer.py | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index 7d547f3c..518dc231 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -13,7 +13,6 @@ from __future__ import absolute_import from sagemaker_inference.default_handler_service import DefaultHandlerService -#from sagemaker_inference.transformer import Transformer from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler diff --git a/src/sagemaker_pytorch_serving_container/torchserve.py b/src/sagemaker_pytorch_serving_container/torchserve.py index 048a06b4..64211221 100644 --- a/src/sagemaker_pytorch_serving_container/torchserve.py +++ b/src/sagemaker_pytorch_serving_container/torchserve.py @@ -138,7 +138,7 @@ def _generate_ts_config_properties(handler_service): if ts_env.is_env_set() and not ENABLE_MULTI_MODEL: models_string = f'''{{\\ "{DEFAULT_TS_MODEL_NAME}": {{\\ - "1.0": {{\\ + "1": {{\\ "defaultVersion": true,\\ "marName": "{DEFAULT_TS_MODEL_NAME}.mar",\\ "minWorkers": {ts_env._min_workers},\\ diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py index 741e80d8..5550e2a1 100644 --- a/src/sagemaker_pytorch_serving_container/transformer.py +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -28,8 +28,7 @@ def transform(self, data, context): response_list = [] - for i in range(len(data)): - print(f"Processing Data: {data[i]}") + for _ in range(len(data)): input_data = data[i].get("body") request_processor = context.request_processor[0]