diff --git a/src/sagemaker_pytorch_serving_container/default_inference_handler.py b/src/sagemaker_pytorch_serving_container/default_inference_handler.py index d9b2558d..176d3afc 100644 --- a/src/sagemaker_pytorch_serving_container/default_inference_handler.py +++ b/src/sagemaker_pytorch_serving_container/default_inference_handler.py @@ -12,29 +12,39 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import os import textwrap import torch - from sagemaker_inference import content_types, decoder, default_inference_handler, encoder +INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT" +DEFAULT_MODEL_FILENAME = "model.pt" + class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler): VALID_CONTENT_TYPES = (content_types.JSON, content_types.NPY) def default_model_fn(self, model_dir): - """Loads a model. For PyTorch, a default function to load a model cannot be provided. - Users should provide customized model_fn() in script. + """Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used. + In other cases, users should provide customized model_fn() in script. Args: model_dir: a directory where model is saved. Returns: A PyTorch model. """ - raise NotImplementedError(textwrap.dedent(""" - Please provide a model_fn implementation. - See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk - """)) + if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true": + model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME) + if not os.path.exists(model_path): + raise FileNotFoundError("Failed to load model with default model_fn: missing file {}." + .format(DEFAULT_MODEL_FILENAME)) + return torch.jit.load(model_path) + else: + raise NotImplementedError(textwrap.dedent(""" + Please provide a model_fn implementation. + See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk + """)) def default_input_fn(self, input_data, content_type): """A default input_fn that can handle JSON, CSV and NPZ formats. @@ -62,12 +72,20 @@ def default_predict_fn(self, data, model): Returns: a prediction """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - input_data = data.to(device) - model.eval() with torch.no_grad(): - output = model(input_data) + if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true": + device = torch.device("cpu") + model = model.to(device) + input_data = data.to(device) + model.eval() + with torch.jit.optimized_execution(True, {"target_device": "eia:0"}): + output = model(input_data) + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + input_data = data.to(device) + model.eval() + output = model(input_data) return output diff --git a/test/resources/mnist/model_eia/mnist.py b/test/resources/mnist/model_eia/mnist.py index ebc0bff0..d151a3f3 100644 --- a/test/resources/mnist/model_eia/mnist.py +++ b/test/resources/mnist/model_eia/mnist.py @@ -10,39 +10,4 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from __future__ import absolute_import -import logging -import os -import sys - -import torch - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) -logger.addHandler(logging.StreamHandler(sys.stdout)) - - -def predict_fn(input_data, model): - logger.info('Performing EIA inference with Torch JIT context with input of size {}'.format(input_data.shape)) - # With EI, client instance should be CPU for cost-efficiency. Subgraphs with unsupported arguments run locally. Server runs with CUDA - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - mdoel = model.to(device) - input_data = input_data.to(device) - with torch.no_grad(): - # Set the target device to the accelerator ordinal - with torch.jit.optimized_execution(True, {'target_device': 'eia:0'}): - return model(input_data) - - -def model_fn(model_dir): - logger.info('model_fn: Loading model with TorchScript from {}'.format(model_dir)) - # Scripted model is serialized with torch.jit.save(). - # No need to instantiate model definition then load state_dict - model = torch.jit.load('model.pth') - return model - - -def save_model(model, model_dir): - logger.info("Saving the model to {}.".format(model_dir)) - path = os.path.join(model_dir, 'model.pth') - torch.jit.save(model, path) +# This file is intentionally left blank to utilize default_model_fn and default_predict_fn diff --git a/test/unit/test_default_inference_handler.py b/test/unit/test_default_inference_handler.py index 818c6a40..e43e026e 100644 --- a/test/unit/test_default_inference_handler.py +++ b/test/unit/test_default_inference_handler.py @@ -15,6 +15,7 @@ import csv import json +import mock import numpy as np import pytest import torch @@ -40,7 +41,7 @@ def __call__(self, tensor): return 3 * tensor -@pytest.fixture(scope='session', name='tensor') +@pytest.fixture(scope="session", name="tensor") def fixture_tensor(): tensor = torch.rand(5, 10, 7, 9) return tensor.to(device) @@ -51,9 +52,14 @@ def inference_handler(): return default_inference_handler.DefaultPytorchInferenceHandler() +@pytest.fixture() +def eia_inference_handler(): + return default_inference_handler.DefaultPytorchInferenceHandler() + + def test_default_model_fn(inference_handler): with pytest.raises(NotImplementedError): - inference_handler.default_model_fn('model_dir') + inference_handler.default_model_fn("model_dir") def test_default_input_fn_json(inference_handler, tensor): @@ -67,7 +73,7 @@ def test_default_input_fn_json(inference_handler, tensor): def test_default_input_fn_csv(inference_handler): array = [[1, 2, 3], [4, 5, 6]] str_io = StringIO() - csv.writer(str_io, delimiter=',').writerows(array) + csv.writer(str_io, delimiter=",").writerows(array) deserialized_np_array = inference_handler.default_input_fn(str_io.getvalue(), content_types.CSV) @@ -78,7 +84,7 @@ def test_default_input_fn_csv(inference_handler): def test_default_input_fn_csv_bad_columns(inference_handler): str_io = StringIO() - csv_writer = csv.writer(str_io, delimiter=',') + csv_writer = csv.writer(str_io, delimiter=",") csv_writer.writerow([1, 2, 3]) csv_writer.writerow([1, 2, 3, 4]) @@ -97,7 +103,7 @@ def test_default_input_fn_npy(inference_handler, tensor): def test_default_input_fn_bad_content_type(inference_handler): with pytest.raises(errors.UnsupportedFormatError): - inference_handler.default_input_fn('', 'application/not_supported') + inference_handler.default_input_fn("", "application/not_supported") def test_default_predict_fn(inference_handler, tensor): @@ -162,7 +168,7 @@ def test_default_output_fn_csv_float(inference_handler): def test_default_output_fn_bad_accept(inference_handler): with pytest.raises(errors.UnsupportedFormatError): - inference_handler.default_output_fn('', 'application/not_supported') + inference_handler.default_output_fn("", "application/not_supported") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") @@ -171,4 +177,34 @@ def test_default_output_fn_gpu(inference_handler): output = inference_handler.default_output_fn(tensor_gpu, content_types.CSV) - assert '1,2,3\n4,5,6\n'.encode("utf-8") == output + assert "1,2,3\n4,5,6\n".encode("utf-8") == output + + +def test_eia_default_model_fn(eia_inference_handler): + with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "true" + mock_os.path.join.return_value = "model_dir" + mock_os.path.exists.return_value = True + with mock.patch("torch.jit.load") as mock_torch: + mock_torch.return_value = DummyModel() + model = eia_inference_handler.default_model_fn("model_dir") + assert model is not None + + +def test_eia_default_model_fn_error(eia_inference_handler): + with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "true" + mock_os.path.join.return_value = "model_dir" + mock_os.path.exists.return_value = False + with pytest.raises(FileNotFoundError): + eia_inference_handler.default_model_fn("model_dir") + + +def test_eia_default_predict_fn(eia_inference_handler, tensor): + model = DummyModel() + with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "true" + with mock.patch("torch.jit.optimized_execution") as mock_torch: + mock_torch.__enter__.return_value = "dummy" + eia_inference_handler.default_predict_fn(tensor, model) + mock_torch.assert_called_once()