Skip to content

change: default model_fn and predict_fn in default handler #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 28, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,40 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import textwrap

import torch

from sagemaker_inference import content_types, decoder, default_inference_handler, encoder

INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
DEFAULT_MODEL_FILENAME = "model.pt"


class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler):
VALID_CONTENT_TYPES = (content_types.JSON, content_types.NPY)

def default_model_fn(self, model_dir):
"""Loads a model. For PyTorch, a default function to load a model cannot be provided.
Users should provide customized model_fn() in script.
"""Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used.
In other cases, users should provide customized model_fn() in script.

Args:
model_dir: a directory where model is saved.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the docstrings be updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring updated

Returns: A PyTorch model.
"""
raise NotImplementedError(textwrap.dedent("""
Please provide a model_fn implementation.
See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk
"""))
if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == 'true':
default_model_filename = "model.pt"
model_path = os.path.join(model_dir, default_model_filename)
if not os.path.exists(model_path):
raise FileNotFoundError('Failed to load model with default model_fn: missing file {}.'
.format(DEFAULT_MODEL_FILENAME))
return torch.jit.load(model_path)
else:
raise NotImplementedError(textwrap.dedent("""
Please provide a model_fn implementation.
See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk
"""))

def default_input_fn(self, input_data, content_type):
"""A default input_fn that can handle JSON, CSV and NPZ formats.
Expand Down Expand Up @@ -62,12 +73,20 @@ def default_predict_fn(self, data, model):

Returns: a prediction
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_data = data.to(device)
model.eval()
with torch.no_grad():
output = model(input_data)
if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == 'true':
device = torch.device('cpu')
model = model.to(device)
input_data = data.to(device)
model.eval()
with torch.jit.optimized_execution(True, {'target_device': 'eia:0'}):
output = model(input_data)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
input_data = data.to(device)
model.eval()
output = model(input_data)

return output

Expand Down
37 changes: 1 addition & 36 deletions test/resources/mnist/model_eia/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,4 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import logging
import os
import sys

import torch

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))


def predict_fn(input_data, model):
logger.info('Performing EIA inference with Torch JIT context with input of size {}'.format(input_data.shape))
# With EI, client instance should be CPU for cost-efficiency. Subgraphs with unsupported arguments run locally. Server runs with CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mdoel = model.to(device)
input_data = input_data.to(device)
with torch.no_grad():
# Set the target device to the accelerator ordinal
with torch.jit.optimized_execution(True, {'target_device': 'eia:0'}):
return model(input_data)


def model_fn(model_dir):
logger.info('model_fn: Loading model with TorchScript from {}'.format(model_dir))
# Scripted model is serialized with torch.jit.save().
# No need to instantiate model definition then load state_dict
model = torch.jit.load('model.pth')
return model


def save_model(model, model_dir):
logger.info("Saving the model to {}.".format(model_dir))
path = os.path.join(model_dir, 'model.pth')
torch.jit.save(model, path)
# This file is intentionally left blank to utilize default_model_fn and default_predict_fn