Description
I've deployed a custom huggingface model based off of microsoft/DialoGPT-small
and uploaded it to huggingface. I've then deployed the model to an endpoint using AWS sagemaker with the following inside sagemaker studio (mimicked from this aws blog post):
!pip install "sagemaker" -q --upgrade
import sagemaker
sess = sagemaker.Session()
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
sagemaker_session_bucket = sess.default_bucket()
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
role = sagemaker.get_execution_role()
hub = {
'HF_MODEL_ID': '[masked model name]',
'HF_TASK': 'conversational',
}
huggingface_model = sagemaker.huggingface.HuggingFaceModel(
transformers_version='4.6.1',
pytorch_version='1.7.1',
py_version='py36',
role=role,
env=hub,
)
predictor = huggingface_model.deploy(
initial_instance_count=1, # number of instances
instance_type='ml.m5.xlarge' # ec2 instance type
)
print(predictor.endpoint_name)
Issue:
Huggingface transformer ConversationalPipeline does not take the standard inputs when generating a prediction
Tried invoking it with the following (input data is a copy of the huggingface api documentation for conversational pipelines):
boto3.client('sagemaker-runtime').invoke_endpoint(
EndpointName='[model endpoint here]',
Body=json.dumps({
'inputs': {
"past_user_inputs": ["Which movie is the best ?"],
"generated_responses": ["It's Die Hard for sure."],
"text": "Can you explain why ?",
}
}),
ContentType='application/json'
)
Gives the following error in AWS (as seen in cloudwatch):
Traceback (most recent call last):
File "/opt/conda/lib/python3.6/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 222, in handle",
response = self.transform_fn(self.model, input_data, content_type, accept)",
File "/opt/conda/lib/python3.6/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 181, in transform_fn",
predictions = self.predict(processed_data, model)",
File "/opt/conda/lib/python3.6/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 149, in predict",
prediction = model(inputs)",
File "/opt/conda/lib/python3.6/site-packages/transformers/pipelines/conversational.py", line 241, in __call__",
raise ValueError("ConversationalPipeline expects a Conversation or list of Conversations as an input")",
ValueError: ConversationalPipeline expects a Conversation or list of Conversations as an input",
During handling of the above exception, another exception occurred:",
Traceback (most recent call last):",
File "/opt/conda/lib/python3.6/site-packages/mms/service.py", line 108, in predict",
ret = self._entry_point(input_batch, self.context)",
File "/opt/conda/lib/python3.6/site-packages/sagemaker_huggingface_inference_toolkit/handler_service.py", line 231, in handle",
raise PredictionException(str(e), 400)",
mms.service.PredictionException: ConversationalPipeline expects a Conversation or list of Conversations as an input : 400"
As a workaround for this, I attempted to add custom inference.py script overwriting the default predict_fn
and postprocess_fn
as follows:
from typing import Dict, Any
from transformers.pipelines import ConversationalPipeline, Conversation
def predict_fn(data: Dict[str, Any], model: ConversationalPipeline) -> Conversation:
inputs = data['inputs']
c = Conversation(inputs['text'], past_user_inputs=inputs.get('past_user_inputs', []), generated_responses=inputs.get('generated_responses', []))
prediction = model(c) # in this case, my model object returns a Conversation object with the a new generated response appended to the object's generated_responses property
return prediction
def output_fn(prediction: Conversation, accept: str) -> str:
return json.dumps({
'generated_text': prediction.generated_responses[-1],
'conversation': {
'past_user_inputs': prediction.past_user_inputs,
'generated_responses': prediction.generated_responses
}
})
Followed the documentation for this library in adding the inference.py
inside of code/
|- pytorch_model.bin
|- ....
|- code/
|- inference.py
|- requirements.txt
But still, the error above happens. You can see from the stacktrace that the custom functions I've added are ignored and not used by the HuggingFaceHandlerService