diff --git a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py index 660a362..aff91fa 100644 --- a/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py +++ b/src/sagemaker_huggingface_inference_toolkit/transformers_utils.py @@ -20,8 +20,7 @@ from huggingface_hub.file_download import cached_download, hf_hub_url from transformers import pipeline from transformers.file_utils import is_tf_available, is_torch_available -from transformers.pipelines import Pipeline - +from transformers.pipelines import Pipeline, Conversation if is_tf_available(): import tensorflow as tf @@ -90,6 +89,25 @@ HF_MODEL_REVISION = os.environ.get("HF_MODEL_REVISION", None) +def wrap_conversation_pipeline(pipeline): + def wrapped_pipeline(inputs, *args, **kwargs): + converted_input = Conversation( + inputs["text"], + past_user_inputs=inputs.get("past_user_inputs", []), + generated_responses=inputs.get("generated_responses", []), + ) + prediction = pipeline(converted_input, *args, **kwargs) + return { + "generated_text": prediction.generated_responses[-1], + "conversation": { + "past_user_inputs": prediction.past_user_inputs, + "generated_responses": prediction.generated_responses, + }, + } + + return wrapped_pipeline + + def _is_gpu_available(): """ checks if a gpu is available. @@ -233,4 +251,8 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline: hf_pipeline = pipeline(task=task, model=model_dir, tokenizer=model_dir, device=device, **kwargs) + # wrapp specific pipeline to support better ux + if task == "conversational": + hf_pipeline = wrap_conversation_pipeline(hf_pipeline) + return hf_pipeline diff --git a/tests/unit/test_transformers_utils.py b/tests/unit/test_transformers_utils.py index 304b097..1d3e0d5 100644 --- a/tests/unit/test_transformers_utils.py +++ b/tests/unit/test_transformers_utils.py @@ -15,6 +15,7 @@ import tempfile from transformers.file_utils import is_torch_available +from transformers import pipeline from transformers.testing_utils import require_tf, require_torch, slow from sagemaker_huggingface_inference_toolkit.transformers_utils import ( @@ -28,6 +29,7 @@ get_pipeline, infer_task_from_hub, infer_task_from_model_architecture, + wrap_conversation_pipeline, ) @@ -122,3 +124,37 @@ def test_infer_task_from_model_architecture(): storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname) task = infer_task_from_model_architecture(f"{storage_dir}/config.json") assert task == "token-classification" + + +@require_torch +def test_wrap_conversation_pipeline(): + init_pipeline = pipeline( + "conversational", + model="microsoft/DialoGPT-small", + tokenizer="microsoft/DialoGPT-small", + framework="pt", + ) + conv_pipe = wrap_conversation_pipeline(init_pipeline) + data = { + "past_user_inputs": ["Which movie is the best ?"], + "generated_responses": ["It's Die Hard for sure."], + "text": "Can you explain why?", + } + res = conv_pipe(data) + assert "conversation" in res + assert "generated_text" in res + + +@require_torch +def test_wrapped_pipeline(): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_dir = _load_model_from_hub("microsoft/DialoGPT-small", tmpdirname) + conv_pipe = get_pipeline("conversational", -1, storage_dir) + data = { + "past_user_inputs": ["Which movie is the best ?"], + "generated_responses": ["It's Die Hard for sure."], + "text": "Can you explain why?", + } + res = conv_pipe(data) + assert "conversation" in res + assert "generated_text" in res