Skip to content

Add context to handler functions #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 64 additions & 18 deletions src/sagemaker_huggingface_inference_toolkit/handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import time
from abc import ABC
from inspect import signature

from sagemaker_inference import environment, utils
from transformers.pipelines import SUPPORTED_TASKS
Expand Down Expand Up @@ -57,6 +58,11 @@ def __init__(self):
self.context = None
self.manifest = None
self.environment = environment.Environment()
self.load_extra_arg = []
self.preprocess_extra_arg = []
self.predict_extra_arg = []
self.postprocess_extra_arg = []
self.transform_extra_arg = []

def initialize(self, context):
"""
Expand All @@ -74,7 +80,7 @@ def initialize(self, context):
self.validate_and_initialize_user_module()

self.device = self.get_device()
self.model = self.load(self.model_dir)
self.model = self.load(*([self.model_dir] + self.load_extra_arg))
self.initialized = True
# # Load methods from file
# if (not self._initialized) and ENABLE_MULTI_MODEL:
Expand All @@ -92,10 +98,15 @@ def get_device(self):
else:
return -1

def load(self, model_dir):
def load(self, model_dir, context=None):
"""
The Load handler is responsible for loading the Hugging Face transformer model.
It can be overridden to load the model from storage
It can be overridden to load the model from storage.

Args:
model_dir (str): The directory where model files are stored.
context (obj): metadata on the incoming request data (default: None).

Returns:
hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
"""
Expand All @@ -111,14 +122,16 @@ def load(self, model_dir):
)
return hf_pipeline

def preprocess(self, input_data, content_type):
def preprocess(self, input_data, content_type, context=None):
"""
The preprocess handler is responsible for deserializing the input data into
an object for prediction, can handle JSON.
The preprocess handler can be overridden for data or feature transformation,
The preprocess handler can be overridden for data or feature transformation.

Args:
input_data: the request payload serialized in the content_type format
content_type: the request content_type
input_data: the request payload serialized in the content_type format.
content_type: the request content_type.
context (obj): metadata on the incoming request data (default: None).

Returns:
decoded_input_data (dict): deserialized input_data into a Python dictonary.
Expand All @@ -136,13 +149,16 @@ def preprocess(self, input_data, content_type):
decoded_input_data = decoder_encoder.decode(input_data, content_type)
return decoded_input_data

def predict(self, data, model):
def predict(self, data, model, context=None):
"""The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
The predict handler can be overridden to implement the model inference.

Args:
data (dict): deserialized decoded_input_data returned by the input_fn
model : Model returned by the `load` method or if it is a custom module `model_fn`.
context (obj): metadata on the incoming request data (default: None).

Returns:
obj (dict): prediction result.
"""
Expand All @@ -158,38 +174,42 @@ def predict(self, data, model):
prediction = model(inputs)
return prediction

def postprocess(self, prediction, accept):
def postprocess(self, prediction, accept, context=None):
"""
The postprocess handler is responsible for serializing the prediction result to
the desired accept type, can handle JSON.
The postprocess handler can be overridden for inference response transformation
The postprocess handler can be overridden for inference response transformation.

Args:
prediction (dict): a prediction result from predict
accept (str): type which the output data needs to be serialized
prediction (dict): a prediction result from predict.
accept (str): type which the output data needs to be serialized.
context (obj): metadata on the incoming request data (default: None).
Returns: output data serialized
"""
return decoder_encoder.encode(prediction, accept)

def transform_fn(self, model, input_data, content_type, accept):
def transform_fn(self, model, input_data, content_type, accept, context=None):
"""
Transform function ("transform_fn") can be used to write one function with pre/post-processing steps and predict step in it.
This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn"
This fuction can't be mixed with "input_fn", "output_fn" or "predict_fn".

Args:
model: Model returned by the model_fn above
input_data: Data received for inference
content_type: The content type of the inference data
accept: The response accept type.
context (obj): metadata on the incoming request data (default: None).

Returns: Response in the "accept" format type.

"""
# run pipeline
start_time = time.time()
processed_data = self.preprocess(input_data, content_type)
processed_data = self.preprocess(*([input_data, content_type] + self.preprocess_extra_arg))
preprocess_time = time.time() - start_time
predictions = self.predict(processed_data, model)
predictions = self.predict(*([processed_data, model] + self.predict_extra_arg))
predict_time = time.time() - preprocess_time - start_time
response = self.postprocess(predictions, accept)
response = self.postprocess(*([predictions, accept] + self.postprocess_extra_arg))
postprocess_time = time.time() - predict_time - preprocess_time - start_time

logger.info(
Expand Down Expand Up @@ -231,7 +251,7 @@ def handle(self, data, context):
input_data = input_data.decode("utf-8")

predict_start = time.time()
response = self.transform_fn(self.model, input_data, content_type, accept)
response = self.transform_fn(*([self.model, input_data, content_type, accept] + self.transform_extra_arg))
predict_end = time.time()

context.metrics.add_time("Transform Fn", round((predict_end - predict_start) * 1000, 2))
Expand Down Expand Up @@ -263,12 +283,38 @@ def validate_and_initialize_user_module(self):
)

if load_fn is not None:
self.load_extra_arg = self.function_extra_arg(self.load, load_fn)
self.load = load_fn
if preprocess_fn is not None:
self.preprocess_extra_arg = self.function_extra_arg(self.preprocess, preprocess_fn)
self.preprocess = preprocess_fn
if predict_fn is not None:
self.predict_extra_arg = self.function_extra_arg(self.predict, predict_fn)
self.predict = predict_fn
if postprocess_fn is not None:
self.postprocess_extra_arg = self.function_extra_arg(self.postprocess, postprocess_fn)
self.postprocess = postprocess_fn
if transform_fn is not None:
self.transform_extra_arg = self.function_extra_arg(self.transform_fn, transform_fn)
self.transform_fn = transform_fn

def function_extra_arg(self, default_func, func):
"""Helper to call the handler function which covers 2 cases:
1. the handle function takes context
2. the handle function does not take context
"""
num_default_func_input = len(signature(default_func).parameters)
num_func_input = len(signature(func).parameters)
if num_default_func_input == num_func_input:
# function takes context
extra_args = [self.context]
elif num_default_func_input == num_func_input + 1:
# function does not take context
extra_args = []
else:
raise TypeError(
"{} definition takes {} or {} arguments but {} were given.".format(
func.__name__, num_default_func_input - 1, num_default_func_input, num_func_input
)
)
return extra_args
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def model_fn(model_dir, context=None):
return "model"


def input_fn(data, content_type, context=None):
return "data"


def predict_fn(data, model, context=None):
return "output"


def output_fn(prediction, accept, context=None):
return prediction
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os


def model_fn(model_dir, context=None):
return f"Loading {os.path.basename(__file__)}"


def transform_fn(a, b, c, d, context=None):
return f"output {b}"
168 changes: 168 additions & 0 deletions tests/unit/test_handler_service_with_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile

import pytest
from sagemaker_inference import content_types
from transformers.testing_utils import require_torch, slow

from mms.context import Context, RequestProcessor
from mms.metrics.metrics_store import MetricsStore
from mock import Mock
from sagemaker_huggingface_inference_toolkit import handler_service
from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline


TASK = "text-classification"
MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
INPUT = {"inputs": "My name is Wolfgang and I live in Berlin"}
OUTPUT = [
{"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19},
{"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40},
]


@pytest.fixture()
def inference_handler():
return handler_service.HuggingFaceHandlerService()


def test_get_device_cpu(inference_handler):
device = inference_handler.get_device()
assert device == -1


@slow
def test_get_device_gpu(inference_handler):
device = inference_handler.get_device()
assert device > -1


@require_torch
def test_test_initialize(inference_handler):
with tempfile.TemporaryDirectory() as tmpdirname:
storage_folder = _load_model_from_hub(
model_id=MODEL,
model_dir=tmpdirname,
)
CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
assert inference_handler.initialized is True


@require_torch
def test_handle(inference_handler):
with tempfile.TemporaryDirectory() as tmpdirname:
storage_folder = _load_model_from_hub(
model_id=MODEL,
model_dir=tmpdirname,
)
CONTEXT = Context(MODEL, storage_folder, {}, 1, -1, "1.1.4")
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)

inference_handler.initialize(CONTEXT)
json_data = json.dumps(INPUT)
prediction = inference_handler.handle([{"body": json_data.encode()}], CONTEXT)
loaded_response = json.loads(prediction[0])
assert "entity" in loaded_response[0]
assert "score" in loaded_response[0]


@require_torch
def test_load(inference_handler):
context = Mock()
with tempfile.TemporaryDirectory() as tmpdirname:
storage_folder = _load_model_from_hub(
model_id=MODEL,
model_dir=tmpdirname,
)
# test with automatic infer
hf_pipeline_without_task = inference_handler.load(storage_folder, context)
assert hf_pipeline_without_task.task == "token-classification"

# test with automatic infer
os.environ["HF_TASK"] = TASK
hf_pipeline_with_task = inference_handler.load(storage_folder, context)
assert hf_pipeline_with_task.task == TASK


def test_preprocess(inference_handler):
context = Mock()
json_data = json.dumps(INPUT)
decoded_input_data = inference_handler.preprocess(json_data, content_types.JSON, context)
assert "inputs" in decoded_input_data


def test_preprocess_bad_content_type(inference_handler):
context = Mock()
with pytest.raises(json.decoder.JSONDecodeError):
inference_handler.preprocess(b"", content_types.JSON, context)


@require_torch
def test_predict(inference_handler):
context = Mock()
with tempfile.TemporaryDirectory() as tmpdirname:
storage_folder = _load_model_from_hub(
model_id=MODEL,
model_dir=tmpdirname,
)
inference_handler.model = get_pipeline(task=TASK, device=-1, model_dir=storage_folder)
prediction = inference_handler.predict(INPUT, inference_handler.model, context)
assert "label" in prediction[0]
assert "score" in prediction[0]


def test_postprocess(inference_handler):
context = Mock()
output = inference_handler.postprocess(OUTPUT, content_types.JSON, context)
assert isinstance(output, str)


def test_validate_and_initialize_user_module(inference_handler):
model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_with_context")
CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)

prediction = inference_handler.handle([{"body": b""}], CONTEXT)
assert "output" in prediction[0]

assert inference_handler.load({}, CONTEXT) == "model"
assert inference_handler.preprocess({}, "", CONTEXT) == "data"
assert inference_handler.predict({}, "model", CONTEXT) == "output"
assert inference_handler.postprocess("output", "", CONTEXT) == "output"


def test_validate_and_initialize_user_module_transform_fn():
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
inference_handler = handler_service.HuggingFaceHandlerService()
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_with_context")
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
CONTEXT.request_processor = [RequestProcessor({"Content-Type": "application/json"})]
CONTEXT.metrics = MetricsStore(1, MODEL)
assert "output" in inference_handler.handle([{"body": b"dummy"}], CONTEXT)[0]
assert inference_handler.load({}, CONTEXT) == "Loading inference_tranform_fn.py"
assert (
inference_handler.transform_fn("model", "dummy", "application/json", "application/json", CONTEXT)
== "output dummy"
)
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_postprocess(inference_handler):


def test_validate_and_initialize_user_module(inference_handler):
model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn")
model_dir = os.path.join(os.getcwd(), "tests/resources/model_input_predict_output_fn_without_context")
CONTEXT = Context("", model_dir, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
Expand All @@ -148,7 +148,7 @@ def test_validate_and_initialize_user_module(inference_handler):
def test_validate_and_initialize_user_module_transform_fn():
os.environ["SAGEMAKER_PROGRAM"] = "inference_tranform_fn.py"
inference_handler = handler_service.HuggingFaceHandlerService()
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn")
model_dir = os.path.join(os.getcwd(), "tests/resources/model_transform_fn_without_context")
CONTEXT = Context("dummy", model_dir, {}, 1, -1, "1.1.4")

inference_handler.initialize(CONTEXT)
Expand Down