Skip to content

Commit 1e8e043

Browse files
authored
Migrate base inference toolkit scripts and unit tests (#157)
* Merge base inference toolkit scripts and unit tests * Remove methods from environment.py
1 parent 8a7bf16 commit 1e8e043

29 files changed

+2731
-486
lines changed

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ def read(fname):
5252

5353
# We don't declare our dependency on torch here because we build with
5454
# different packages for different variants
55-
install_requires=['numpy==1.24.4', 'retrying==1.3.4', 'sagemaker-inference==1.10.0'],
55+
install_requires=['boto3==1.28.60', 'numpy==1.24.4', 'six==1.16.0',
56+
'retrying==1.3.4', 'scipy==1.10.1', 'psutil==5.9.5'],
5657
extras_require={
57-
'test': ['boto3==1.28.60', 'coverage==7.3.2', 'docker-compose==1.29.2', 'flake8==6.1.0', 'Flask==3.0.0',
58-
'mock==5.1.0', 'pytest==7.4.2', 'pytest-cov==4.1.0', 'pytest-xdist==3.3.1', 'PyYAML==5.4.1',
59-
'sagemaker==2.125.0', 'six==1.16.0', 'requests==2.31.0',
58+
'test': ['coverage==7.3.2', 'docker-compose==1.29.2', 'flake8==6.1.0', 'Flask==3.0.0',
59+
'mock==5.1.0', 'pytest==7.4.2', 'pytest-cov==4.1.0', 'pytest-xdist==3.3.1',
60+
'PyYAML==5.4.1', 'sagemaker==2.125.0', 'requests==2.31.0',
6061
'requests_mock==1.11.0', 'torch==2.1.0', 'torchvision==0.16.0', 'tox==4.11.3']
6162
},
6263

src/sagemaker_inference/__init__.py

Whitespace-only changes.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
"""This module contains constants that define MIME content types."""
16+
JSON = "application/json"
17+
CSV = "text/csv"
18+
OCTET_STREAM = "application/octet-stream"
19+
ANY = "*/*"
20+
NPY = "application/x-npy"
21+
NPZ = "application/x-npz"
22+
UTF8_TYPES = [JSON, CSV]

src/sagemaker_inference/decoder.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functionality for converting various types of
14+
files and objects to NumPy arrays."""
15+
from __future__ import absolute_import
16+
17+
import json
18+
19+
import numpy as np
20+
import scipy.sparse
21+
from six import BytesIO, StringIO
22+
23+
from sagemaker_inference import content_types, errors
24+
25+
26+
def _json_to_numpy(string_like, dtype=None): # type: (str) -> np.array
27+
"""Convert a JSON object to a numpy array.
28+
29+
Args:
30+
string_like (str): JSON string.
31+
dtype (dtype, optional): Data type of the resulting array.
32+
If None, the dtypes will be determined by the contents
33+
of each column, individually. This argument can only be
34+
used to 'upcast' the array. For downcasting, use the
35+
.astype(t) method.
36+
37+
Returns:
38+
(np.array): numpy array
39+
"""
40+
data = json.loads(string_like)
41+
return np.array(data, dtype=dtype)
42+
43+
44+
def _csv_to_numpy(string_like, dtype=None): # type: (str) -> np.array
45+
"""Convert a CSV object to a numpy array.
46+
47+
Args:
48+
string_like (str): CSV string.
49+
dtype (dtype, optional): Data type of the resulting array. If None,
50+
the dtypes will be determined by the contents of each column,
51+
individually. This argument can only be used to 'upcast' the array.
52+
For downcasting, use the .astype(t) method.
53+
54+
Returns:
55+
(np.array): numpy array
56+
"""
57+
stream = StringIO(string_like)
58+
return np.genfromtxt(stream, dtype=dtype, delimiter=",")
59+
60+
61+
def _npy_to_numpy(npy_array): # type: (object) -> np.array
62+
"""Convert a NPY array into numpy.
63+
64+
Args:
65+
npy_array (npy array): to be converted to numpy array
66+
67+
Returns:
68+
(np.array): converted numpy array.
69+
"""
70+
stream = BytesIO(npy_array)
71+
return np.load(stream, allow_pickle=True)
72+
73+
74+
def _npz_to_sparse(npz_bytes): # type: (object) -> scipy.sparse.spmatrix
75+
"""Convert .npz-formatted data to a sparse matrix.
76+
77+
Args:
78+
npz_bytes (object): Bytes encoding a sparse matrix in the .npz format.
79+
80+
Returns:
81+
(scipy.sparse.spmatrix): A sparse matrix.
82+
"""
83+
buffer = BytesIO(npz_bytes)
84+
return scipy.sparse.load_npz(buffer)
85+
86+
87+
_decoder_map = {
88+
content_types.NPY: _npy_to_numpy,
89+
content_types.CSV: _csv_to_numpy,
90+
content_types.JSON: _json_to_numpy,
91+
content_types.NPZ: _npz_to_sparse,
92+
}
93+
94+
95+
def decode(obj, content_type):
96+
"""Decode an object that is encoded as one of the default content types.
97+
98+
Args:
99+
obj (object): to be decoded.
100+
content_type (str): content type to be used.
101+
102+
Returns:
103+
object: decoded object for prediction.
104+
"""
105+
try:
106+
decoder = _decoder_map[content_type]
107+
return decoder(obj)
108+
except KeyError:
109+
raise errors.UnsupportedFormatError(content_type)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains functionality for the default handler service."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
18+
from sagemaker_inference.transformer import Transformer
19+
20+
PYTHON_PATH_ENV = "PYTHONPATH"
21+
22+
23+
class DefaultHandlerService(object):
24+
"""Default handler service that is executed by the model server.
25+
26+
The handler service is responsible for defining an ``initialize`` and ``handle`` method.
27+
- The ``handle`` method is invoked for all incoming inference requests to the model server.
28+
- The ``initialize`` method is invoked at model server start up.
29+
30+
Implementation of: https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md
31+
"""
32+
33+
def __init__(self, transformer=None):
34+
self._service = transformer if transformer else Transformer()
35+
36+
def handle(self, data, context):
37+
"""Handles an inference request with input data and makes a prediction.
38+
39+
Args:
40+
data (obj): the request data.
41+
context (obj): metadata on the incoming request data.
42+
43+
Returns:
44+
list[obj]: The return value from the Transformer.transform method,
45+
which is a serialized prediction result wrapped in a list if
46+
inference is successful. Otherwise returns an error message
47+
with the context set appropriately.
48+
49+
"""
50+
return self._service.transform(data, context)
51+
52+
def initialize(self, context):
53+
"""Calls the Transformer method that validates the user module against
54+
the SageMaker inference contract.
55+
"""
56+
properties = context.system_properties
57+
model_dir = properties.get("model_dir")
58+
59+
# add model_dir/code to python path
60+
code_dir_path = "{}:".format(model_dir + "/code")
61+
if PYTHON_PATH_ENV in os.environ:
62+
os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV]
63+
else:
64+
os.environ[PYTHON_PATH_ENV] = code_dir_path
65+
66+
self._service.validate_and_initialize(model_dir=model_dir, context=context)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License'). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains the definition of the default inference handler,
14+
which provides a bare-bones implementation of default inference functions.
15+
"""
16+
from __future__ import absolute_import
17+
18+
import textwrap
19+
20+
from sagemaker_inference import decoder, encoder, errors, utils
21+
22+
23+
class DefaultInferenceHandler(object):
24+
"""Bare-bones implementation of default inference functions."""
25+
26+
def default_model_fn(self, model_dir, context=None):
27+
"""Function responsible for loading the model.
28+
29+
Args:
30+
model_dir (str): The directory where model files are stored.
31+
context (obj): the request context (default: None).
32+
33+
Returns:
34+
obj: the loaded model.
35+
36+
"""
37+
raise NotImplementedError(
38+
textwrap.dedent(
39+
"""
40+
Please provide a model_fn implementation.
41+
See documentation for model_fn at https://sagemaker.readthedocs.io/en/stable/
42+
"""
43+
)
44+
)
45+
46+
def default_input_fn(self, input_data, content_type, context=None):
47+
# pylint: disable=unused-argument, no-self-use
48+
"""Function responsible for deserializing the input data into an object for prediction.
49+
50+
Args:
51+
input_data (obj): the request data.
52+
content_type (str): the request content type.
53+
context (obj): the request context (default: None).
54+
55+
Returns:
56+
obj: data ready for prediction.
57+
58+
"""
59+
return decoder.decode(input_data, content_type)
60+
61+
def default_predict_fn(self, data, model, context=None):
62+
"""Function responsible for model predictions.
63+
64+
Args:
65+
model (obj): model loaded by the model_fn.
66+
data: deserialized data returned by the input_fn.
67+
context (obj): the request context (default: None).
68+
69+
Returns:
70+
obj: prediction result.
71+
72+
"""
73+
raise NotImplementedError(
74+
textwrap.dedent(
75+
"""
76+
Please provide a predict_fn implementation.
77+
See documentation for predict_fn at https://sagemaker.readthedocs.io/en/stable/
78+
"""
79+
)
80+
)
81+
82+
def default_output_fn(self, prediction, accept, context=None): # pylint: disable=no-self-use
83+
"""Function responsible for serializing the prediction result to the desired accept type.
84+
85+
Args:
86+
prediction (obj): prediction result returned by the predict_fn.
87+
accept (str): accept header expected by the client.
88+
context (obj): the request context (default: None).
89+
90+
Returns:
91+
obj: prediction data.
92+
93+
"""
94+
for content_type in utils.parse_accept(accept):
95+
if content_type in encoder.SUPPORTED_CONTENT_TYPES:
96+
return encoder.encode(prediction, content_type), content_type
97+
raise errors.UnsupportedFormatError(accept)

0 commit comments

Comments
 (0)