Skip to content

Commit 7b9c5d7

Browse files
authored
change: default model_fn and predict_fn in default handler (#51)
* change: default model_fn and predict_fn in default handler
1 parent 63bafda commit 7b9c5d7

File tree

3 files changed

+74
-55
lines changed

3 files changed

+74
-55
lines changed

src/sagemaker_pytorch_serving_container/default_inference_handler.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,39 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import textwrap
1617

1718
import torch
18-
1919
from sagemaker_inference import content_types, decoder, default_inference_handler, encoder
2020

21+
INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
22+
DEFAULT_MODEL_FILENAME = "model.pt"
23+
2124

2225
class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler):
2326
VALID_CONTENT_TYPES = (content_types.JSON, content_types.NPY)
2427

2528
def default_model_fn(self, model_dir):
26-
"""Loads a model. For PyTorch, a default function to load a model cannot be provided.
27-
Users should provide customized model_fn() in script.
29+
"""Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used.
30+
In other cases, users should provide customized model_fn() in script.
2831
2932
Args:
3033
model_dir: a directory where model is saved.
3134
3235
Returns: A PyTorch model.
3336
"""
34-
raise NotImplementedError(textwrap.dedent("""
35-
Please provide a model_fn implementation.
36-
See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk
37-
"""))
37+
if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true":
38+
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
39+
if not os.path.exists(model_path):
40+
raise FileNotFoundError("Failed to load model with default model_fn: missing file {}."
41+
.format(DEFAULT_MODEL_FILENAME))
42+
return torch.jit.load(model_path)
43+
else:
44+
raise NotImplementedError(textwrap.dedent("""
45+
Please provide a model_fn implementation.
46+
See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk
47+
"""))
3848

3949
def default_input_fn(self, input_data, content_type):
4050
"""A default input_fn that can handle JSON, CSV and NPZ formats.
@@ -62,12 +72,20 @@ def default_predict_fn(self, data, model):
6272
6373
Returns: a prediction
6474
"""
65-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66-
model.to(device)
67-
input_data = data.to(device)
68-
model.eval()
6975
with torch.no_grad():
70-
output = model(input_data)
76+
if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true":
77+
device = torch.device("cpu")
78+
model = model.to(device)
79+
input_data = data.to(device)
80+
model.eval()
81+
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
82+
output = model(input_data)
83+
else:
84+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85+
model = model.to(device)
86+
input_data = data.to(device)
87+
model.eval()
88+
output = model(input_data)
7189

7290
return output
7391

test/resources/mnist/model_eia/mnist.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,39 +10,4 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from __future__ import absolute_import
14-
import logging
15-
import os
16-
import sys
17-
18-
import torch
19-
20-
logger = logging.getLogger(__name__)
21-
logger.setLevel(logging.DEBUG)
22-
logger.addHandler(logging.StreamHandler(sys.stdout))
23-
24-
25-
def predict_fn(input_data, model):
26-
logger.info('Performing EIA inference with Torch JIT context with input of size {}'.format(input_data.shape))
27-
# With EI, client instance should be CPU for cost-efficiency. Subgraphs with unsupported arguments run locally. Server runs with CUDA
28-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29-
mdoel = model.to(device)
30-
input_data = input_data.to(device)
31-
with torch.no_grad():
32-
# Set the target device to the accelerator ordinal
33-
with torch.jit.optimized_execution(True, {'target_device': 'eia:0'}):
34-
return model(input_data)
35-
36-
37-
def model_fn(model_dir):
38-
logger.info('model_fn: Loading model with TorchScript from {}'.format(model_dir))
39-
# Scripted model is serialized with torch.jit.save().
40-
# No need to instantiate model definition then load state_dict
41-
model = torch.jit.load('model.pth')
42-
return model
43-
44-
45-
def save_model(model, model_dir):
46-
logger.info("Saving the model to {}.".format(model_dir))
47-
path = os.path.join(model_dir, 'model.pth')
48-
torch.jit.save(model, path)
13+
# This file is intentionally left blank to utilize default_model_fn and default_predict_fn

test/unit/test_default_inference_handler.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import csv
1616
import json
1717

18+
import mock
1819
import numpy as np
1920
import pytest
2021
import torch
@@ -40,7 +41,7 @@ def __call__(self, tensor):
4041
return 3 * tensor
4142

4243

43-
@pytest.fixture(scope='session', name='tensor')
44+
@pytest.fixture(scope="session", name="tensor")
4445
def fixture_tensor():
4546
tensor = torch.rand(5, 10, 7, 9)
4647
return tensor.to(device)
@@ -51,9 +52,14 @@ def inference_handler():
5152
return default_inference_handler.DefaultPytorchInferenceHandler()
5253

5354

55+
@pytest.fixture()
56+
def eia_inference_handler():
57+
return default_inference_handler.DefaultPytorchInferenceHandler()
58+
59+
5460
def test_default_model_fn(inference_handler):
5561
with pytest.raises(NotImplementedError):
56-
inference_handler.default_model_fn('model_dir')
62+
inference_handler.default_model_fn("model_dir")
5763

5864

5965
def test_default_input_fn_json(inference_handler, tensor):
@@ -67,7 +73,7 @@ def test_default_input_fn_json(inference_handler, tensor):
6773
def test_default_input_fn_csv(inference_handler):
6874
array = [[1, 2, 3], [4, 5, 6]]
6975
str_io = StringIO()
70-
csv.writer(str_io, delimiter=',').writerows(array)
76+
csv.writer(str_io, delimiter=",").writerows(array)
7177

7278
deserialized_np_array = inference_handler.default_input_fn(str_io.getvalue(), content_types.CSV)
7379

@@ -78,7 +84,7 @@ def test_default_input_fn_csv(inference_handler):
7884

7985
def test_default_input_fn_csv_bad_columns(inference_handler):
8086
str_io = StringIO()
81-
csv_writer = csv.writer(str_io, delimiter=',')
87+
csv_writer = csv.writer(str_io, delimiter=",")
8288
csv_writer.writerow([1, 2, 3])
8389
csv_writer.writerow([1, 2, 3, 4])
8490

@@ -97,7 +103,7 @@ def test_default_input_fn_npy(inference_handler, tensor):
97103

98104
def test_default_input_fn_bad_content_type(inference_handler):
99105
with pytest.raises(errors.UnsupportedFormatError):
100-
inference_handler.default_input_fn('', 'application/not_supported')
106+
inference_handler.default_input_fn("", "application/not_supported")
101107

102108

103109
def test_default_predict_fn(inference_handler, tensor):
@@ -162,7 +168,7 @@ def test_default_output_fn_csv_float(inference_handler):
162168

163169
def test_default_output_fn_bad_accept(inference_handler):
164170
with pytest.raises(errors.UnsupportedFormatError):
165-
inference_handler.default_output_fn('', 'application/not_supported')
171+
inference_handler.default_output_fn("", "application/not_supported")
166172

167173

168174
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
@@ -171,4 +177,34 @@ def test_default_output_fn_gpu(inference_handler):
171177

172178
output = inference_handler.default_output_fn(tensor_gpu, content_types.CSV)
173179

174-
assert '1,2,3\n4,5,6\n'.encode("utf-8") == output
180+
assert "1,2,3\n4,5,6\n".encode("utf-8") == output
181+
182+
183+
def test_eia_default_model_fn(eia_inference_handler):
184+
with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os:
185+
mock_os.getenv.return_value = "true"
186+
mock_os.path.join.return_value = "model_dir"
187+
mock_os.path.exists.return_value = True
188+
with mock.patch("torch.jit.load") as mock_torch:
189+
mock_torch.return_value = DummyModel()
190+
model = eia_inference_handler.default_model_fn("model_dir")
191+
assert model is not None
192+
193+
194+
def test_eia_default_model_fn_error(eia_inference_handler):
195+
with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os:
196+
mock_os.getenv.return_value = "true"
197+
mock_os.path.join.return_value = "model_dir"
198+
mock_os.path.exists.return_value = False
199+
with pytest.raises(FileNotFoundError):
200+
eia_inference_handler.default_model_fn("model_dir")
201+
202+
203+
def test_eia_default_predict_fn(eia_inference_handler, tensor):
204+
model = DummyModel()
205+
with mock.patch("sagemaker_pytorch_serving_container.default_inference_handler.os") as mock_os:
206+
mock_os.getenv.return_value = "true"
207+
with mock.patch("torch.jit.optimized_execution") as mock_torch:
208+
mock_torch.__enter__.return_value = "dummy"
209+
eia_inference_handler.default_predict_fn(tensor, model)
210+
mock_torch.assert_called_once()

0 commit comments

Comments
 (0)