diff --git a/src/sagemaker_pytorch_serving_container/handler_service.py b/src/sagemaker_pytorch_serving_container/handler_service.py index 07e81c7b..518dc231 100644 --- a/src/sagemaker_pytorch_serving_container/handler_service.py +++ b/src/sagemaker_pytorch_serving_container/handler_service.py @@ -13,7 +13,7 @@ from __future__ import absolute_import from sagemaker_inference.default_handler_service import DefaultHandlerService -from sagemaker_inference.transformer import Transformer +from sagemaker_pytorch_serving_container.transformer import PyTorchTransformer from sagemaker_pytorch_serving_container.default_pytorch_inference_handler import DefaultPytorchInferenceHandler import os @@ -38,7 +38,7 @@ class HandlerService(DefaultHandlerService): def __init__(self): self._initialized = False - transformer = Transformer(default_inference_handler=DefaultPytorchInferenceHandler()) + transformer = PyTorchTransformer(default_inference_handler=DefaultPytorchInferenceHandler()) super(HandlerService, self).__init__(transformer=transformer) def initialize(self, context): diff --git a/src/sagemaker_pytorch_serving_container/torchserve.py b/src/sagemaker_pytorch_serving_container/torchserve.py index 048a06b4..64211221 100644 --- a/src/sagemaker_pytorch_serving_container/torchserve.py +++ b/src/sagemaker_pytorch_serving_container/torchserve.py @@ -138,7 +138,7 @@ def _generate_ts_config_properties(handler_service): if ts_env.is_env_set() and not ENABLE_MULTI_MODEL: models_string = f'''{{\\ "{DEFAULT_TS_MODEL_NAME}": {{\\ - "1.0": {{\\ + "1": {{\\ "defaultVersion": true,\\ "marName": "{DEFAULT_TS_MODEL_NAME}.mar",\\ "minWorkers": {ts_env._min_workers},\\ diff --git a/src/sagemaker_pytorch_serving_container/transformer.py b/src/sagemaker_pytorch_serving_container/transformer.py new file mode 100644 index 00000000..5550e2a1 --- /dev/null +++ b/src/sagemaker_pytorch_serving_container/transformer.py @@ -0,0 +1,70 @@ +from six.moves import http_client + +from sagemaker_inference.transformer import Transformer +from sagemaker_inference import content_types, environment, utils +from sagemaker_inference.default_inference_handler import DefaultInferenceHandler +from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError + + +class PyTorchTransformer(Transformer): + def transform(self, data, context): + """ + Take a request with input data, deserialize it, make a prediction, and return a + serialized response. + + Args: + data (obj): the request data. + context (obj): metadata on the incoming request data. + + Returns: + list[obj]: The serialized prediction result wrapped in a list if + inference is successful. Otherwise returns an error message + with the context set appropriately. + """ + try: + properties = context.system_properties + model_dir = properties.get("model_dir") + self.validate_and_initialize(model_dir=model_dir) + + response_list = [] + + for _ in range(len(data)): + input_data = data[i].get("body") + + request_processor = context.request_processor[0] + + request_property = request_processor.get_request_properties() + content_type = utils.retrieve_content_type_header(request_property) + accept = request_property.get("Accept") or request_property.get("accept") + + if not accept or accept == content_types.ANY: + accept = self._environment.default_accept + + if content_type in content_types.UTF8_TYPES: + input_data = input_data.decode("utf-8") + + result = self._transform_fn(self._model, input_data, content_type, accept) + + response = result + response_content_type = accept + + if isinstance(result, tuple): + # handles tuple for backwards compatibility + response = result[0] + response_content_type = result[1] + + context.set_response_content_type(0, response_content_type) + + response_list.append(response) + + return response_list + except Exception as e: # pylint: disable=broad-except + trace = traceback.format_exc() + if isinstance(e, BaseInferenceToolkitError): + return self.handle_error(context, e, trace) + else: + return self.handle_error( + context, + GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)), + trace, + )