Skip to content

Commit 080f05d

Browse files
committed
added conversational pipeline wrapped for better ux
1 parent da070a2 commit 080f05d

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

src/sagemaker_huggingface_inference_toolkit/transformers_utils.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from huggingface_hub.file_download import cached_download, hf_hub_url
2121
from transformers import pipeline
2222
from transformers.file_utils import is_tf_available, is_torch_available
23-
from transformers.pipelines import Pipeline
24-
23+
from transformers.pipelines import Pipeline, Conversation
2524

2625
if is_tf_available():
2726
import tensorflow as tf
@@ -90,6 +89,25 @@
9089
HF_MODEL_REVISION = os.environ.get("HF_MODEL_REVISION", None)
9190

9291

92+
def wrap_conversation_pipeline(pipeline):
93+
def wrapped_pipeline(inputs, *args, **kwargs):
94+
converted_input = Conversation(
95+
inputs["text"],
96+
past_user_inputs=inputs.get("past_user_inputs", []),
97+
generated_responses=inputs.get("generated_responses", []),
98+
)
99+
prediction = pipeline(converted_input, *args, **kwargs)
100+
return {
101+
"generated_text": prediction.generated_responses[-1],
102+
"conversation": {
103+
"past_user_inputs": prediction.past_user_inputs,
104+
"generated_responses": prediction.generated_responses,
105+
},
106+
}
107+
108+
return wrapped_pipeline
109+
110+
93111
def _is_gpu_available():
94112
"""
95113
checks if a gpu is available.
@@ -233,4 +251,8 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline:
233251

234252
hf_pipeline = pipeline(task=task, model=model_dir, tokenizer=model_dir, device=device, **kwargs)
235253

254+
# wrapp specific pipeline to support better ux
255+
if task == "conversational":
256+
hf_pipeline = wrap_conversation_pipeline(hf_pipeline)
257+
236258
return hf_pipeline

tests/unit/test_transformers_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import tempfile
1616

1717
from transformers.file_utils import is_torch_available
18+
from transformers import pipeline
1819
from transformers.testing_utils import require_tf, require_torch, slow
1920

2021
from sagemaker_huggingface_inference_toolkit.transformers_utils import (
@@ -28,6 +29,7 @@
2829
get_pipeline,
2930
infer_task_from_hub,
3031
infer_task_from_model_architecture,
32+
wrap_conversation_pipeline,
3133
)
3234

3335

@@ -122,3 +124,38 @@ def test_infer_task_from_model_architecture():
122124
storage_dir = _load_model_from_hub(TASK_MODEL, tmpdirname)
123125
task = infer_task_from_model_architecture(f"{storage_dir}/config.json")
124126
assert task == "token-classification"
127+
128+
129+
@require_torch
130+
def test_wrap_conversation_pipeline():
131+
init_pipeline = pipeline(
132+
"conversational",
133+
model="microsoft/DialoGPT-small",
134+
tokenizer="microsoft/DialoGPT-small",
135+
framework="pt",
136+
device=0,
137+
)
138+
conv_pipe = wrap_conversation_pipeline(init_pipeline)
139+
data = {
140+
"past_user_inputs": ["Which movie is the best ?"],
141+
"generated_responses": ["It's Die Hard for sure."],
142+
"text": "Can you explain why?",
143+
}
144+
res = conv_pipe(data)
145+
assert "conversation" in res
146+
assert "generated_text" in res
147+
148+
149+
@require_torch
150+
def test_wrapped_pipeline():
151+
with tempfile.TemporaryDirectory() as tmpdirname:
152+
storage_dir = _load_model_from_hub("microsoft/DialoGPT-small", tmpdirname)
153+
conv_pipe = get_pipeline("conversational", -1, storage_dir)
154+
data = {
155+
"past_user_inputs": ["Which movie is the best ?"],
156+
"generated_responses": ["It's Die Hard for sure."],
157+
"text": "Can you explain why?",
158+
}
159+
res = conv_pipe(data)
160+
assert "conversation" in res
161+
assert "generated_text" in res

0 commit comments

Comments
 (0)