-
Notifications
You must be signed in to change notification settings - Fork 72
Migrate base inference toolkit scripts and unit tests #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the 'License'). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the 'license' file accompanying this file. This file is | ||
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
"""This module contains constants that define MIME content types.""" | ||
JSON = "application/json" | ||
CSV = "text/csv" | ||
OCTET_STREAM = "application/octet-stream" | ||
ANY = "*/*" | ||
NPY = "application/x-npy" | ||
NPZ = "application/x-npz" | ||
UTF8_TYPES = [JSON, CSV] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the 'License'). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the 'license' file accompanying this file. This file is | ||
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains functionality for converting various types of | ||
files and objects to NumPy arrays.""" | ||
from __future__ import absolute_import | ||
|
||
import json | ||
|
||
import numpy as np | ||
import scipy.sparse | ||
from six import BytesIO, StringIO | ||
|
||
from sagemaker_inference import content_types, errors | ||
|
||
|
||
def _json_to_numpy(string_like, dtype=None): # type: (str) -> np.array | ||
"""Convert a JSON object to a numpy array. | ||
|
||
Args: | ||
string_like (str): JSON string. | ||
dtype (dtype, optional): Data type of the resulting array. | ||
If None, the dtypes will be determined by the contents | ||
of each column, individually. This argument can only be | ||
used to 'upcast' the array. For downcasting, use the | ||
.astype(t) method. | ||
|
||
Returns: | ||
(np.array): numpy array | ||
""" | ||
data = json.loads(string_like) | ||
return np.array(data, dtype=dtype) | ||
|
||
|
||
def _csv_to_numpy(string_like, dtype=None): # type: (str) -> np.array | ||
"""Convert a CSV object to a numpy array. | ||
|
||
Args: | ||
string_like (str): CSV string. | ||
dtype (dtype, optional): Data type of the resulting array. If None, | ||
the dtypes will be determined by the contents of each column, | ||
individually. This argument can only be used to 'upcast' the array. | ||
For downcasting, use the .astype(t) method. | ||
|
||
Returns: | ||
(np.array): numpy array | ||
""" | ||
stream = StringIO(string_like) | ||
return np.genfromtxt(stream, dtype=dtype, delimiter=",") | ||
|
||
|
||
def _npy_to_numpy(npy_array): # type: (object) -> np.array | ||
"""Convert a NPY array into numpy. | ||
|
||
Args: | ||
npy_array (npy array): to be converted to numpy array | ||
|
||
Returns: | ||
(np.array): converted numpy array. | ||
""" | ||
stream = BytesIO(npy_array) | ||
return np.load(stream, allow_pickle=True) | ||
|
||
|
||
def _npz_to_sparse(npz_bytes): # type: (object) -> scipy.sparse.spmatrix | ||
"""Convert .npz-formatted data to a sparse matrix. | ||
|
||
Args: | ||
npz_bytes (object): Bytes encoding a sparse matrix in the .npz format. | ||
|
||
Returns: | ||
(scipy.sparse.spmatrix): A sparse matrix. | ||
""" | ||
buffer = BytesIO(npz_bytes) | ||
return scipy.sparse.load_npz(buffer) | ||
|
||
|
||
_decoder_map = { | ||
content_types.NPY: _npy_to_numpy, | ||
content_types.CSV: _csv_to_numpy, | ||
content_types.JSON: _json_to_numpy, | ||
content_types.NPZ: _npz_to_sparse, | ||
} | ||
|
||
|
||
def decode(obj, content_type): | ||
"""Decode an object that is encoded as one of the default content types. | ||
|
||
Args: | ||
obj (object): to be decoded. | ||
content_type (str): content type to be used. | ||
|
||
Returns: | ||
object: decoded object for prediction. | ||
""" | ||
try: | ||
decoder = _decoder_map[content_type] | ||
return decoder(obj) | ||
except KeyError: | ||
raise errors.UnsupportedFormatError(content_type) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the 'License'). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the 'license' file accompanying this file. This file is | ||
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains functionality for the default handler service.""" | ||
from __future__ import absolute_import | ||
|
||
import os | ||
|
||
from sagemaker_inference.transformer import Transformer | ||
|
||
PYTHON_PATH_ENV = "PYTHONPATH" | ||
|
||
|
||
class DefaultHandlerService(object): | ||
"""Default handler service that is executed by the model server. | ||
|
||
The handler service is responsible for defining an ``initialize`` and ``handle`` method. | ||
- The ``handle`` method is invoked for all incoming inference requests to the model server. | ||
- The ``initialize`` method is invoked at model server start up. | ||
|
||
Implementation of: https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md | ||
""" | ||
|
||
def __init__(self, transformer=None): | ||
self._service = transformer if transformer else Transformer() | ||
|
||
def handle(self, data, context): | ||
"""Handles an inference request with input data and makes a prediction. | ||
|
||
Args: | ||
data (obj): the request data. | ||
context (obj): metadata on the incoming request data. | ||
|
||
Returns: | ||
list[obj]: The return value from the Transformer.transform method, | ||
which is a serialized prediction result wrapped in a list if | ||
inference is successful. Otherwise returns an error message | ||
with the context set appropriately. | ||
|
||
""" | ||
return self._service.transform(data, context) | ||
|
||
def initialize(self, context): | ||
"""Calls the Transformer method that validates the user module against | ||
the SageMaker inference contract. | ||
""" | ||
properties = context.system_properties | ||
model_dir = properties.get("model_dir") | ||
|
||
# add model_dir/code to python path | ||
code_dir_path = "{}:".format(model_dir + "/code") | ||
if PYTHON_PATH_ENV in os.environ: | ||
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] | ||
else: | ||
os.environ[PYTHON_PATH_ENV] = code_dir_path | ||
|
||
self._service.validate_and_initialize(model_dir=model_dir, context=context) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the 'License'). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the 'license' file accompanying this file. This file is | ||
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""This module contains the definition of the default inference handler, | ||
which provides a bare-bones implementation of default inference functions. | ||
""" | ||
from __future__ import absolute_import | ||
|
||
import textwrap | ||
|
||
from sagemaker_inference import decoder, encoder, errors, utils | ||
|
||
|
||
class DefaultInferenceHandler(object): | ||
"""Bare-bones implementation of default inference functions.""" | ||
|
||
def default_model_fn(self, model_dir, context=None): | ||
"""Function responsible for loading the model. | ||
|
||
Args: | ||
model_dir (str): The directory where model files are stored. | ||
context (obj): the request context (default: None). | ||
|
||
Returns: | ||
obj: the loaded model. | ||
|
||
""" | ||
raise NotImplementedError( | ||
textwrap.dedent( | ||
""" | ||
Please provide a model_fn implementation. | ||
See documentation for model_fn at https://sagemaker.readthedocs.io/en/stable/ | ||
""" | ||
) | ||
) | ||
|
||
def default_input_fn(self, input_data, content_type, context=None): | ||
# pylint: disable=unused-argument, no-self-use | ||
"""Function responsible for deserializing the input data into an object for prediction. | ||
|
||
Args: | ||
input_data (obj): the request data. | ||
content_type (str): the request content type. | ||
context (obj): the request context (default: None). | ||
|
||
Returns: | ||
obj: data ready for prediction. | ||
|
||
""" | ||
return decoder.decode(input_data, content_type) | ||
|
||
def default_predict_fn(self, data, model, context=None): | ||
"""Function responsible for model predictions. | ||
|
||
Args: | ||
model (obj): model loaded by the model_fn. | ||
data: deserialized data returned by the input_fn. | ||
context (obj): the request context (default: None). | ||
|
||
Returns: | ||
obj: prediction result. | ||
|
||
""" | ||
raise NotImplementedError( | ||
textwrap.dedent( | ||
""" | ||
Please provide a predict_fn implementation. | ||
See documentation for predict_fn at https://sagemaker.readthedocs.io/en/stable/ | ||
""" | ||
) | ||
) | ||
|
||
def default_output_fn(self, prediction, accept, context=None): # pylint: disable=no-self-use | ||
"""Function responsible for serializing the prediction result to the desired accept type. | ||
|
||
Args: | ||
prediction (obj): prediction result returned by the predict_fn. | ||
accept (str): accept header expected by the client. | ||
context (obj): the request context (default: None). | ||
|
||
Returns: | ||
obj: prediction data. | ||
|
||
""" | ||
for content_type in utils.parse_accept(accept): | ||
if content_type in encoder.SUPPORTED_CONTENT_TYPES: | ||
return encoder.encode(prediction, content_type), content_type | ||
raise errors.UnsupportedFormatError(accept) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this package? Is it for backwards compatibility with older python versions?
Same applies to following files in which this line is added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This link offers details to the relevance of
from __future__ import absolute_import
:https://portingguide.readthedocs.io/en/latest/imports.html
Also, this line is required for the
flake8
checks.