diff --git a/setup.py b/setup.py index 2358668..83fd58e 100644 --- a/setup.py +++ b/setup.py @@ -52,11 +52,12 @@ def read(fname): # We don't declare our dependency on torch here because we build with # different packages for different variants - install_requires=['numpy==1.24.4', 'retrying==1.3.4', 'sagemaker-inference==1.10.0'], + install_requires=['boto3==1.28.60', 'numpy==1.24.4', 'six==1.16.0', + 'retrying==1.3.4', 'scipy==1.10.1', 'psutil==5.9.5'], extras_require={ - 'test': ['boto3==1.28.60', 'coverage==7.3.2', 'docker-compose==1.29.2', 'flake8==6.1.0', 'Flask==3.0.0', - 'mock==5.1.0', 'pytest==7.4.2', 'pytest-cov==4.1.0', 'pytest-xdist==3.3.1', 'PyYAML==5.4.1', - 'sagemaker==2.125.0', 'six==1.16.0', 'requests==2.31.0', + 'test': ['coverage==7.3.2', 'docker-compose==1.29.2', 'flake8==6.1.0', 'Flask==3.0.0', + 'mock==5.1.0', 'pytest==7.4.2', 'pytest-cov==4.1.0', 'pytest-xdist==3.3.1', + 'PyYAML==5.4.1', 'sagemaker==2.125.0', 'requests==2.31.0', 'requests_mock==1.11.0', 'torch==2.1.0', 'torchvision==0.16.0', 'tox==4.11.3'] }, diff --git a/src/sagemaker_inference/__init__.py b/src/sagemaker_inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/sagemaker_inference/content_types.py b/src/sagemaker_inference/content_types.py new file mode 100644 index 0000000..905817a --- /dev/null +++ b/src/sagemaker_inference/content_types.py @@ -0,0 +1,22 @@ +# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +"""This module contains constants that define MIME content types.""" +JSON = "application/json" +CSV = "text/csv" +OCTET_STREAM = "application/octet-stream" +ANY = "*/*" +NPY = "application/x-npy" +NPZ = "application/x-npz" +UTF8_TYPES = [JSON, CSV] diff --git a/src/sagemaker_inference/decoder.py b/src/sagemaker_inference/decoder.py new file mode 100644 index 0000000..3fb8b67 --- /dev/null +++ b/src/sagemaker_inference/decoder.py @@ -0,0 +1,109 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functionality for converting various types of +files and objects to NumPy arrays.""" +from __future__ import absolute_import + +import json + +import numpy as np +import scipy.sparse +from six import BytesIO, StringIO + +from sagemaker_inference import content_types, errors + + +def _json_to_numpy(string_like, dtype=None): # type: (str) -> np.array + """Convert a JSON object to a numpy array. + + Args: + string_like (str): JSON string. + dtype (dtype, optional): Data type of the resulting array. + If None, the dtypes will be determined by the contents + of each column, individually. This argument can only be + used to 'upcast' the array. For downcasting, use the + .astype(t) method. + + Returns: + (np.array): numpy array + """ + data = json.loads(string_like) + return np.array(data, dtype=dtype) + + +def _csv_to_numpy(string_like, dtype=None): # type: (str) -> np.array + """Convert a CSV object to a numpy array. + + Args: + string_like (str): CSV string. + dtype (dtype, optional): Data type of the resulting array. If None, + the dtypes will be determined by the contents of each column, + individually. This argument can only be used to 'upcast' the array. + For downcasting, use the .astype(t) method. + + Returns: + (np.array): numpy array + """ + stream = StringIO(string_like) + return np.genfromtxt(stream, dtype=dtype, delimiter=",") + + +def _npy_to_numpy(npy_array): # type: (object) -> np.array + """Convert a NPY array into numpy. + + Args: + npy_array (npy array): to be converted to numpy array + + Returns: + (np.array): converted numpy array. + """ + stream = BytesIO(npy_array) + return np.load(stream, allow_pickle=True) + + +def _npz_to_sparse(npz_bytes): # type: (object) -> scipy.sparse.spmatrix + """Convert .npz-formatted data to a sparse matrix. + + Args: + npz_bytes (object): Bytes encoding a sparse matrix in the .npz format. + + Returns: + (scipy.sparse.spmatrix): A sparse matrix. + """ + buffer = BytesIO(npz_bytes) + return scipy.sparse.load_npz(buffer) + + +_decoder_map = { + content_types.NPY: _npy_to_numpy, + content_types.CSV: _csv_to_numpy, + content_types.JSON: _json_to_numpy, + content_types.NPZ: _npz_to_sparse, +} + + +def decode(obj, content_type): + """Decode an object that is encoded as one of the default content types. + + Args: + obj (object): to be decoded. + content_type (str): content type to be used. + + Returns: + object: decoded object for prediction. + """ + try: + decoder = _decoder_map[content_type] + return decoder(obj) + except KeyError: + raise errors.UnsupportedFormatError(content_type) diff --git a/src/sagemaker_inference/default_handler_service.py b/src/sagemaker_inference/default_handler_service.py new file mode 100644 index 0000000..cc6e390 --- /dev/null +++ b/src/sagemaker_inference/default_handler_service.py @@ -0,0 +1,66 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functionality for the default handler service.""" +from __future__ import absolute_import + +import os + +from sagemaker_inference.transformer import Transformer + +PYTHON_PATH_ENV = "PYTHONPATH" + + +class DefaultHandlerService(object): + """Default handler service that is executed by the model server. + + The handler service is responsible for defining an ``initialize`` and ``handle`` method. + - The ``handle`` method is invoked for all incoming inference requests to the model server. + - The ``initialize`` method is invoked at model server start up. + + Implementation of: https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md + """ + + def __init__(self, transformer=None): + self._service = transformer if transformer else Transformer() + + def handle(self, data, context): + """Handles an inference request with input data and makes a prediction. + + Args: + data (obj): the request data. + context (obj): metadata on the incoming request data. + + Returns: + list[obj]: The return value from the Transformer.transform method, + which is a serialized prediction result wrapped in a list if + inference is successful. Otherwise returns an error message + with the context set appropriately. + + """ + return self._service.transform(data, context) + + def initialize(self, context): + """Calls the Transformer method that validates the user module against + the SageMaker inference contract. + """ + properties = context.system_properties + model_dir = properties.get("model_dir") + + # add model_dir/code to python path + code_dir_path = "{}:".format(model_dir + "/code") + if PYTHON_PATH_ENV in os.environ: + os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] + else: + os.environ[PYTHON_PATH_ENV] = code_dir_path + + self._service.validate_and_initialize(model_dir=model_dir, context=context) diff --git a/src/sagemaker_inference/default_inference_handler.py b/src/sagemaker_inference/default_inference_handler.py new file mode 100644 index 0000000..897b10f --- /dev/null +++ b/src/sagemaker_inference/default_inference_handler.py @@ -0,0 +1,97 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains the definition of the default inference handler, +which provides a bare-bones implementation of default inference functions. +""" +from __future__ import absolute_import + +import textwrap + +from sagemaker_inference import decoder, encoder, errors, utils + + +class DefaultInferenceHandler(object): + """Bare-bones implementation of default inference functions.""" + + def default_model_fn(self, model_dir, context=None): + """Function responsible for loading the model. + + Args: + model_dir (str): The directory where model files are stored. + context (obj): the request context (default: None). + + Returns: + obj: the loaded model. + + """ + raise NotImplementedError( + textwrap.dedent( + """ + Please provide a model_fn implementation. + See documentation for model_fn at https://sagemaker.readthedocs.io/en/stable/ + """ + ) + ) + + def default_input_fn(self, input_data, content_type, context=None): + # pylint: disable=unused-argument, no-self-use + """Function responsible for deserializing the input data into an object for prediction. + + Args: + input_data (obj): the request data. + content_type (str): the request content type. + context (obj): the request context (default: None). + + Returns: + obj: data ready for prediction. + + """ + return decoder.decode(input_data, content_type) + + def default_predict_fn(self, data, model, context=None): + """Function responsible for model predictions. + + Args: + model (obj): model loaded by the model_fn. + data: deserialized data returned by the input_fn. + context (obj): the request context (default: None). + + Returns: + obj: prediction result. + + """ + raise NotImplementedError( + textwrap.dedent( + """ + Please provide a predict_fn implementation. + See documentation for predict_fn at https://sagemaker.readthedocs.io/en/stable/ + """ + ) + ) + + def default_output_fn(self, prediction, accept, context=None): # pylint: disable=no-self-use + """Function responsible for serializing the prediction result to the desired accept type. + + Args: + prediction (obj): prediction result returned by the predict_fn. + accept (str): accept header expected by the client. + context (obj): the request context (default: None). + + Returns: + obj: prediction data. + + """ + for content_type in utils.parse_accept(accept): + if content_type in encoder.SUPPORTED_CONTENT_TYPES: + return encoder.encode(prediction, content_type), content_type + raise errors.UnsupportedFormatError(accept) diff --git a/src/sagemaker_inference/encoder.py b/src/sagemaker_inference/encoder.py new file mode 100644 index 0000000..fdf38a0 --- /dev/null +++ b/src/sagemaker_inference/encoder.py @@ -0,0 +1,110 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functionality for converting array-like objects +to various types of objects and files.""" +from __future__ import absolute_import + +import json + +import numpy as np +from six import BytesIO, StringIO + +from sagemaker_inference import content_types, errors + + +def _array_to_json(array_like): + """Convert an array-like object to JSON. + + To understand better what an array-like object is see: + https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays + + Args: + array_like (np.array or Iterable or int or float): array-like object + to be converted to JSON. + + Returns: + (str): object serialized to JSON + """ + + def default(_array_like): + if hasattr(_array_like, "tolist"): + return _array_like.tolist() + return json.JSONEncoder().default(_array_like) + + return json.dumps(array_like, default=default) + + +def _array_to_npy(array_like): + """Convert an array-like object to the NPY format. + + To understand better what an array-like object is see: + https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays + + Args: + array_like (np.array or Iterable or int or float): array-like object + to be converted to NPY. + + Returns: + (obj): NPY array. + """ + buffer = BytesIO() + np.save(buffer, array_like) + return buffer.getvalue() + + +def _array_to_csv(array_like): + """Convert an array-like object to CSV. + + To understand better what an array-like object is see: + https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays + + Args: + array_like (np.array or Iterable or int or float): array-like object + to be converted to CSV. + + Returns: + (str): object serialized to CSV + """ + stream = StringIO() + np.savetxt(stream, array_like, delimiter=",", fmt="%s") + return stream.getvalue() + + +_encoder_map = { + content_types.NPY: _array_to_npy, + content_types.CSV: _array_to_csv, + content_types.JSON: _array_to_json, +} + + +SUPPORTED_CONTENT_TYPES = set(_encoder_map.keys()) + + +def encode(array_like, content_type): + """Encode an array-like object in a specific content_type to a numpy array. + + To understand better what an array-like object is see: + https://docs.scipy.org/doc/numpy/user/basics.creation.html#converting-python-array-like-objects-to-numpy-arrays + + Args: + array_like (np.array or Iterable or int or float): to be converted to numpy. + content_type (str): content type to be used. + + Returns: + (np.array): object converted as numpy array. + """ + try: + encoder = _encoder_map[content_type] + return encoder(array_like) + except KeyError: + raise errors.UnsupportedFormatError(content_type) diff --git a/src/sagemaker_inference/environment.py b/src/sagemaker_inference/environment.py new file mode 100644 index 0000000..8e22675 --- /dev/null +++ b/src/sagemaker_inference/environment.py @@ -0,0 +1,133 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functionality that provides access to system +characteristics, environment variables and configuration settings. +""" +from __future__ import absolute_import + +import os +from typing import Optional + +from sagemaker_inference import content_types, parameters + +DEFAULT_MODULE_NAME = "inference.py" +DEFAULT_MODEL_SERVER_TIMEOUT = "60" +DEFAULT_HTTP_PORT = "8080" + +SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str + +base_dir = os.environ.get(parameters.BASE_PATH_ENV, SAGEMAKER_BASE_PATH) # type: str + +if os.environ.get(parameters.MULTI_MODEL_ENV) == "true": + model_dir = os.path.join(base_dir, "models") # type: str +else: + model_dir = os.path.join(base_dir, "model") # type: str +# str: the directory where models should be saved, e.g., /opt/ml/model/ + +code_dir = os.path.join(model_dir, "code") # type: str +"""str: the path of the user's code directory, e.g., /opt/ml/model/code/""" + + +class Environment(object): + """Provides access to aspects of the serving environment relevant to serving containers, + including system characteristics, environment variables and configuration settings. + + The Environment is a read-only snapshot of the container environment. + It is a dictionary-like object, allowing any builtin function that works with dictionary. + + Attributes: + module_name (str): The name of the user-provided module. Default is inference.py. + model_server_timeout (int): Timeout for the model server. Default is 60. + model_server_workers (str): Number of worker processes the model server will use. + + default_accept (str): The desired default MIME type of the inference in the response + as specified in the user-supplied SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT environment + variable. Otherwise, returns 'application/json' by default. + For example: application/json + http_port (str): Port that SageMaker will use to handle invocations and pings against the + running Docker container. Default is 8080. For example: 8080 + safe_port_range (str): HTTP port range that can be used by customers to avoid collisions + with the HTTP port specified by SageMaker for handling pings and invocations. + For example: 1111-2222 + + """ + + def __init__(self): + self._module_name = os.environ.get(parameters.USER_PROGRAM_ENV, DEFAULT_MODULE_NAME) + self._model_server_timeout = int( + os.environ.get(parameters.MODEL_SERVER_TIMEOUT_ENV, DEFAULT_MODEL_SERVER_TIMEOUT) + ) + + self._model_server_workers = os.environ.get(parameters.MODEL_SERVER_WORKERS_ENV) + + self._default_accept = os.environ.get( + parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV, content_types.JSON + ) + self._inference_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) + self._management_http_port = os.environ.get(parameters.BIND_TO_PORT_ENV, DEFAULT_HTTP_PORT) + self._safe_port_range = os.environ.get(parameters.SAFE_PORT_RANGE_ENV) + + @staticmethod + def _parse_module_name(program_param): + """Given a module name or a script name, return the module name. + + Args: + program_param (str): Module or script name. + + Returns: + str: Module name. + + """ + if program_param and program_param.endswith(".py"): + return program_param[:-3] + return program_param + + @property + def module_name(self): # type: () -> str + """str: Name of the user-provided module.""" + return self._parse_module_name(self._module_name) + + @property + def model_server_timeout(self) -> int: + """int: Timeout used for model server's backend workers before they are + deemed unresponsive and rebooted. + + """ + return self._model_server_timeout + + @property + def model_server_workers(self) -> Optional[str]: + """str: Number of worker processes the model server is configured to use.""" + return self._model_server_workers + + @property + def default_accept(self) -> str: + """str: The desired default MIME type of the inference in the response.""" + return self._default_accept + + @property + def inference_http_port(self) -> str: + """str: HTTP port that SageMaker uses to handle invocations and pings.""" + return self._inference_http_port + + @property + def management_http_port(self) -> str: + """str: HTTP port that SageMaker uses to handle model management requests.""" + return self._management_http_port + + @property + def safe_port_range(self) -> Optional[str]: + """str: HTTP port range that can be used by users to avoid collisions with the HTTP port + specified by SageMaker for handling pings and invocations. + """ + return self._safe_port_range diff --git a/src/sagemaker_inference/errors.py b/src/sagemaker_inference/errors.py new file mode 100644 index 0000000..ae0f573 --- /dev/null +++ b/src/sagemaker_inference/errors.py @@ -0,0 +1,71 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains custom exceptions.""" +from __future__ import absolute_import + +import textwrap + + +class UnsupportedFormatError(Exception): + """Exception used to indicate that an unsupported content type was provided.""" + + def __init__(self, content_type, **kwargs): + self._message = textwrap.dedent( + """Content type %s is not supported by this framework. + + Please implement input_fn to to deserialize the request data or an output_fn to + serialize the response. For more information, see the SageMaker Python SDK README.""" + % content_type + ) + super(UnsupportedFormatError, self).__init__(self._message, **kwargs) + + +class BaseInferenceToolkitError(Exception): + """Exception used to indicate a problem that occurred during inference. + + This is meant to be extended from so that customers may handle errors + within inference servers. + """ + + def __init__(self, status_code, message, phrase): + """Initializes an instance of BaseInferenceToolkitError. + + Args: + status_code (int): HTTP Error Status Code to send to client + message (str): Response message to send to client + phrase (str): Response body to send to client + """ + self.status_code = status_code + self.message = message + self.phrase = phrase + super(BaseInferenceToolkitError, self).__init__(status_code, message, phrase) + + +class GenericInferenceToolkitError(BaseInferenceToolkitError): + """Exception used to indicate a problem that occurred during inference. + + This is meant to be a generic implementation of the BaseInferenceToolkitError + for re-raising unexpected exceptions in a way that can be sent back to the client. + """ + + def __init__(self, status_code, message=None, phrase=None): + """Initializes an instance of GenericInferenceToolkitError. + + Args: + status_code (int): HTTP Error Status Code to send to client + message (str): Response message to send to client + phrase (str): Response body to send to client + """ + message = message or "Invalid Request" + phrase = phrase or message + super(GenericInferenceToolkitError, self).__init__(status_code, message, phrase) diff --git a/src/sagemaker_inference/logging.py b/src/sagemaker_inference/logging.py new file mode 100644 index 0000000..702e575 --- /dev/null +++ b/src/sagemaker_inference/logging.py @@ -0,0 +1,36 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains logging functionality.""" +from __future__ import absolute_import + +import logging + + +def configure_logger(): + """Add a handler to the library's logger with a formatter that + includes a timestamp along with the message. + """ + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + get_logger().addHandler(handler) + + +def get_logger(): + """Return a logger with the name "sagemaker-inference", + creating it if necessary. + + Returns: + logging.Logger: Instance of the Logger class. + """ + return logging.getLogger("sagemaker-inference") diff --git a/src/sagemaker_inference/model_server.py b/src/sagemaker_inference/model_server.py new file mode 100644 index 0000000..1b29ee2 --- /dev/null +++ b/src/sagemaker_inference/model_server.py @@ -0,0 +1,94 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functionality to configure and start the +multi-model server.""" +from __future__ import absolute_import + +import os +import re +import subprocess +import sys + +import boto3 + +from sagemaker_inference import logging +from sagemaker_inference.environment import code_dir + +logging.configure_logger() +logger = logging.get_logger() + +REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt") + + +def _install_requirements(): + logger.info("installing packages from requirements.txt...") + pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH] + if os.getenv("CA_REPOSITORY_ARN"): + index = _get_codeartifact_index() + pip_install_cmd.append("-i") + pip_install_cmd.append(index) + try: + subprocess.check_call(pip_install_cmd) + except subprocess.CalledProcessError: + logger.error("failed to install required packages, exiting") + raise ValueError("failed to install required packages") + + +def _get_codeartifact_index(): + """ + Build the authenticated codeartifact index url + https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html + https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies + :return: authenticated codeartifact index url + """ + repository_arn = os.getenv("CA_REPOSITORY_ARN") + arn_regex = ( + "arn:(?P[^:]+):codeartifact:(?P[^:]+):(?P[^:]+)" + ":repository/(?P[^/]+)/(?P.+)" + ) + m = re.match(arn_regex, repository_arn) + if not m: + raise Exception("invalid CodeArtifact repository arn {}".format(repository_arn)) + domain = m.group("domain") + owner = m.group("account") + repository = m.group("repository") + region = m.group("region") + + logger.info( + "configuring pip to use codeartifact " + "(domain: %s, domain owner: %s, repository: %s, region: %s)", + domain, + owner, + repository, + region, + ) + try: + client = boto3.client("codeartifact", region_name=region) + auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner) + token = auth_token_response["authorizationToken"] + endpoint_response = client.get_repository_endpoint( + domain=domain, domainOwner=owner, repository=repository, format="pypi" + ) + unauthenticated_index = endpoint_response["repositoryEndpoint"] + return re.sub( + "https://", + "https://aws:{}@".format(token), + re.sub( + "{}/?$".format(repository), + "{}/simple/".format(repository), + unauthenticated_index, + ), + ) + except Exception: + logger.error("failed to configure pip to use codeartifact") + raise Exception("failed to configure pip to use codeartifact") diff --git a/src/sagemaker_inference/parameters.py b/src/sagemaker_inference/parameters.py new file mode 100644 index 0000000..0f7fa23 --- /dev/null +++ b/src/sagemaker_inference/parameters.py @@ -0,0 +1,25 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains string constants that define inference toolkit +parameters.""" +from __future__ import absolute_import + +BASE_PATH_ENV = "SAGEMAKER_BASE_DIR" # type: str +USER_PROGRAM_ENV = "SAGEMAKER_PROGRAM" # type: str +LOG_LEVEL_ENV = "SAGEMAKER_CONTAINER_LOG_LEVEL" # type: str +DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str +MODEL_SERVER_WORKERS_ENV = "SAGEMAKER_MODEL_SERVER_WORKERS" # type: str +MODEL_SERVER_TIMEOUT_ENV = "SAGEMAKER_MODEL_SERVER_TIMEOUT" # type: str +BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str +SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str +MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str diff --git a/src/sagemaker_inference/transformer.py b/src/sagemaker_inference/transformer.py new file mode 100644 index 0000000..ed4fab5 --- /dev/null +++ b/src/sagemaker_inference/transformer.py @@ -0,0 +1,284 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains functionality for the Transformer class, +which represents the execution workflow for handling inference +requests. +""" +from __future__ import absolute_import + +import importlib +import traceback + +try: + from inspect import signature # pylint: disable=ungrouped-imports +except ImportError: + # for Python2.7 + import subprocess + import sys + + subprocess.check_call([sys.executable, "-m", "pip", "install", "inspect2"]) + from inspect2 import signature + +try: + from importlib.util import find_spec # pylint: disable=ungrouped-imports +except ImportError: + import imp # noqa: F401 + + def find_spec(module_name): + """Function that searches for a module. + + Args: + module_name: The name of the module to search for. + + Returns: + bool: Whether the module was found. + """ + try: + imp.find_module(module_name) + return True + except ImportError: + return None + + +from six.moves import http_client + +from sagemaker_inference import content_types, environment, utils +from sagemaker_inference.default_inference_handler import DefaultInferenceHandler +from sagemaker_inference.errors import BaseInferenceToolkitError, GenericInferenceToolkitError + + +class Transformer(object): + """Represents the execution workflow for handling inference requests + sent to the model server. + """ + + def __init__(self, default_inference_handler=None): + """Initialize a ``Transformer``. + + Args: + default_inference_handler (DefaultInferenceHandler): default implementation of + inference handlers to use in absence of expected serving functions within + the user module. Defaults to ``DefaultInferenceHandler``. + + """ + self._default_inference_handler = default_inference_handler or DefaultInferenceHandler() + self._initialized = False + self._environment = None + self._model = None + + self._pre_model_fn = None + self._model_warmup_fn = None + self._model_fn = None + self._transform_fn = None + self._input_fn = None + self._predict_fn = None + self._output_fn = None + self._context = None + + @staticmethod + def handle_error(context, inference_exception, trace): + """Set context appropriately for error response. + + Args: + context (mms.context.Context): The inference context. + inference_exception (sagemaker_inference.errors.BaseInferenceToolkitError): An exception + raised during inference, with information for the error response. + trace (traceback): The stacktrace of the error. + + Returns: + str: The error message and stacktrace from the exception. + """ + context.set_response_status( + code=inference_exception.status_code, + phrase=utils.remove_crlf(inference_exception.phrase), + ) + return ["{}\n{}".format(inference_exception.message, trace)] + + def transform(self, data, context): + """Take a request with input data, deserialize it, make a prediction, and return a + serialized response. + + Args: + data (obj): the request data. + context (obj): metadata on the incoming request data. + + Returns: + list[obj]: The serialized prediction result wrapped in a list if + inference is successful. Otherwise returns an error message + with the context set appropriately. + """ + try: + properties = context.system_properties + model_dir = properties.get("model_dir") + self.validate_and_initialize(model_dir=model_dir, context=context) + + response_list = [] + + for i in range(len(data)): + input_data = data[i].get("body") + + request_processor = context.request_processor[0] + + request_property = request_processor.get_request_properties() + content_type = utils.retrieve_content_type_header(request_property) + accept = request_property.get("Accept") or request_property.get("accept") + + if not accept or accept == content_types.ANY: + accept = self._environment.default_accept + + if content_type in content_types.UTF8_TYPES: + input_data = input_data.decode("utf-8") + + result = self._run_handler_function( + self._transform_fn, *(self._model, input_data, content_type, accept) + ) + + response = result + response_content_type = accept + + if isinstance(result, tuple): + # handles tuple for backwards compatibility + response = result[0] + response_content_type = result[1] + + context.set_response_content_type(0, response_content_type) + + response_list.append(response) + + return response_list + except Exception as e: # pylint: disable=broad-except + trace = traceback.format_exc() + if isinstance(e, BaseInferenceToolkitError): + return self.handle_error(context, e, trace) + else: + return self.handle_error( + context, + GenericInferenceToolkitError(http_client.INTERNAL_SERVER_ERROR, str(e)), + trace, + ) + + def validate_and_initialize(self, model_dir=environment.model_dir, context=None): + """Validates the user module against the SageMaker inference contract. + + Load the model as defined by the ``model_fn`` to prepare handling predictions. + + """ + if not self._initialized: + self._context = context + self._environment = environment.Environment() + self._validate_user_module_and_set_functions() + + if self._pre_model_fn is not None: + self._run_handler_function(self._pre_model_fn, *(model_dir,)) + + self._model = self._run_handler_function(self._model_fn, *(model_dir,)) + + if self._model_warmup_fn is not None: + self._run_handler_function(self._model_warmup_fn, *(model_dir, self._model)) + + self._initialized = True + + def _validate_user_module_and_set_functions(self): + """Retrieves and validates the inference handlers provided within the user module. + + Default implementations of the inference handlers are utilized in + place of missing functions defined in the user module. + + """ + user_module_name = self._environment.module_name + + self._pre_model_fn = getattr(self._default_inference_handler, "default_pre_model_fn", None) + self._model_warmup_fn = getattr( + self._default_inference_handler, "default_model_warmup_fn", None + ) + + if find_spec(user_module_name) is not None: + user_module = importlib.import_module(user_module_name) + + self._model_fn = getattr( + user_module, "model_fn", self._default_inference_handler.default_model_fn + ) + + transform_fn = getattr(user_module, "transform_fn", None) + input_fn = getattr(user_module, "input_fn", None) + predict_fn = getattr(user_module, "predict_fn", None) + output_fn = getattr(user_module, "output_fn", None) + pre_model_fn = getattr(user_module, "pre_model_fn", None) + model_warmup_fn = getattr(user_module, "model_warmup_fn", None) + + if transform_fn and (input_fn or predict_fn or output_fn): + raise ValueError( + "Cannot use transform_fn implementation in conjunction with " + "input_fn, predict_fn, and/or output_fn implementation" + ) + + self._transform_fn = transform_fn or self._default_transform_fn + self._input_fn = input_fn or self._default_inference_handler.default_input_fn + self._predict_fn = predict_fn or self._default_inference_handler.default_predict_fn + self._output_fn = output_fn or self._default_inference_handler.default_output_fn + if pre_model_fn is not None: + self._pre_model_fn = pre_model_fn + if model_warmup_fn is not None: + self._model_warmup_fn = model_warmup_fn + else: + self._model_fn = self._default_inference_handler.default_model_fn + self._input_fn = self._default_inference_handler.default_input_fn + self._predict_fn = self._default_inference_handler.default_predict_fn + self._output_fn = self._default_inference_handler.default_output_fn + + self._transform_fn = self._default_transform_fn + + def _default_transform_fn(self, model, input_data, content_type, accept, context=None): + # pylint: disable=unused-argument + """Make predictions against the model and return a serialized response. + This serves as the default implementation of transform_fn, used when the + user has not provided an implementation. + + Args: + model (obj): model loaded by model_fn. + input_data (obj): the request data. + content_type (str): the request content type. + accept (str): accept header expected by the client. + context (obj): the request context (default: None). + + Returns: + obj: the serialized prediction result or a tuple of the form + (response_data, content_type) + + """ + data = self._run_handler_function(self._input_fn, *(input_data, content_type)) + prediction = self._run_handler_function(self._predict_fn, *(data, model)) + result = self._run_handler_function(self._output_fn, *(prediction, accept)) + return result + + def _run_handler_function(self, func, *argv): + """Helper to call the handler function which covers 2 cases: + 1. the handle function takes context + 2. the handle function does not take context + """ + num_func_input = len(signature(func).parameters) + if num_func_input == len(argv): + # function does not take context + result = func(*argv) + elif num_func_input == len(argv) + 1: + # function takes context + argv_context = argv + (self._context,) + result = func(*argv_context) + else: + raise TypeError( + "{} takes {} arguments but {} were given.".format( + func.__name__, num_func_input, len(argv) + ) + ) + + return result diff --git a/src/sagemaker_inference/utils.py b/src/sagemaker_inference/utils.py new file mode 100644 index 0000000..3f1cb1f --- /dev/null +++ b/src/sagemaker_inference/utils.py @@ -0,0 +1,102 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains utility functions related to reading files, +writing files, and retrieving information from requests. +""" +from __future__ import absolute_import + +import re + +CONTENT_TYPE_REGEX = re.compile("^[Cc]ontent-?[Tt]ype") + + +def read_file(path, mode="r"): + """Read data from a file. + + Args: + path (str): path to the file. + mode (str): mode which the file will be open. + + Returns: + (str): contents of the file. + + """ + with open(path, mode) as f: + return f.read() + + +def write_file(path, data, mode="w"): # type: (str, str, str) -> None + """Write data to a file. + + Args: + path (str): path to the file. + data (str): data to be written to the file. + mode (str): mode which the file will be open. + + """ + with open(path, mode) as f: + f.write(data) + + +def retrieve_content_type_header(request_property): + """Retrieve Content-Type header from incoming request. + + This function handles multiple spellings of Content-Type based on the presence of + the dash and initial capitalization in each respective word. + + Args: + request_property (dict): incoming request metadata + + Returns: + (str): the request content type. + + """ + for key in request_property: + if CONTENT_TYPE_REGEX.match(key): + return request_property[key] + + return None + + +def parse_accept(accept): + """Parses the Accept header sent with a request. + + Args: + accept (str): the value of an Accept header. + + Returns: + (list): A list containing the MIME types that the client is able to + understand. + """ + return accept.replace(" ", "").split(",") + + +def remove_crlf(illegal_string): + """Removes characters prohibited by the MMS dependency Netty. + + https://github.com/netty/netty/issues/8312 + + Args: + illegal_string: The string containing prohibited characters. + + Returns: + str: The input string with the prohibited characters removed. + """ + prohibited = ("\r", "\n") + + sanitized_string = illegal_string + + for character in prohibited: + sanitized_string = sanitized_string.replace(character, " ") + + return sanitized_string diff --git a/test/container/2.0.0/Dockerfile.dlc.cpu b/test/container/2.0.0/Dockerfile.dlc.cpu index edb7064..98a912c 100644 --- a/test/container/2.0.0/Dockerfile.dlc.cpu +++ b/test/container/2.0.0/Dockerfile.dlc.cpu @@ -3,6 +3,6 @@ FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:2.0.0-cpu-py31 COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz -RUN pip uninstall -y sagemaker_pytorch_inference && \ +RUN pip uninstall -y sagemaker_inference sagemaker_pytorch_inference && \ pip install --upgrade --no-cache-dir /sagemaker_pytorch_inference.tar.gz && \ rm /sagemaker_pytorch_inference.tar.gz diff --git a/test/container/2.0.0/Dockerfile.dlc.gpu b/test/container/2.0.0/Dockerfile.dlc.gpu index 01edb02..ae1ebd7 100644 --- a/test/container/2.0.0/Dockerfile.dlc.gpu +++ b/test/container/2.0.0/Dockerfile.dlc.gpu @@ -3,6 +3,6 @@ FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:2.0.0-gpu-py31 COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz -RUN pip uninstall -y sagemaker_pytorch_inference && \ +RUN pip uninstall -y sagemaker_inference sagemaker_pytorch_inference && \ pip install --upgrade --no-cache-dir /sagemaker_pytorch_inference.tar.gz && \ rm /sagemaker_pytorch_inference.tar.gz diff --git a/test/container/2.0.1/Dockerfile.dlc.cpu b/test/container/2.0.1/Dockerfile.dlc.cpu index 19c8a2f..e1bf5bd 100644 --- a/test/container/2.0.1/Dockerfile.dlc.cpu +++ b/test/container/2.0.1/Dockerfile.dlc.cpu @@ -3,6 +3,6 @@ FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:2.0.1-cpu-py31 COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz -RUN pip uninstall -y sagemaker_pytorch_inference && \ +RUN pip uninstall -y sagemaker_inference sagemaker_pytorch_inference && \ pip install --upgrade --no-cache-dir /sagemaker_pytorch_inference.tar.gz && \ rm /sagemaker_pytorch_inference.tar.gz diff --git a/test/container/2.0.1/Dockerfile.dlc.gpu b/test/container/2.0.1/Dockerfile.dlc.gpu index 67501bd..3076527 100644 --- a/test/container/2.0.1/Dockerfile.dlc.gpu +++ b/test/container/2.0.1/Dockerfile.dlc.gpu @@ -3,6 +3,6 @@ FROM 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-inference:2.0.1-gpu-py31 COPY dist/sagemaker_pytorch_inference-*.tar.gz /sagemaker_pytorch_inference.tar.gz -RUN pip uninstall -y sagemaker_pytorch_inference && \ +RUN pip uninstall -y sagemaker_inference sagemaker_pytorch_inference && \ pip install --upgrade --no-cache-dir /sagemaker_pytorch_inference.tar.gz && \ rm /sagemaker_pytorch_inference.tar.gz diff --git a/test/unit/test_decoder.py b/test/unit/test_decoder.py new file mode 100644 index 0000000..9f433a7 --- /dev/null +++ b/test/unit/test_decoder.py @@ -0,0 +1,100 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import Mock, patch +import numpy as np +import pytest +import scipy.sparse +from six import BytesIO + +from sagemaker_inference import content_types, decoder, errors + + +@pytest.mark.parametrize( + "target", + ([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], ["42", "6", "9"], {42: {"6": 9.0}}), +) +def test_npy_to_numpy(target): + buffer = BytesIO() + np.save(buffer, target) + input_data = buffer.getvalue() + + actual = decoder._npy_to_numpy(input_data) + + np.testing.assert_equal(actual, np.array(target)) + + +@pytest.mark.parametrize( + "target, expected", + [ + ("[42, 6, 9]", np.array([42, 6, 9])), + ("[42.0, 6.0, 9.0]", np.array([42.0, 6.0, 9.0])), + ('["42", "6", "9"]', np.array(["42", "6", "9"])), + ('["42", "6", "9"]', np.array(["42", "6", "9"])), + ], +) +def test_json_to_numpy(target, expected): + actual = decoder._json_to_numpy(target) + np.testing.assert_equal(actual, expected) + + np.testing.assert_equal(decoder._json_to_numpy(target, dtype=int), expected.astype(int)) + + np.testing.assert_equal(decoder._json_to_numpy(target, dtype=float), expected.astype(float)) + + +@pytest.mark.parametrize( + "target, expected", + [ + ("42\n6\n9\n", np.array([42, 6, 9])), + ("42.0\n6.0\n9.0\n", np.array([42.0, 6.0, 9.0])), + ("42\n6\n9\n", np.array([42, 6, 9])), + ], +) +def test_csv_to_numpy(target, expected): + actual = decoder._csv_to_numpy(target) + np.testing.assert_equal(actual, expected) + + +@pytest.mark.parametrize( + "target", + [ + scipy.sparse.csc_matrix(np.array([[0, 0, 3], [4, 0, 0]])), + scipy.sparse.csr_matrix(np.array([[1, 0], [0, 7]])), + scipy.sparse.coo_matrix(np.array([[6, 2], [5, 9]])), + ], +) +def test_npz_to_sparse(target): + buffer = BytesIO() + scipy.sparse.save_npz(buffer, target) + data = buffer.getvalue() + matrix = decoder._npz_to_sparse(data) + + actual = matrix.toarray() + expected = target.toarray() + + np.testing.assert_equal(actual, expected) + + +def test_decode_error(): + with pytest.raises(errors.UnsupportedFormatError): + decoder.decode(42, content_types.OCTET_STREAM) + + +@pytest.mark.parametrize("content_type", [content_types.JSON, content_types.CSV, content_types.NPY]) +def test_decode(content_type): + mock_decoder = Mock() + with patch.dict(decoder._decoder_map, {content_type: mock_decoder}, clear=True): + decoder.decode(42, content_type) + + mock_decoder.assert_called_once_with(42) diff --git a/test/unit/test_default_handler_service.py b/test/unit/test_default_handler_service.py new file mode 100644 index 0000000..7a516aa --- /dev/null +++ b/test/unit/test_default_handler_service.py @@ -0,0 +1,62 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import MagicMock, Mock, patch + +from sagemaker_inference.default_handler_service import DefaultHandlerService +from sagemaker_inference.transformer import Transformer + +DATA = "data" +CONTEXT = Mock() +TRANSFORMED_RESULT = "transformed_result" + + +@patch("importlib.import_module", return_value=object()) +def test_default_handler_service(import_lib): + handler_service = DefaultHandlerService() + + assert isinstance(handler_service._service, Transformer) + + +def test_default_handler_service_custom_transformer(): + transformer = Mock() + + handler_service = DefaultHandlerService(transformer) + + assert handler_service._service == transformer + + +def test_handle(): + transformer = Mock() + transformer.transform.return_value = TRANSFORMED_RESULT + + handler_service = DefaultHandlerService(transformer) + result = handler_service.handle(DATA, CONTEXT) + + assert result == TRANSFORMED_RESULT + transformer.transform.assert_called_once_with(DATA, CONTEXT) + + +def test_initialize(): + transformer = Mock() + properties = {"model_dir": "/opt/ml/models/model-name"} + + def getitem(key): + return properties[key] + + context = MagicMock() + context.system_properties.__getitem__.side_effect = getitem + DefaultHandlerService(transformer).initialize(context) + + transformer.validate_and_initialize.assert_called_once() diff --git a/test/unit/test_default_inference_handler.py b/test/unit/test_default_inference_handler.py index 427494a..1f969ab 100644 --- a/test/unit/test_default_inference_handler.py +++ b/test/unit/test_default_inference_handler.py @@ -1,271 +1,53 @@ -# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"). You +# Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -import csv -import json -import os - -import mock -import numpy as np +from mock import Mock, patch import pytest -import torch -import torch.nn as nn -from sagemaker_inference import content_types, errors -from six import StringIO, BytesIO -from torch.autograd import Variable - -from sagemaker_pytorch_serving_container import default_pytorch_inference_handler - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -class DummyModel(nn.Module): - - def __init__(self, ): - super(DummyModel, self).__init__() - - def forward(self, x): - pass - - def __call__(self, tensor): - return 3 * tensor +from sagemaker_inference import content_types +from sagemaker_inference.default_inference_handler import DefaultInferenceHandler -@pytest.fixture(scope="session", name="tensor") -def fixture_tensor(): - tensor = torch.rand(5, 10, 7, 9) - return tensor.to(device) +@patch("sagemaker_inference.decoder.decode") +def test_default_input_fn(loads): + context = Mock() + assert DefaultInferenceHandler().default_input_fn(42, content_types.JSON, context) -@pytest.fixture() -def inference_handler(): - return default_pytorch_inference_handler.DefaultPytorchInferenceHandler() - - -@pytest.fixture() -def eia_inference_handler(): - return default_pytorch_inference_handler.DefaultPytorchInferenceHandler() - - -def test_default_model_fn(inference_handler): - with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: - mock_os.getenv.return_value = "true" - mock_os.path.join = os.path.join - mock_os.path.exists.return_value = True - with mock.patch("torch.jit.load") as mock_torch_load: - mock_torch_load.return_value = DummyModel() - model = inference_handler.default_model_fn("model_dir") - assert model is not None - - -def test_default_model_fn_unknown_name(inference_handler): - with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: - mock_os.getenv.return_value = "false" - mock_os.path.join = os.path.join - mock_os.path.exists.return_value = False - mock_os.path.isfile.return_value = True - mock_os.listdir.return_value = ["abcd.pt", "efgh.txt", "ijkl.bin"] - mock_os.path.splitext = os.path.splitext - with mock.patch("torch.jit.load") as mock_torch_load: - mock_torch_load.return_value = DummyModel() - model = inference_handler.default_model_fn("model_dir") - assert model is not None + loads.assert_called_with(42, content_types.JSON) @pytest.mark.parametrize( - "listdir_return_value", [["abcd.py", "efgh.txt", "ijkl.bin"], ["abcd.pt", "efgh.pth"]] + "accept, expected_content_type", + [ + ("text/csv", "text/csv"), + ("text/csv, application/json", "text/csv"), + ("unsupported/type, text/csv", "text/csv"), + ], ) -def test_default_model_fn_no_model_file(inference_handler, listdir_return_value): - with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: - mock_os.getenv.return_value = "false" - mock_os.path.join = os.path.join - mock_os.path.exists.return_value = False - mock_os.path.isfile.return_value = True - mock_os.listdir.return_value = listdir_return_value - mock_os.path.splitext = os.path.splitext - with mock.patch("torch.jit.load") as mock_torch_load: - mock_torch_load.return_value = DummyModel() - with pytest.raises(ValueError, match=r"Exactly one .pth or .pt file is required for PyTorch models: .*"): - inference_handler.default_model_fn("model_dir") - - -def _produce_runtime_error(x, **kwargs): - raise RuntimeError("dummy runtime error") - - -@pytest.mark.parametrize("test_case", ["eia", "non_eia"]) -def test_default_model_fn_non_torchscript_model(inference_handler, test_case): - with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: - mock_os.getenv.return_value = "true" if test_case == "eia" else "false" - mock_os.path.join = os.path.join - mock_os.path.exists.return_value = True - with mock.patch("torch.jit") as mock_torch_jit: - mock_torch_jit.load = _produce_runtime_error - with pytest.raises(Exception, match=r"Failed to load .*. Please ensure model is saved using torchscript."): - inference_handler.default_model_fn("model_dir") - - -def test_default_input_fn_json(inference_handler, tensor): - json_data = json.dumps(tensor.cpu().numpy().tolist()) - deserialized_np_array = inference_handler.default_input_fn(json_data, content_types.JSON) - - assert deserialized_np_array.is_cuda == torch.cuda.is_available() - assert torch.equal(tensor, deserialized_np_array) - - -def test_default_input_fn_csv(inference_handler): - array = [[1, 2, 3], [4, 5, 6]] - str_io = StringIO() - csv.writer(str_io, delimiter=",").writerows(array) - - deserialized_np_array = inference_handler.default_input_fn(str_io.getvalue(), content_types.CSV) - - tensor = torch.FloatTensor(array).to(device) - assert torch.equal(tensor, deserialized_np_array) - assert deserialized_np_array.is_cuda == torch.cuda.is_available() - - -def test_default_input_fn_csv_bad_columns(inference_handler): - str_io = StringIO() - csv_writer = csv.writer(str_io, delimiter=",") - csv_writer.writerow([1, 2, 3]) - csv_writer.writerow([1, 2, 3, 4]) - - with pytest.raises(ValueError): - inference_handler.default_input_fn(str_io.getvalue(), content_types.CSV) - - -def test_default_input_fn_npy(inference_handler, tensor): - stream = BytesIO() - np.save(stream, tensor.cpu().numpy()) - deserialized_np_array = inference_handler.default_input_fn(stream.getvalue(), content_types.NPY) - - assert deserialized_np_array.is_cuda == torch.cuda.is_available() - assert torch.equal(tensor, deserialized_np_array) - - -def test_default_input_fn_bad_content_type(inference_handler): - with pytest.raises(errors.UnsupportedFormatError): - inference_handler.default_input_fn("", "application/not_supported") - - -def test_default_predict_fn(inference_handler, tensor): - model = DummyModel() - prediction = inference_handler.default_predict_fn(tensor, model) - assert torch.equal(model(Variable(tensor)), prediction) - assert prediction.is_cuda == torch.cuda.is_available() - - -def test_default_predict_fn_cpu_cpu(inference_handler, tensor): - prediction = inference_handler.default_predict_fn(tensor.cpu(), DummyModel().cpu()) - - model = DummyModel().to(device) - assert torch.equal(model(Variable(tensor)), prediction) - assert prediction.is_cuda == torch.cuda.is_available() - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -def test_default_predict_fn_cpu_gpu(inference_handler, tensor): - model = DummyModel().cuda() - prediction = inference_handler.default_predict_fn(tensor.cpu(), model) - assert torch.equal(model(tensor), prediction) - assert prediction.is_cuda is True - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -def test_default_predict_fn_gpu_cpu(inference_handler, tensor): - prediction = inference_handler.default_predict_fn(tensor.cpu(), DummyModel().cpu()) - model = DummyModel().cuda() - assert torch.equal(model(tensor), prediction) - assert prediction.is_cuda is True - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -def test_default_predict_fn_gpu_gpu(inference_handler, tensor): - tensor = tensor.cuda() - model = DummyModel().cuda() - prediction = inference_handler.default_predict_fn(tensor, model) - assert torch.equal(model(tensor), prediction) - assert prediction.is_cuda is True - - -def test_default_output_fn_json(inference_handler, tensor): - output = inference_handler.default_output_fn(tensor, content_types.JSON) - - assert json.dumps(tensor.cpu().numpy().tolist()) == output - - -def test_default_output_fn_csv_long(inference_handler): - tensor = torch.LongTensor([[1, 2, 3], [4, 5, 6]]) - output = inference_handler.default_output_fn(tensor, content_types.CSV) - - assert '1,2,3\n4,5,6\n'.encode("utf-8") == output - - -def test_default_output_fn_csv_float(inference_handler): - tensor = torch.FloatTensor([[1, 2, 3], [4, 5, 6]]) - output = inference_handler.default_output_fn(tensor, content_types.CSV) - - assert '1.0,2.0,3.0\n4.0,5.0,6.0\n'.encode("utf-8") == output - - -def test_default_output_fn_multiple_content_types(inference_handler, tensor): - accept = ", ".join(["application/unsupported", content_types.JSON, content_types.CSV]) - output = inference_handler.default_output_fn(tensor, accept) - - assert json.dumps(tensor.cpu().numpy().tolist()) == output - - -def test_default_output_fn_bad_accept(inference_handler): - with pytest.raises(errors.UnsupportedFormatError): - inference_handler.default_output_fn("", "application/not_supported") - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -def test_default_output_fn_gpu(inference_handler): - tensor_gpu = torch.LongTensor([[1, 2, 3], [4, 5, 6]]).cuda() - - output = inference_handler.default_output_fn(tensor_gpu, content_types.CSV) - - assert "1,2,3\n4,5,6\n".encode("utf-8") == output - - -def test_eia_default_model_fn(eia_inference_handler): - with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: - mock_os.getenv.return_value = "true" - mock_os.path.join.return_value = "model_dir" - mock_os.path.exists.return_value = True - with mock.patch("torch.jit.load") as mock_torch: - mock_torch.return_value = DummyModel() - model = eia_inference_handler.default_model_fn("model_dir") - assert model is not None +@patch("sagemaker_inference.encoder.encode", lambda prediction, accept: prediction**2) +def test_default_output_fn(accept, expected_content_type): + context = Mock() + result, content_type = DefaultInferenceHandler().default_output_fn(2, accept, context) + assert result == 4 + assert content_type == expected_content_type -def test_eia_default_model_fn_error(eia_inference_handler): - with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: - mock_os.getenv.return_value = "true" - mock_os.path.join.return_value = "model_dir" - mock_os.path.exists.return_value = False - with pytest.raises(FileNotFoundError): - eia_inference_handler.default_model_fn("model_dir") +def test_default_model_fn(): + with pytest.raises(NotImplementedError): + DefaultInferenceHandler().default_model_fn("model_dir") -def test_eia_default_predict_fn(eia_inference_handler, tensor): - model = DummyModel() - with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: - mock_os.getenv.return_value = "true" - with mock.patch("torch.jit.optimized_execution") as mock_torch: - mock_torch.__enter__.return_value = "dummy" - eia_inference_handler.default_predict_fn(tensor, model) - mock_torch.assert_called_once() +def test_predict_fn(): + with pytest.raises(NotImplementedError): + DefaultInferenceHandler().default_predict_fn("data", "model") diff --git a/test/unit/test_default_pytorch_inference_handler.py b/test/unit/test_default_pytorch_inference_handler.py new file mode 100644 index 0000000..427494a --- /dev/null +++ b/test/unit/test_default_pytorch_inference_handler.py @@ -0,0 +1,271 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import csv +import json +import os + +import mock +import numpy as np +import pytest +import torch +import torch.nn as nn +from sagemaker_inference import content_types, errors +from six import StringIO, BytesIO +from torch.autograd import Variable + +from sagemaker_pytorch_serving_container import default_pytorch_inference_handler + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class DummyModel(nn.Module): + + def __init__(self, ): + super(DummyModel, self).__init__() + + def forward(self, x): + pass + + def __call__(self, tensor): + return 3 * tensor + + +@pytest.fixture(scope="session", name="tensor") +def fixture_tensor(): + tensor = torch.rand(5, 10, 7, 9) + return tensor.to(device) + + +@pytest.fixture() +def inference_handler(): + return default_pytorch_inference_handler.DefaultPytorchInferenceHandler() + + +@pytest.fixture() +def eia_inference_handler(): + return default_pytorch_inference_handler.DefaultPytorchInferenceHandler() + + +def test_default_model_fn(inference_handler): + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "true" + mock_os.path.join = os.path.join + mock_os.path.exists.return_value = True + with mock.patch("torch.jit.load") as mock_torch_load: + mock_torch_load.return_value = DummyModel() + model = inference_handler.default_model_fn("model_dir") + assert model is not None + + +def test_default_model_fn_unknown_name(inference_handler): + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "false" + mock_os.path.join = os.path.join + mock_os.path.exists.return_value = False + mock_os.path.isfile.return_value = True + mock_os.listdir.return_value = ["abcd.pt", "efgh.txt", "ijkl.bin"] + mock_os.path.splitext = os.path.splitext + with mock.patch("torch.jit.load") as mock_torch_load: + mock_torch_load.return_value = DummyModel() + model = inference_handler.default_model_fn("model_dir") + assert model is not None + + +@pytest.mark.parametrize( + "listdir_return_value", [["abcd.py", "efgh.txt", "ijkl.bin"], ["abcd.pt", "efgh.pth"]] +) +def test_default_model_fn_no_model_file(inference_handler, listdir_return_value): + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "false" + mock_os.path.join = os.path.join + mock_os.path.exists.return_value = False + mock_os.path.isfile.return_value = True + mock_os.listdir.return_value = listdir_return_value + mock_os.path.splitext = os.path.splitext + with mock.patch("torch.jit.load") as mock_torch_load: + mock_torch_load.return_value = DummyModel() + with pytest.raises(ValueError, match=r"Exactly one .pth or .pt file is required for PyTorch models: .*"): + inference_handler.default_model_fn("model_dir") + + +def _produce_runtime_error(x, **kwargs): + raise RuntimeError("dummy runtime error") + + +@pytest.mark.parametrize("test_case", ["eia", "non_eia"]) +def test_default_model_fn_non_torchscript_model(inference_handler, test_case): + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "true" if test_case == "eia" else "false" + mock_os.path.join = os.path.join + mock_os.path.exists.return_value = True + with mock.patch("torch.jit") as mock_torch_jit: + mock_torch_jit.load = _produce_runtime_error + with pytest.raises(Exception, match=r"Failed to load .*. Please ensure model is saved using torchscript."): + inference_handler.default_model_fn("model_dir") + + +def test_default_input_fn_json(inference_handler, tensor): + json_data = json.dumps(tensor.cpu().numpy().tolist()) + deserialized_np_array = inference_handler.default_input_fn(json_data, content_types.JSON) + + assert deserialized_np_array.is_cuda == torch.cuda.is_available() + assert torch.equal(tensor, deserialized_np_array) + + +def test_default_input_fn_csv(inference_handler): + array = [[1, 2, 3], [4, 5, 6]] + str_io = StringIO() + csv.writer(str_io, delimiter=",").writerows(array) + + deserialized_np_array = inference_handler.default_input_fn(str_io.getvalue(), content_types.CSV) + + tensor = torch.FloatTensor(array).to(device) + assert torch.equal(tensor, deserialized_np_array) + assert deserialized_np_array.is_cuda == torch.cuda.is_available() + + +def test_default_input_fn_csv_bad_columns(inference_handler): + str_io = StringIO() + csv_writer = csv.writer(str_io, delimiter=",") + csv_writer.writerow([1, 2, 3]) + csv_writer.writerow([1, 2, 3, 4]) + + with pytest.raises(ValueError): + inference_handler.default_input_fn(str_io.getvalue(), content_types.CSV) + + +def test_default_input_fn_npy(inference_handler, tensor): + stream = BytesIO() + np.save(stream, tensor.cpu().numpy()) + deserialized_np_array = inference_handler.default_input_fn(stream.getvalue(), content_types.NPY) + + assert deserialized_np_array.is_cuda == torch.cuda.is_available() + assert torch.equal(tensor, deserialized_np_array) + + +def test_default_input_fn_bad_content_type(inference_handler): + with pytest.raises(errors.UnsupportedFormatError): + inference_handler.default_input_fn("", "application/not_supported") + + +def test_default_predict_fn(inference_handler, tensor): + model = DummyModel() + prediction = inference_handler.default_predict_fn(tensor, model) + assert torch.equal(model(Variable(tensor)), prediction) + assert prediction.is_cuda == torch.cuda.is_available() + + +def test_default_predict_fn_cpu_cpu(inference_handler, tensor): + prediction = inference_handler.default_predict_fn(tensor.cpu(), DummyModel().cpu()) + + model = DummyModel().to(device) + assert torch.equal(model(Variable(tensor)), prediction) + assert prediction.is_cuda == torch.cuda.is_available() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_default_predict_fn_cpu_gpu(inference_handler, tensor): + model = DummyModel().cuda() + prediction = inference_handler.default_predict_fn(tensor.cpu(), model) + assert torch.equal(model(tensor), prediction) + assert prediction.is_cuda is True + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_default_predict_fn_gpu_cpu(inference_handler, tensor): + prediction = inference_handler.default_predict_fn(tensor.cpu(), DummyModel().cpu()) + model = DummyModel().cuda() + assert torch.equal(model(tensor), prediction) + assert prediction.is_cuda is True + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_default_predict_fn_gpu_gpu(inference_handler, tensor): + tensor = tensor.cuda() + model = DummyModel().cuda() + prediction = inference_handler.default_predict_fn(tensor, model) + assert torch.equal(model(tensor), prediction) + assert prediction.is_cuda is True + + +def test_default_output_fn_json(inference_handler, tensor): + output = inference_handler.default_output_fn(tensor, content_types.JSON) + + assert json.dumps(tensor.cpu().numpy().tolist()) == output + + +def test_default_output_fn_csv_long(inference_handler): + tensor = torch.LongTensor([[1, 2, 3], [4, 5, 6]]) + output = inference_handler.default_output_fn(tensor, content_types.CSV) + + assert '1,2,3\n4,5,6\n'.encode("utf-8") == output + + +def test_default_output_fn_csv_float(inference_handler): + tensor = torch.FloatTensor([[1, 2, 3], [4, 5, 6]]) + output = inference_handler.default_output_fn(tensor, content_types.CSV) + + assert '1.0,2.0,3.0\n4.0,5.0,6.0\n'.encode("utf-8") == output + + +def test_default_output_fn_multiple_content_types(inference_handler, tensor): + accept = ", ".join(["application/unsupported", content_types.JSON, content_types.CSV]) + output = inference_handler.default_output_fn(tensor, accept) + + assert json.dumps(tensor.cpu().numpy().tolist()) == output + + +def test_default_output_fn_bad_accept(inference_handler): + with pytest.raises(errors.UnsupportedFormatError): + inference_handler.default_output_fn("", "application/not_supported") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_default_output_fn_gpu(inference_handler): + tensor_gpu = torch.LongTensor([[1, 2, 3], [4, 5, 6]]).cuda() + + output = inference_handler.default_output_fn(tensor_gpu, content_types.CSV) + + assert "1,2,3\n4,5,6\n".encode("utf-8") == output + + +def test_eia_default_model_fn(eia_inference_handler): + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "true" + mock_os.path.join.return_value = "model_dir" + mock_os.path.exists.return_value = True + with mock.patch("torch.jit.load") as mock_torch: + mock_torch.return_value = DummyModel() + model = eia_inference_handler.default_model_fn("model_dir") + assert model is not None + + +def test_eia_default_model_fn_error(eia_inference_handler): + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "true" + mock_os.path.join.return_value = "model_dir" + mock_os.path.exists.return_value = False + with pytest.raises(FileNotFoundError): + eia_inference_handler.default_model_fn("model_dir") + + +def test_eia_default_predict_fn(eia_inference_handler, tensor): + model = DummyModel() + with mock.patch("sagemaker_pytorch_serving_container.default_pytorch_inference_handler.os") as mock_os: + mock_os.getenv.return_value = "true" + with mock.patch("torch.jit.optimized_execution") as mock_torch: + mock_torch.__enter__.return_value = "dummy" + eia_inference_handler.default_predict_fn(tensor, model) + mock_torch.assert_called_once() diff --git a/test/unit/test_encoder.py b/test/unit/test_encoder.py new file mode 100644 index 0000000..07f0547 --- /dev/null +++ b/test/unit/test_encoder.py @@ -0,0 +1,88 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import Mock, patch +import numpy as np +import pytest +from six import BytesIO + +from sagemaker_inference import content_types, encoder, errors + + +@pytest.mark.parametrize( + "target", + ([42, 6, 9], [42.0, 6.0, 9.0], ["42", "6", "9"], ["42", "6", "9"], {42: {"6": 9.0}}), +) +def test_array_to_npy(target): + input_data = np.array(target) + + actual = encoder._array_to_npy(input_data) + + np.testing.assert_equal(np.load(BytesIO(actual), allow_pickle=True), np.array(target)) + + actual = encoder._array_to_npy(target) + + np.testing.assert_equal(np.load(BytesIO(actual), allow_pickle=True), np.array(target)) + + +@pytest.mark.parametrize( + "target, expected", + [ + ([42, 6, 9], "[42, 6, 9]"), + ([42.0, 6.0, 9.0], "[42.0, 6.0, 9.0]"), + (["42", "6", "9"], '["42", "6", "9"]'), + ({42: {"6": 9.0}}, '{"42": {"6": 9.0}}'), + ], +) +def test_array_to_json(target, expected): + actual = encoder._array_to_json(target) + np.testing.assert_equal(actual, expected) + + actual = encoder._array_to_json(np.array(target)) + np.testing.assert_equal(actual, expected) + + +def test_array_to_json_exception(): + with pytest.raises(TypeError): + encoder._array_to_json(lambda x: 3) + + +@pytest.mark.parametrize( + "target, expected", + [ + ([42, 6, 9], "42\n6\n9\n"), + ([42.0, 6.0, 9.0], "42.0\n6.0\n9.0\n"), + (["42", "6", "9"], "42\n6\n9\n"), + ], +) +def test_array_to_csv(target, expected): + actual = encoder._array_to_csv(target) + np.testing.assert_equal(actual, expected) + + actual = encoder._array_to_csv(np.array(target)) + np.testing.assert_equal(actual, expected) + + +@pytest.mark.parametrize("content_type", [content_types.JSON, content_types.CSV, content_types.NPY]) +def test_encode(content_type): + mock_encoder = Mock() + with patch.dict(encoder._encoder_map, {content_type: mock_encoder}, clear=True): + encoder.encode(42, content_type) + + mock_encoder.assert_called_once_with(42) + + +def test_encode_error(): + with pytest.raises(errors.UnsupportedFormatError): + encoder.encode(42, content_types.OCTET_STREAM) diff --git a/test/unit/test_environment.py b/test/unit/test_environment.py new file mode 100644 index 0000000..e206df3 --- /dev/null +++ b/test/unit/test_environment.py @@ -0,0 +1,58 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os + +from mock import patch +import pytest + +from sagemaker_inference import environment, parameters + + +@patch.dict( + os.environ, + { + parameters.USER_PROGRAM_ENV: "main.py", + parameters.MODEL_SERVER_TIMEOUT_ENV: "20", + parameters.MODEL_SERVER_WORKERS_ENV: "8", + parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html", + parameters.BIND_TO_PORT_ENV: "1738", + parameters.SAFE_PORT_RANGE_ENV: "1111-2222", + }, + clear=True, +) +def test_env(): + env = environment.Environment() + + assert environment.base_dir.endswith("/opt/ml") + assert environment.model_dir.endswith("/opt/ml/model") + assert environment.code_dir.endswith("opt/ml/model/code") + assert env.module_name == "main" + assert env.model_server_timeout == 20 + assert env.model_server_workers == "8" + assert env.default_accept == "text/html" + assert env.inference_http_port == "1738" + assert env.management_http_port == "1738" + assert env.safe_port_range == "1111-2222" + + +@pytest.mark.parametrize("sagemaker_program", ["program.py", "program"]) +@patch.dict(os.environ, {}, clear=True) +def test_env_module_name(sagemaker_program): + os.environ[parameters.USER_PROGRAM_ENV] = sagemaker_program + module_name = environment.Environment().module_name + + del os.environ[parameters.USER_PROGRAM_ENV] + + assert module_name == "program" diff --git a/test/unit/test_model_server.py b/test/unit/test_model_server.py index 26d8629..0611960 100644 --- a/test/unit/test_model_server.py +++ b/test/unit/test_model_server.py @@ -11,257 +11,80 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import + import os -import signal import subprocess -import types +import sys -from mock import Mock, patch +import botocore.session +from botocore.stub import Stubber +from mock import MagicMock, patch import pytest -from sagemaker_inference import environment, model_server -from sagemaker_pytorch_serving_container import torchserve -from sagemaker_pytorch_serving_container.torchserve import TS_NAMESPACE - -PYTHON_PATH = "python_path" -DEFAULT_CONFIGURATION = "default_configuration" - - -@patch("subprocess.call") -@patch("subprocess.Popen") -@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process") -@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler") -@patch("sagemaker_inference.model_server._install_requirements") -@patch("os.path.exists", return_value=True) -@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file") -@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path") -def test_start_torchserve_default_service_handler( - set_python_path, - create_config, - exists, - install_requirements, - sigterm, - retrieve, - subprocess_popen, - subprocess_call, -): - torchserve.start_torchserve() - - set_python_path.assert_called_once_with() - create_config.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE) - install_requirements.assert_called_once_with() - - ts_model_server_cmd = [ - "torchserve", - "--start", - "--model-store", - torchserve.MODEL_STORE, - "--ts-config", - torchserve.TS_CONFIG_FILE, - "--log-config", - torchserve.DEFAULT_TS_LOG_FILE, - "--models", - "model=/opt/ml/model" - ] - - subprocess_popen.assert_called_once_with(ts_model_server_cmd) - sigterm.assert_called_once_with(retrieve.return_value) - - -@patch("subprocess.call") -@patch("subprocess.Popen") -@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process") -@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler") -@patch("sagemaker_inference.model_server._install_requirements") -@patch("os.path.exists", return_value=True) -@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file") -@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path") -def test_start_torchserve_default_service_handler_multi_model( - set_python_path, - create_config, - exists, - install_requirements, - sigterm, - retrieve, - subprocess_popen, - subprocess_call, -): - torchserve.ENABLE_MULTI_MODEL = True - torchserve.start_torchserve() - torchserve.ENABLE_MULTI_MODEL = False - - set_python_path.assert_called_once_with() - create_config.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE) - exists.assert_called_once_with(model_server.REQUIREMENTS_PATH) - install_requirements.assert_called_once_with() - - ts_model_server_cmd = [ - "torchserve", - "--start", - "--model-store", - torchserve.MODEL_STORE, - "--ts-config", - torchserve.TS_CONFIG_FILE, - "--log-config", - torchserve.DEFAULT_TS_LOG_FILE, - ] - - subprocess_popen.assert_called_once_with(ts_model_server_cmd) - sigterm.assert_called_once_with(retrieve.return_value) - - -@patch.dict(os.environ, {torchserve.PYTHON_PATH_ENV: PYTHON_PATH}, clear=True) -def test_set_existing_python_path(): - torchserve._set_python_path() - - code_dir_path = "{}:{}".format(environment.code_dir, PYTHON_PATH) - - assert os.environ[torchserve.PYTHON_PATH_ENV] == code_dir_path - - -@patch.dict(os.environ, {}, clear=True) -def test_new_python_path(): - torchserve._set_python_path() - - code_dir_path = environment.code_dir - - assert os.environ[torchserve.PYTHON_PATH_ENV] == code_dir_path - - -@patch("sagemaker_pytorch_serving_container.torchserve._generate_ts_config_properties") -@patch("sagemaker_inference.utils.write_file") -def test_create_torchserve_config_file(write_file, generate_ts_config_props): - torchserve._create_torchserve_config_file(torchserve.DEFAULT_HANDLER_SERVICE) - - write_file.assert_called_once_with( - torchserve.TS_CONFIG_FILE, generate_ts_config_props.return_value - ) - - -@patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) -@patch("sagemaker_inference.environment.Environment") -def test_generate_ts_config_properties(env, read_file): - model_server_timeout = "torchserve_timeout" - model_server_workers = "torchserve_workers" - http_port = "http_port" - - env.return_value.model_server_timeout = model_server_timeout - env.return_value.model_sever_workerse = model_server_workers - env.return_value.inference_http_port = http_port - - ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE) - - inference_address = "inference_address=http://0.0.0.0:{}\n".format(http_port) - server_timeout = "default_response_timeout={}\n".format(model_server_timeout) - - read_file.assert_called_once_with(torchserve.DEFAULT_TS_CONFIG_FILE) - - assert ts_config_properties.startswith(DEFAULT_CONFIGURATION) - assert inference_address in ts_config_properties - assert server_timeout in ts_config_properties - - -@patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) -@patch("sagemaker_inference.environment.Environment") -def test_generate_ts_config_properties_default_workers(env, read_file): - env.return_value.model_server_workers = None - - ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE) - - workers = "default_workers_per_model={}".format(None) - - read_file.assert_called_once_with(torchserve.DEFAULT_TS_CONFIG_FILE) - - assert ts_config_properties.startswith(DEFAULT_CONFIGURATION) - assert workers not in ts_config_properties - - -@patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) -@patch("sagemaker_inference.environment.Environment") -def test_generate_ts_config_properties_multi_model(env, read_file): - env.return_value.model_server_workers = None - - torchserve.ENABLE_MULTI_MODEL = True - ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE) - torchserve.ENABLE_MULTI_MODEL = False - - workers = "default_workers_per_model={}".format(None) - - read_file.assert_called_once_with(torchserve.MME_TS_CONFIG_FILE) - - assert ts_config_properties.startswith(DEFAULT_CONFIGURATION) - assert workers not in ts_config_properties - - -@patch("signal.signal") -def test_add_sigterm_handler(signal_call): - ts = Mock() - - torchserve._add_sigterm_handler(ts) - - mock_calls = signal_call.mock_calls - first_argument = mock_calls[0][1][0] - second_argument = mock_calls[0][1][1] - - assert len(mock_calls) == 1 - assert first_argument == signal.SIGTERM - assert isinstance(second_argument, types.FunctionType) +from sagemaker_inference import model_server @patch("subprocess.check_call") def test_install_requirements(check_call): model_server._install_requirements() - for i in ['pip', 'install', '-r', '/opt/ml/model/code/requirements.txt']: - assert i in check_call.call_args.args[0] + + install_cmd = [ + sys.executable, + "-m", + "pip", + "install", + "-r", + "/opt/ml/model/code/requirements.txt", + ] + check_call.assert_called_once_with(install_cmd) @patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd")) def test_install_requirements_installation_failed(check_call): with pytest.raises(ValueError) as e: model_server._install_requirements() - assert "failed to install required packages" in str(e.value) - - -@patch("retrying.Retrying.should_reject", return_value=False) -@patch("psutil.process_iter") -def test_retrieve_ts_server_process(process_iter, retry): - server = Mock() - server.cmdline.return_value = TS_NAMESPACE - processes = list() - processes.append(server) - - process_iter.return_value = processes - - process = torchserve._retrieve_ts_server_process() - - assert process == server + assert "failed to install required packages" in str(e.value) -@patch("retrying.Retrying.should_reject", return_value=False) -@patch("psutil.process_iter", return_value=list()) -def test_retrieve_ts_server_process_no_server(process_iter, retry): +@patch.dict(os.environ, {"CA_REPOSITORY_ARN": "invalid_arn"}, clear=True) +def test_install_requirements_codeartifact_invalid_arn_installation_failed(): with pytest.raises(Exception) as e: - torchserve._retrieve_ts_server_process() - - assert "Torchserve model server was unsuccessfully started" in str(e.value) - + model_server._install_requirements() -@patch("retrying.Retrying.should_reject", return_value=False) -@patch("psutil.process_iter") -def test_retrieve_ts_server_process_too_many_servers(process_iter, retry): - server = Mock() - second_server = Mock() - server.cmdline.return_value = TS_NAMESPACE - second_server.cmdline.return_value = TS_NAMESPACE + assert "invalid CodeArtifact repository arn invalid_arn" in str(e.value) - processes = list() - processes.append(server) - processes.append(second_server) - process_iter.return_value = processes +@patch("subprocess.check_call") +@patch.dict( + os.environ, + { + "CA_REPOSITORY_ARN": "arn:aws:codeartifact:my_region:012345678900:repository/my_domain/my_repo" + }, + clear=True, +) +def test_install_requirements_codeartifact(check_call): + # mock/stub codeartifact client and its responses + endpoint = "https://domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/" + codeartifact = botocore.session.get_session().create_client( + "codeartifact", region_name="myregion" + ) + stubber = Stubber(codeartifact) + stubber.add_response("get_authorization_token", {"authorizationToken": "the-auth-token"}) + stubber.add_response("get_repository_endpoint", {"repositoryEndpoint": endpoint}) + stubber.activate() - with pytest.raises(Exception) as e: - torchserve._retrieve_ts_server_process() + with patch("boto3.client", MagicMock(return_value=codeartifact)): + model_server._install_requirements() - assert "multiple ts model servers are not supported" in str(e.value) + install_cmd = [ + sys.executable, + "-m", + "pip", + "install", + "-r", + "/opt/ml/model/code/requirements.txt", + "-i", + "https://aws:the-auth-token@domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/simple/", + ] + check_call.assert_called_once_with(install_cmd) diff --git a/test/unit/test_torchserve.py b/test/unit/test_torchserve.py new file mode 100644 index 0000000..26d8629 --- /dev/null +++ b/test/unit/test_torchserve.py @@ -0,0 +1,267 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import os +import signal +import subprocess +import types + +from mock import Mock, patch +import pytest + +from sagemaker_inference import environment, model_server +from sagemaker_pytorch_serving_container import torchserve +from sagemaker_pytorch_serving_container.torchserve import TS_NAMESPACE + +PYTHON_PATH = "python_path" +DEFAULT_CONFIGURATION = "default_configuration" + + +@patch("subprocess.call") +@patch("subprocess.Popen") +@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process") +@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler") +@patch("sagemaker_inference.model_server._install_requirements") +@patch("os.path.exists", return_value=True) +@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file") +@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path") +def test_start_torchserve_default_service_handler( + set_python_path, + create_config, + exists, + install_requirements, + sigterm, + retrieve, + subprocess_popen, + subprocess_call, +): + torchserve.start_torchserve() + + set_python_path.assert_called_once_with() + create_config.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE) + install_requirements.assert_called_once_with() + + ts_model_server_cmd = [ + "torchserve", + "--start", + "--model-store", + torchserve.MODEL_STORE, + "--ts-config", + torchserve.TS_CONFIG_FILE, + "--log-config", + torchserve.DEFAULT_TS_LOG_FILE, + "--models", + "model=/opt/ml/model" + ] + + subprocess_popen.assert_called_once_with(ts_model_server_cmd) + sigterm.assert_called_once_with(retrieve.return_value) + + +@patch("subprocess.call") +@patch("subprocess.Popen") +@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process") +@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler") +@patch("sagemaker_inference.model_server._install_requirements") +@patch("os.path.exists", return_value=True) +@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file") +@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path") +def test_start_torchserve_default_service_handler_multi_model( + set_python_path, + create_config, + exists, + install_requirements, + sigterm, + retrieve, + subprocess_popen, + subprocess_call, +): + torchserve.ENABLE_MULTI_MODEL = True + torchserve.start_torchserve() + torchserve.ENABLE_MULTI_MODEL = False + + set_python_path.assert_called_once_with() + create_config.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE) + exists.assert_called_once_with(model_server.REQUIREMENTS_PATH) + install_requirements.assert_called_once_with() + + ts_model_server_cmd = [ + "torchserve", + "--start", + "--model-store", + torchserve.MODEL_STORE, + "--ts-config", + torchserve.TS_CONFIG_FILE, + "--log-config", + torchserve.DEFAULT_TS_LOG_FILE, + ] + + subprocess_popen.assert_called_once_with(ts_model_server_cmd) + sigterm.assert_called_once_with(retrieve.return_value) + + +@patch.dict(os.environ, {torchserve.PYTHON_PATH_ENV: PYTHON_PATH}, clear=True) +def test_set_existing_python_path(): + torchserve._set_python_path() + + code_dir_path = "{}:{}".format(environment.code_dir, PYTHON_PATH) + + assert os.environ[torchserve.PYTHON_PATH_ENV] == code_dir_path + + +@patch.dict(os.environ, {}, clear=True) +def test_new_python_path(): + torchserve._set_python_path() + + code_dir_path = environment.code_dir + + assert os.environ[torchserve.PYTHON_PATH_ENV] == code_dir_path + + +@patch("sagemaker_pytorch_serving_container.torchserve._generate_ts_config_properties") +@patch("sagemaker_inference.utils.write_file") +def test_create_torchserve_config_file(write_file, generate_ts_config_props): + torchserve._create_torchserve_config_file(torchserve.DEFAULT_HANDLER_SERVICE) + + write_file.assert_called_once_with( + torchserve.TS_CONFIG_FILE, generate_ts_config_props.return_value + ) + + +@patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) +@patch("sagemaker_inference.environment.Environment") +def test_generate_ts_config_properties(env, read_file): + model_server_timeout = "torchserve_timeout" + model_server_workers = "torchserve_workers" + http_port = "http_port" + + env.return_value.model_server_timeout = model_server_timeout + env.return_value.model_sever_workerse = model_server_workers + env.return_value.inference_http_port = http_port + + ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE) + + inference_address = "inference_address=http://0.0.0.0:{}\n".format(http_port) + server_timeout = "default_response_timeout={}\n".format(model_server_timeout) + + read_file.assert_called_once_with(torchserve.DEFAULT_TS_CONFIG_FILE) + + assert ts_config_properties.startswith(DEFAULT_CONFIGURATION) + assert inference_address in ts_config_properties + assert server_timeout in ts_config_properties + + +@patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) +@patch("sagemaker_inference.environment.Environment") +def test_generate_ts_config_properties_default_workers(env, read_file): + env.return_value.model_server_workers = None + + ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE) + + workers = "default_workers_per_model={}".format(None) + + read_file.assert_called_once_with(torchserve.DEFAULT_TS_CONFIG_FILE) + + assert ts_config_properties.startswith(DEFAULT_CONFIGURATION) + assert workers not in ts_config_properties + + +@patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) +@patch("sagemaker_inference.environment.Environment") +def test_generate_ts_config_properties_multi_model(env, read_file): + env.return_value.model_server_workers = None + + torchserve.ENABLE_MULTI_MODEL = True + ts_config_properties = torchserve._generate_ts_config_properties(torchserve.DEFAULT_HANDLER_SERVICE) + torchserve.ENABLE_MULTI_MODEL = False + + workers = "default_workers_per_model={}".format(None) + + read_file.assert_called_once_with(torchserve.MME_TS_CONFIG_FILE) + + assert ts_config_properties.startswith(DEFAULT_CONFIGURATION) + assert workers not in ts_config_properties + + +@patch("signal.signal") +def test_add_sigterm_handler(signal_call): + ts = Mock() + + torchserve._add_sigterm_handler(ts) + + mock_calls = signal_call.mock_calls + first_argument = mock_calls[0][1][0] + second_argument = mock_calls[0][1][1] + + assert len(mock_calls) == 1 + assert first_argument == signal.SIGTERM + assert isinstance(second_argument, types.FunctionType) + + +@patch("subprocess.check_call") +def test_install_requirements(check_call): + model_server._install_requirements() + for i in ['pip', 'install', '-r', '/opt/ml/model/code/requirements.txt']: + assert i in check_call.call_args.args[0] + + +@patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd")) +def test_install_requirements_installation_failed(check_call): + with pytest.raises(ValueError) as e: + model_server._install_requirements() + assert "failed to install required packages" in str(e.value) + + +@patch("retrying.Retrying.should_reject", return_value=False) +@patch("psutil.process_iter") +def test_retrieve_ts_server_process(process_iter, retry): + server = Mock() + server.cmdline.return_value = TS_NAMESPACE + + processes = list() + processes.append(server) + + process_iter.return_value = processes + + process = torchserve._retrieve_ts_server_process() + + assert process == server + + +@patch("retrying.Retrying.should_reject", return_value=False) +@patch("psutil.process_iter", return_value=list()) +def test_retrieve_ts_server_process_no_server(process_iter, retry): + with pytest.raises(Exception) as e: + torchserve._retrieve_ts_server_process() + + assert "Torchserve model server was unsuccessfully started" in str(e.value) + + +@patch("retrying.Retrying.should_reject", return_value=False) +@patch("psutil.process_iter") +def test_retrieve_ts_server_process_too_many_servers(process_iter, retry): + server = Mock() + second_server = Mock() + server.cmdline.return_value = TS_NAMESPACE + second_server.cmdline.return_value = TS_NAMESPACE + + processes = list() + processes.append(server) + processes.append(second_server) + + process_iter.return_value = processes + + with pytest.raises(Exception) as e: + torchserve._retrieve_ts_server_process() + + assert "multiple ts model servers are not supported" in str(e.value) diff --git a/test/unit/test_transfomer.py b/test/unit/test_transfomer.py new file mode 100644 index 0000000..5149646 --- /dev/null +++ b/test/unit/test_transfomer.py @@ -0,0 +1,540 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import call, Mock, patch +import pytest + +try: + import http.client as http_client +except ImportError: + import httplib as http_client + +from sagemaker_inference import content_types, environment +from sagemaker_inference.default_inference_handler import DefaultInferenceHandler +from sagemaker_inference.errors import BaseInferenceToolkitError +from sagemaker_inference.transformer import Transformer + +INPUT_DATA = "input_data" +CONTENT_TYPE = "content_type" +ACCEPT = "accept" +DEFAULT_ACCEPT = "default_accept" +RESULT = "result" +MODEL = "foo" + +PREPROCESSED_DATA = "preprocessed_data" +PREDICT_RESULT = "prediction_result" +PROCESSED_RESULT = "processed_result" + + +def test_default_transformer(): + transformer = Transformer() + + assert isinstance(transformer._default_inference_handler, DefaultInferenceHandler) + assert transformer._initialized is False + assert transformer._environment is None + assert transformer._pre_model_fn is None + assert transformer._model_warmup_fn is None + assert transformer._model is None + assert transformer._model_fn is None + assert transformer._transform_fn is None + assert transformer._input_fn is None + assert transformer._predict_fn is None + assert transformer._output_fn is None + assert transformer._context is None + + +def test_transformer_with_custom_default_inference_handler(): + default_inference_handler = Mock() + + transformer = Transformer(default_inference_handler) + + assert transformer._default_inference_handler == default_inference_handler + assert transformer._initialized is False + assert transformer._environment is None + assert transformer._pre_model_fn is None + assert transformer._model_warmup_fn is None + assert transformer._model is None + assert transformer._model_fn is None + assert transformer._transform_fn is None + assert transformer._input_fn is None + assert transformer._predict_fn is None + assert transformer._output_fn is None + assert transformer._context is None + + +@pytest.mark.parametrize("accept_key", ["Accept", "accept"]) +@patch("sagemaker_inference.transformer.Transformer._run_handler_function", return_value=RESULT) +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") +def test_transform(validate, retrieve_content_type_header, run_handler, accept_key): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock() + + context.request_processor = [request_processor] + request_property = {accept_key: ACCEPT} + request_processor.get_request_properties.return_value = request_property + + transformer = Transformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + result = transformer.transform(data, context) + + validate.assert_called_once() + retrieve_content_type_header.assert_called_once_with(request_property) + run_handler.assert_called_once_with( + transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT + ) + context.set_response_content_type.assert_called_once_with(0, ACCEPT) + assert isinstance(result, list) + assert result[0] == RESULT + + +@pytest.mark.parametrize("accept_key", ["Accept", "accept"]) +@patch("sagemaker_inference.transformer.Transformer._run_handler_function", return_value=RESULT) +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") +def test_batch_transform(validate, retrieve_content_type_header, run_handler, accept_key): + data = [{"body": INPUT_DATA}, {"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock() + + context.request_processor = [request_processor] + request_property = {accept_key: ACCEPT} + request_processor.get_request_properties.return_value = request_property + + transformer = Transformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + result = transformer.transform(data, context) + + validate.assert_called_once() + retrieve_content_type_header.assert_called_with(request_property) + assert retrieve_content_type_header.call_count == 2 + run_handler.assert_called_with( + transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT + ) + assert run_handler.call_count == 2 + context.set_response_content_type.assert_called_with(0, ACCEPT) + assert context.set_response_content_type.call_count == 2 + assert isinstance(result, list) + assert result == [RESULT, RESULT] + + +@patch("sagemaker_inference.transformer.Transformer._run_handler_function") +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") +def test_transform_no_accept(validate, retrieve_content_type_header, run_handler): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock() + environment = Mock() + environment.default_accept = DEFAULT_ACCEPT + + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = dict() + + transformer = Transformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._environment = environment + transformer._context = context + + transformer.transform(data, context) + + validate.assert_called_once() + run_handler.assert_called_once_with( + transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT + ) + + +@patch("sagemaker_inference.transformer.Transformer._run_handler_function") +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") +def test_transform_any_accept(validate, retrieve_content_type_header, run_handler): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock() + environment = Mock() + environment.default_accept = DEFAULT_ACCEPT + + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": content_types.ANY} + + transformer = Transformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._environment = environment + transformer._context = context + + transformer.transform(data, context) + + validate.assert_called_once() + run_handler.assert_called_once_with( + transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT + ) + + +@pytest.mark.parametrize("content_type", content_types.UTF8_TYPES) +@patch("sagemaker_inference.transformer.Transformer._run_handler_function") +@patch("sagemaker_inference.utils.retrieve_content_type_header") +@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") +def test_transform_decode(validate, retrieve_content_type_header, run_handler, content_type): + input_data = Mock() + context = Mock() + request_processor = Mock() + transform_fn = Mock() + data = [{"body": input_data}] + + input_data.decode.return_value = INPUT_DATA + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": ACCEPT} + retrieve_content_type_header.return_value = content_type + + transformer = Transformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + transformer.transform(data, context) + + input_data.decode.assert_called_once_with("utf-8") + run_handler.assert_called_once_with( + transformer._transform_fn, MODEL, INPUT_DATA, content_type, ACCEPT + ) + + +@patch( + "sagemaker_inference.transformer.Transformer._run_handler_function", + return_value=(RESULT, ACCEPT), +) +@patch("sagemaker_inference.utils.retrieve_content_type_header", return_value=CONTENT_TYPE) +@patch("sagemaker_inference.transformer.Transformer.validate_and_initialize") +def test_transform_tuple(validate, retrieve_content_type_header, run_handler): + data = [{"body": INPUT_DATA}] + context = Mock() + request_processor = Mock() + transform_fn = Mock() + + context.request_processor = [request_processor] + request_processor.get_request_properties.return_value = {"accept": ACCEPT} + + transformer = Transformer() + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._context = context + + result = transformer.transform(data, context) + + run_handler.assert_called_once_with( + transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT + ) + context.set_response_content_type.assert_called_once_with(0, run_handler()[1]) + assert isinstance(result, list) + assert result[0] == run_handler()[0] + + +@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_validate_and_initialize(env, validate_user_module): + transformer = Transformer() + + model_fn = Mock() + context = Mock() + transformer._model_fn = model_fn + + assert transformer._initialized is False + assert transformer._context is None + + transformer.validate_and_initialize(context=context) + + assert transformer._initialized is True + assert transformer._context == context + + transformer.validate_and_initialize() + + model_fn.assert_called_once_with(environment.model_dir, context) + env.assert_called_once_with() + validate_user_module.assert_called_once_with() + + +@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_handle_validate_and_initialize_error(env, validate_user_module): + data = [{"body": INPUT_DATA}] + request_processor = Mock() + + context = Mock() + context.request_processor = [request_processor] + + transform_fn = Mock() + model_fn = Mock() + + transformer = Transformer() + + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._model_fn = model_fn + transformer._context = context + + test_error_message = "Foo" + validate_user_module.side_effect = ValueError(test_error_message) + + assert transformer._initialized is False + + response = transformer.transform(data, context) + assert test_error_message in str(response) + assert "Traceback (most recent call last)" in str(response) + context.set_response_status.assert_called_with( + code=http_client.INTERNAL_SERVER_ERROR, phrase=test_error_message + ) + + +@patch("sagemaker_inference.transformer.Transformer._validate_user_module_and_set_functions") +@patch("sagemaker_inference.environment.Environment") +def test_handle_validate_and_initialize_user_error(env, validate_user_module): + test_status_code = http_client.FORBIDDEN + test_error_message = "Foo" + + class FooUserError(BaseInferenceToolkitError): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.phrase = "Foo" + + data = [{"body": INPUT_DATA}] + context = Mock() + transform_fn = Mock() + model_fn = Mock() + + transformer = Transformer() + + transformer._model = MODEL + transformer._transform_fn = transform_fn + transformer._model_fn = model_fn + transformer._context = context + + validate_user_module.side_effect = FooUserError(test_status_code, test_error_message) + + assert transformer._initialized is False + + response = transformer.transform(data, context) + assert test_error_message in str(response) + assert "Traceback (most recent call last)" in str(response) + context.set_response_status.assert_called_with( + code=http_client.FORBIDDEN, phrase=test_error_message + ) + + +class UserModuleMock: + def __init__(self, transform_fn=Mock(), input_fn=Mock(), predict_fn=Mock(), output_fn=Mock()): + self.transform_fn = transform_fn + self.input_fn = input_fn + self.predict_fn = predict_fn + self.output_fn = output_fn + + +@patch("importlib.import_module") +@patch("sagemaker_inference.transformer.find_spec", return_value=None) +def test_validate_no_user_module_and_set_functions(find_spec, import_module): + default_inference_handler = Mock() + mock_env = Mock() + mock_env.module_name = "foo_module" + + default_pre_model_fn = object() + default_model_warmup_fn = object() + default_model_fn = object() + default_input_fn = object() + default_predict_fn = object() + default_output_fn = object() + + default_inference_handler.default_pre_model_fn = default_pre_model_fn + default_inference_handler.default_model_warmup_fn = default_model_warmup_fn + default_inference_handler.default_model_fn = default_model_fn + default_inference_handler.default_input_fn = default_input_fn + default_inference_handler.default_predict_fn = default_predict_fn + default_inference_handler.default_output_fn = default_output_fn + + transformer = Transformer(default_inference_handler) + transformer._environment = mock_env + transformer._validate_user_module_and_set_functions() + + find_spec.assert_called_once_with(mock_env.module_name) + import_module.assert_not_called() + assert transformer._default_inference_handler == default_inference_handler + assert transformer._environment == mock_env + assert transformer._pre_model_fn == default_pre_model_fn + assert transformer._model_warmup_fn == default_model_warmup_fn + assert transformer._model_fn == default_model_fn + assert transformer._input_fn == default_input_fn + assert transformer._predict_fn == default_predict_fn + assert transformer._output_fn == default_output_fn + + +@patch("importlib.import_module", return_value=object()) +@patch("sagemaker_inference.transformer.find_spec", return_value=Mock()) +def test_validate_user_module_and_set_functions(find_spec, import_module): + default_inference_handler = Mock() + mock_env = Mock() + mock_env.module_name = "foo_module" + + default_pre_model_fn = object() + default_model_warmup_fn = object() + default_model_fn = object() + default_input_fn = object() + default_predict_fn = object() + default_output_fn = object() + + default_inference_handler.default_pre_model_fn = default_pre_model_fn + default_inference_handler.default_model_warmup_fn = default_model_warmup_fn + default_inference_handler.default_model_fn = default_model_fn + default_inference_handler.default_input_fn = default_input_fn + default_inference_handler.default_predict_fn = default_predict_fn + default_inference_handler.default_output_fn = default_output_fn + + transformer = Transformer(default_inference_handler) + transformer._environment = mock_env + transformer._validate_user_module_and_set_functions() + + find_spec.assert_called_once_with(mock_env.module_name) + import_module.assert_called_once_with(mock_env.module_name) + assert transformer._default_inference_handler == default_inference_handler + assert transformer._environment == mock_env + assert transformer._pre_model_fn == default_pre_model_fn + assert transformer._model_warmup_fn == default_model_warmup_fn + assert transformer._model_fn == default_model_fn + assert transformer._input_fn == default_input_fn + assert transformer._predict_fn == default_predict_fn + assert transformer._output_fn == default_output_fn + + +@patch( + "importlib.import_module", + return_value=UserModuleMock(input_fn=None, predict_fn=None, output_fn=None), +) +@patch("sagemaker_inference.transformer.find_spec", return_value=Mock()) +def test_validate_user_module_and_set_functions_transform_fn(find_spec, import_module): + mock_env = Mock() + mock_env.module_name = "foo_module" + + import_module.transform_fn = Mock() + + transformer = Transformer() + transformer._environment = mock_env + + transformer._validate_user_module_and_set_functions() + + find_spec.assert_called_once_with(mock_env.module_name) + import_module.assert_called_once_with(mock_env.module_name) + assert transformer._transform_fn == import_module.return_value.transform_fn + + +def _assert_value_error_raised(): + with pytest.raises(ValueError) as e: + transformer = Transformer() + transformer._environment = Mock() + transformer._validate_user_module_and_set_functions() + + assert ( + "Cannot use transform_fn implementation in conjunction with input_fn, predict_fn, " + "and/or output_fn implementation" in str(e.value) + ) + + +@pytest.mark.parametrize( + "user_module", + [ + UserModuleMock(input_fn=None), + UserModuleMock(predict_fn=None), + UserModuleMock(output_fn=None), + UserModuleMock(output_fn=None, predict_fn=None), + UserModuleMock(input_fn=None, output_fn=None), + UserModuleMock(input_fn=None, predict_fn=None), + UserModuleMock(), + ], +) +@patch("importlib.import_module") +@patch("sagemaker_inference.transformer.find_spec", return_value=Mock()) +def test_validate_user_module_error(find_spec, import_module, user_module): + import_module.return_value = user_module + + _assert_value_error_raised() + + +@patch( + "sagemaker_inference.transformer.Transformer._run_handler_function", + side_effect=[PREPROCESSED_DATA, PREDICT_RESULT, PROCESSED_RESULT], +) +def test_default_transform_fn(run_handle_function): + transformer = Transformer() + context = Mock() + transformer._context = context + + input_fn = Mock() + predict_fn = Mock(return_value=PREDICT_RESULT) + output_fn = Mock(return_value=PROCESSED_RESULT) + + transformer._input_fn = input_fn + transformer._predict_fn = predict_fn + transformer._output_fn = output_fn + + result = transformer._default_transform_fn(MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT, context) + + run_handle_function.assert_has_calls( + [ + call(transformer._input_fn, *(INPUT_DATA, CONTENT_TYPE)), + call(transformer._predict_fn, *(PREPROCESSED_DATA, MODEL)), + call(transformer._output_fn, *(PREDICT_RESULT, ACCEPT)), + ] + ) + + assert result == PROCESSED_RESULT + + +def dummy_handler_func(a, b): + return b + + +def test_run_handler_function(): + arg1 = Mock() + arg2 = Mock() + context = Mock() + transformer = Transformer() + transformer._context = context + + # test the case when handler function takes context + assert transformer._run_handler_function(dummy_handler_func, arg1) == context + + # test the case when handler function does not take context + assert transformer._run_handler_function(dummy_handler_func, arg1, arg2) == arg2 + + +def test_run_handler_function_raise_error(): + with pytest.raises(TypeError) as e: + a = Mock() + b = Mock() + c = Mock() + transformer = Transformer() + transformer._context = Mock() + transformer._run_handler_function(dummy_handler_func, a, b, c) + + assert "dummy_handler_func takes 2 arguments but 3 were given." in str(e.value) diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py new file mode 100644 index 0000000..1a85e03 --- /dev/null +++ b/test/unit/test_utils.py @@ -0,0 +1,104 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the 'license' file accompanying this file. This file is +# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import Mock, mock_open, patch +import pytest + +from sagemaker_inference.utils import ( + parse_accept, + read_file, + remove_crlf, + retrieve_content_type_header, + write_file, +) + +TEXT = "text" +CONTENT_TYPE = "content_type" + + +@patch("sagemaker_inference.utils.open", new_callable=mock_open, read_data=TEXT) +def test_read_file(with_open): + path = Mock() + + result = read_file(path) + + with_open.assert_called_once_with(path, "r") + with_open().read.assert_called_once_with() + assert TEXT == result + + +@patch("sagemaker_inference.utils.open", new_callable=mock_open, read_data=TEXT) +def test_read_file_mode(with_open): + path = Mock() + mode = Mock() + + result = read_file(path, mode) + + with_open.assert_called_once_with(path, mode) + with_open().read.assert_called_once_with() + assert result == TEXT + + +@patch("sagemaker_inference.utils.open", new_callable=mock_open) +def test_write_file(with_open): + path = Mock() + data = Mock() + + write_file(path, data) + + with_open.assert_called_once_with(path, "w") + with_open().write.assert_called_once_with(data) + + +@patch("sagemaker_inference.utils.open", new_callable=mock_open) +def test_write_file_mode(with_open): + path = Mock() + data = Mock() + mode = Mock() + + write_file(path, data, mode) + + with_open.assert_called_once_with(path, mode) + with_open().write.assert_called_once_with(data) + + +@pytest.mark.parametrize( + "content_type_key", ["Content-Type", "Content-type", "content-type", "ContentType"] +) +def test_content_type_header(content_type_key): + request_property = {content_type_key: CONTENT_TYPE} + + result = retrieve_content_type_header(request_property) + + assert result == CONTENT_TYPE + + +@pytest.mark.parametrize( + "input, expected", + [ + ("application/json", ["application/json"]), + ("application/json, text/csv", ["application/json", "text/csv"]), + ("application/json,text/csv", ["application/json", "text/csv"]), + ], +) +def test_parse_accept(input, expected): + actual = parse_accept(input) + assert actual == expected + + +def test_remove_crlf(): + illegal_string = "test:\r\nstring" + sanitized_string = "test: string" + + assert sanitized_string == remove_crlf(illegal_string) diff --git a/tox.ini b/tox.ini index 3f8befb..ca02278 100644 --- a/tox.ini +++ b/tox.ini @@ -51,7 +51,7 @@ passenv = AWS_CONTAINER_CREDENTIALS_RELATIVE_URI AWS_DEFAULT_REGION commands = - coverage run --rcfile .coveragerc --source sagemaker_pytorch_serving_container -m pytest {posargs} + coverage run --rcfile .coveragerc --source sagemaker_pytorch_serving_container,sagemaker_inference -m pytest {posargs} {env:IGNORE_COVERAGE:} coverage report --fail-under=90 deps =