Skip to content

Migrate base inference toolkit scripts and unit tests #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ def read(fname):

# We don't declare our dependency on torch here because we build with
# different packages for different variants
install_requires=['numpy==1.24.4', 'retrying==1.3.4', 'sagemaker-inference==1.10.0'],
install_requires=['boto3==1.28.60', 'numpy==1.24.4', 'six==1.16.0',
'retrying==1.3.4', 'scipy==1.10.1', 'psutil==5.9.5'],
extras_require={
'test': ['boto3==1.28.60', 'coverage==7.3.2', 'docker-compose==1.29.2', 'flake8==6.1.0', 'Flask==3.0.0',
'mock==5.1.0', 'pytest==7.4.2', 'pytest-cov==4.1.0', 'pytest-xdist==3.3.1', 'PyYAML==5.4.1',
'sagemaker==2.125.0', 'six==1.16.0', 'requests==2.31.0',
'test': ['coverage==7.3.2', 'docker-compose==1.29.2', 'flake8==6.1.0', 'Flask==3.0.0',
'mock==5.1.0', 'pytest==7.4.2', 'pytest-cov==4.1.0', 'pytest-xdist==3.3.1',
'PyYAML==5.4.1', 'sagemaker==2.125.0', 'requests==2.31.0',
'requests_mock==1.11.0', 'torch==2.1.0', 'torchvision==0.16.0', 'tox==4.11.3']
},

Expand Down
Empty file.
22 changes: 22 additions & 0 deletions src/sagemaker_inference/content_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this package? Is it for backwards compatibility with older python versions?
Same applies to following files in which this line is added.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link offers details to the relevance of from __future__ import absolute_import:
https://portingguide.readthedocs.io/en/latest/imports.html

Also, this line is required for the flake8 checks.


"""This module contains constants that define MIME content types."""
JSON = "application/json"
CSV = "text/csv"
OCTET_STREAM = "application/octet-stream"
ANY = "*/*"
NPY = "application/x-npy"
NPZ = "application/x-npz"
UTF8_TYPES = [JSON, CSV]
109 changes: 109 additions & 0 deletions src/sagemaker_inference/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functionality for converting various types of
files and objects to NumPy arrays."""
from __future__ import absolute_import

import json

import numpy as np
import scipy.sparse
from six import BytesIO, StringIO

from sagemaker_inference import content_types, errors


def _json_to_numpy(string_like, dtype=None): # type: (str) -> np.array
"""Convert a JSON object to a numpy array.

Args:
string_like (str): JSON string.
dtype (dtype, optional): Data type of the resulting array.
If None, the dtypes will be determined by the contents
of each column, individually. This argument can only be
used to 'upcast' the array. For downcasting, use the
.astype(t) method.

Returns:
(np.array): numpy array
"""
data = json.loads(string_like)
return np.array(data, dtype=dtype)


def _csv_to_numpy(string_like, dtype=None): # type: (str) -> np.array
"""Convert a CSV object to a numpy array.

Args:
string_like (str): CSV string.
dtype (dtype, optional): Data type of the resulting array. If None,
the dtypes will be determined by the contents of each column,
individually. This argument can only be used to 'upcast' the array.
For downcasting, use the .astype(t) method.

Returns:
(np.array): numpy array
"""
stream = StringIO(string_like)
return np.genfromtxt(stream, dtype=dtype, delimiter=",")


def _npy_to_numpy(npy_array): # type: (object) -> np.array
"""Convert a NPY array into numpy.

Args:
npy_array (npy array): to be converted to numpy array

Returns:
(np.array): converted numpy array.
"""
stream = BytesIO(npy_array)
return np.load(stream, allow_pickle=True)


def _npz_to_sparse(npz_bytes): # type: (object) -> scipy.sparse.spmatrix
"""Convert .npz-formatted data to a sparse matrix.

Args:
npz_bytes (object): Bytes encoding a sparse matrix in the .npz format.

Returns:
(scipy.sparse.spmatrix): A sparse matrix.
"""
buffer = BytesIO(npz_bytes)
return scipy.sparse.load_npz(buffer)


_decoder_map = {
content_types.NPY: _npy_to_numpy,
content_types.CSV: _csv_to_numpy,
content_types.JSON: _json_to_numpy,
content_types.NPZ: _npz_to_sparse,
}


def decode(obj, content_type):
"""Decode an object that is encoded as one of the default content types.

Args:
obj (object): to be decoded.
content_type (str): content type to be used.

Returns:
object: decoded object for prediction.
"""
try:
decoder = _decoder_map[content_type]
return decoder(obj)
except KeyError:
raise errors.UnsupportedFormatError(content_type)
66 changes: 66 additions & 0 deletions src/sagemaker_inference/default_handler_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functionality for the default handler service."""
from __future__ import absolute_import

import os

from sagemaker_inference.transformer import Transformer

PYTHON_PATH_ENV = "PYTHONPATH"


class DefaultHandlerService(object):
"""Default handler service that is executed by the model server.

The handler service is responsible for defining an ``initialize`` and ``handle`` method.
- The ``handle`` method is invoked for all incoming inference requests to the model server.
- The ``initialize`` method is invoked at model server start up.

Implementation of: https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md
"""

def __init__(self, transformer=None):
self._service = transformer if transformer else Transformer()

def handle(self, data, context):
"""Handles an inference request with input data and makes a prediction.

Args:
data (obj): the request data.
context (obj): metadata on the incoming request data.

Returns:
list[obj]: The return value from the Transformer.transform method,
which is a serialized prediction result wrapped in a list if
inference is successful. Otherwise returns an error message
with the context set appropriately.

"""
return self._service.transform(data, context)

def initialize(self, context):
"""Calls the Transformer method that validates the user module against
the SageMaker inference contract.
"""
properties = context.system_properties
model_dir = properties.get("model_dir")

# add model_dir/code to python path
code_dir_path = "{}:".format(model_dir + "/code")
if PYTHON_PATH_ENV in os.environ:
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV]
else:
os.environ[PYTHON_PATH_ENV] = code_dir_path

self._service.validate_and_initialize(model_dir=model_dir, context=context)
97 changes: 97 additions & 0 deletions src/sagemaker_inference/default_inference_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains the definition of the default inference handler,
which provides a bare-bones implementation of default inference functions.
"""
from __future__ import absolute_import

import textwrap

from sagemaker_inference import decoder, encoder, errors, utils


class DefaultInferenceHandler(object):
"""Bare-bones implementation of default inference functions."""

def default_model_fn(self, model_dir, context=None):
"""Function responsible for loading the model.

Args:
model_dir (str): The directory where model files are stored.
context (obj): the request context (default: None).

Returns:
obj: the loaded model.

"""
raise NotImplementedError(
textwrap.dedent(
"""
Please provide a model_fn implementation.
See documentation for model_fn at https://sagemaker.readthedocs.io/en/stable/
"""
)
)

def default_input_fn(self, input_data, content_type, context=None):
# pylint: disable=unused-argument, no-self-use
"""Function responsible for deserializing the input data into an object for prediction.

Args:
input_data (obj): the request data.
content_type (str): the request content type.
context (obj): the request context (default: None).

Returns:
obj: data ready for prediction.

"""
return decoder.decode(input_data, content_type)

def default_predict_fn(self, data, model, context=None):
"""Function responsible for model predictions.

Args:
model (obj): model loaded by the model_fn.
data: deserialized data returned by the input_fn.
context (obj): the request context (default: None).

Returns:
obj: prediction result.

"""
raise NotImplementedError(
textwrap.dedent(
"""
Please provide a predict_fn implementation.
See documentation for predict_fn at https://sagemaker.readthedocs.io/en/stable/
"""
)
)

def default_output_fn(self, prediction, accept, context=None): # pylint: disable=no-self-use
"""Function responsible for serializing the prediction result to the desired accept type.

Args:
prediction (obj): prediction result returned by the predict_fn.
accept (str): accept header expected by the client.
context (obj): the request context (default: None).

Returns:
obj: prediction data.

"""
for content_type in utils.parse_accept(accept):
if content_type in encoder.SUPPORTED_CONTENT_TYPES:
return encoder.encode(prediction, content_type), content_type
raise errors.UnsupportedFormatError(accept)
Loading