Skip to content

Commit 0524290

Browse files
committed
Added tests
1 parent be4efa2 commit 0524290

File tree

4 files changed

+211
-35
lines changed

4 files changed

+211
-35
lines changed

azure/functions/decorators/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
ENTITY_TRIGGER = "entityTrigger"
3535
DURABLE_CLIENT = "durableClient"
3636
ASSISTANT_SKILLS_TRIGGER = "assistantSkillsTrigger"
37-
TEXT_COMPLETION = "TextCompletion"
37+
TEXT_COMPLETION = "textCompletion"
3838
ASSISTANT_QUERY = "assistantQuery"
3939
EMBEDDINGS = "embeddings"
4040
EMBEDDINGS_STORE = "embeddingsStore"

azure/functions/decorators/function_app.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535
parse_iterable_param_to_enums, StringifyEnumJsonEncoder
3636
from azure.functions.http import HttpRequest
3737
from .generic import GenericInputBinding, GenericTrigger, GenericOutputBinding
38-
from .openai import AssistantSkillTrigger, OpenAIModels, TextCompletionInput, AssistantCreateOutput, \
39-
AssistantQueryInput, AssistantPostInput, InputType, EmbeddingsInput, semantic_search_system_prompt, \
38+
from .openai import AssistantSkillTrigger, OpenAIModels, TextCompletionInput, \
39+
AssistantCreateOutput, \
40+
AssistantQueryInput, AssistantPostInput, InputType, EmbeddingsInput, \
41+
semantic_search_system_prompt, \
4042
SemanticSearchInput, EmbeddingsStoreOutput
4143
from .retry_policy import RetryPolicy
4244
from .function_name import FunctionName
@@ -297,7 +299,9 @@ def decorator():
297299
self._function_builders.pop()
298300
self._function_builders.append(function_builder)
299301
return function_builder
302+
300303
return decorator()
304+
301305
return wrap
302306

303307
def _get_durable_blueprint(self):
@@ -310,9 +314,10 @@ def _get_durable_blueprint(self):
310314
df_bp = df.Blueprint()
311315
return df_bp
312316
except ImportError:
313-
error_message = "Attempted to use a Durable Functions decorator, "\
314-
"but the `azure-functions-durable` SDK package could not be "\
315-
"found. Please install `azure-functions-durable` to use "\
317+
error_message = \
318+
"Attempted to use a Durable Functions decorator, " \
319+
"but the `azure-functions-durable` SDK package could not be " \
320+
"found. Please install `azure-functions-durable` to use " \
316321
"Durable Functions."
317322
raise Exception(error_message)
318323

@@ -2772,7 +2777,8 @@ def decorator():
27722777
def text_completion_input(self,
27732778
arg_name: str,
27742779
prompt: str,
2775-
model: Optional[str] = OpenAIModels.DefaultChatModel,
2780+
model: Optional[
2781+
OpenAIModels] = OpenAIModels.DefaultChatModel, # NoQA
27762782
temperature: Optional[str] = "0.5",
27772783
top_p: Optional[str] = None,
27782784
max_tokens: Optional[str] = "100",
@@ -2783,6 +2789,7 @@ def text_completion_input(self,
27832789
"""
27842790
TODO: pydocs
27852791
"""
2792+
27862793
@self._configure_function_builder
27872794
def wrap(fb):
27882795
def decorator():
@@ -2811,6 +2818,7 @@ def assistant_create_output(self, arg_name: str,
28112818
"""
28122819
TODO: pydocs
28132820
"""
2821+
28142822
@self._configure_function_builder
28152823
def wrap(fb):
28162824
def decorator():
@@ -2837,6 +2845,7 @@ def assistant_query_input(self,
28372845
"""
28382846
TODO: pydocs
28392847
"""
2848+
28402849
@self._configure_function_builder
28412850
def wrap(fb):
28422851
def decorator():
@@ -2922,9 +2931,12 @@ def semantic_search_input(self,
29222931
connection_name: str,
29232932
collection: str,
29242933
query: Optional[str] = None,
2925-
embeddings_model: Optional[str] = OpenAIModels.DefaultEmbeddingsModel,
2926-
chat_model: Optional[str] = OpenAIModels.DefaultChatModel,
2927-
system_prompt: Optional[str] = semantic_search_system_prompt,
2934+
embeddings_model: Optional[
2935+
OpenAIModels] = OpenAIModels.DefaultEmbeddingsModel, # NoQA
2936+
chat_model: Optional[
2937+
OpenAIModels] = OpenAIModels.DefaultChatModel, # NoQA
2938+
system_prompt: Optional[
2939+
str] = semantic_search_system_prompt,
29282940
max_knowledge_count: Optional[int] = 1,
29292941
data_type: Optional[
29302942
Union[DataType, str]] = None,
@@ -2962,7 +2974,8 @@ def embeddings_store_output(self,
29622974
input_type: InputType,
29632975
connection_name: str,
29642976
collection: str,
2965-
model: Optional[str] = OpenAIModels.DefaultEmbeddingsModel,
2977+
model: Optional[
2978+
OpenAIModels] = OpenAIModels.DefaultEmbeddingsModel, # NoQA
29662979
max_chunk_length: Optional[int] = 8 * 1024,
29672980
max_overlap: Optional[int] = 128,
29682981
data_type: Optional[

azure/functions/decorators/openai.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from typing import Optional
22

3-
from azure.functions.decorators.constants import (ASSISTANT_SKILLS_TRIGGER, TEXT_COMPLETION, ASSISTANT_QUERY,
4-
EMBEDDINGS, EMBEDDINGS_STORE, ASSISTANT_CREATE, ASSISTANT_POST,
3+
from azure.functions.decorators.constants import (ASSISTANT_SKILLS_TRIGGER,
4+
TEXT_COMPLETION,
5+
ASSISTANT_QUERY,
6+
EMBEDDINGS, EMBEDDINGS_STORE,
7+
ASSISTANT_CREATE,
8+
ASSISTANT_POST,
59
SEMANTIC_SEARCH)
6-
from azure.functions.decorators.core import Trigger, DataType, InputBinding, OutputBinding
10+
from azure.functions.decorators.core import Trigger, DataType, InputBinding, \
11+
OutputBinding
712
from azure.functions.decorators.utils import StringifyEnum
813

914

@@ -15,7 +20,7 @@ class InputType(StringifyEnum):
1520

1621
class OpenAIModels(StringifyEnum):
1722
DefaultChatModel = "gpt-3.5-turbo"
18-
DefaultEmbeddingsModel = "text-embedding-3-small"
23+
DefaultEmbeddingsModel = "text-embedding-ada-002"
1924

2025

2126
class AssistantSkillTrigger(Trigger):
@@ -42,7 +47,7 @@ def get_binding_name() -> str:
4247
def __init__(self,
4348
name: str,
4449
prompt: str,
45-
model: Optional[str] = OpenAIModels.DefaultChatModel,
50+
model: Optional[OpenAIModels] = OpenAIModels.DefaultChatModel,
4651
temperature: Optional[str] = "0.5",
4752
top_p: Optional[str] = None,
4853
max_tokens: Optional[str] = "100",
@@ -85,7 +90,7 @@ def __init__(self,
8590
input_type: InputType,
8691
model: Optional[str] = None,
8792
max_chunk_length: Optional[int] = 8 * 1024,
88-
max_overlap : Optional[int] = 128,
93+
max_overlap: Optional[int] = 128,
8994
data_type: Optional[DataType] = None,
9095
**kwargs):
9196
self.name = name
@@ -97,21 +102,21 @@ def __init__(self,
97102
super().__init__(name=name, data_type=data_type)
98103

99104

100-
semantic_search_system_prompt = \
101-
"""You are a helpful assistant. You are responding to requests
102-
from a user about internal emails and documents. You can and
103-
should refer to the internal documents to help respond to
104-
requests. If a user makes a request that's not covered by the
105-
internal emails and documents, explain that you don't know the
106-
answer or that you don't have access to the information.
105+
semantic_search_system_prompt = \
106+
"""You are a helpful assistant. You are responding to requests
107+
from a user about internal emails and documents. You can and
108+
should refer to the internal documents to help respond to
109+
requests. If a user makes a request that's not covered by the
110+
internal emails and documents, explain that you don't know the
111+
answer or that you don't have access to the information.
107112
108-
The following is a list of documents that you can refer to when
109-
answering questions. The documents are in the format
110-
[filename]: [text] and are separated by newlines. If you answer
111-
a question by referencing any of the documents, please cite the
112-
document in your answer. For example, if you answer a question
113-
by referencing info.txt, you should add "Reference: info.txt"
114-
to the end of your answer on a separate line."""
113+
The following is a list of documents that you can refer to when
114+
answering questions. The documents are in the format
115+
[filename]: [text] and are separated by newlines. If you answer
116+
a question by referencing any of the documents, please cite the
117+
document in your answer. For example, if you answer a question
118+
by referencing info.txt, you should add "Reference: info.txt"
119+
to the end of your answer on a separate line."""
115120

116121

117122
class SemanticSearchInput(InputBinding):
@@ -125,10 +130,12 @@ def __init__(self,
125130
connection_name: str,
126131
collection: str,
127132
query: Optional[str] = None,
128-
embeddings_model: Optional[str] = OpenAIModels.DefaultEmbeddingsModel,
129-
chat_model: Optional[str] = OpenAIModels.DefaultChatModel,
133+
embeddings_model: Optional[
134+
OpenAIModels] = OpenAIModels.DefaultEmbeddingsModel,
135+
chat_model: Optional[
136+
OpenAIModels] = OpenAIModels.DefaultChatModel,
130137
system_prompt: Optional[str] = semantic_search_system_prompt,
131-
max_knowledge_count : Optional[int] = 1,
138+
max_knowledge_count: Optional[int] = 1,
132139
data_type: Optional[DataType] = None,
133140
**kwargs):
134141
self.name = name
@@ -172,7 +179,8 @@ def __init__(self,
172179
input_type: InputType,
173180
connection_name: str,
174181
collection: str,
175-
model: Optional[str] = OpenAIModels.DefaultEmbeddingsModel,
182+
model: Optional[
183+
OpenAIModels] = OpenAIModels.DefaultEmbeddingsModel,
176184
max_chunk_length: Optional[int] = 8 * 1024,
177185
max_overlap: Optional[int] = 128,
178186
data_type: Optional[DataType] = None,

tests/decorators/test_openai.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import unittest
2+
3+
from azure.functions import DataType
4+
from azure.functions.decorators.core import BindingDirection
5+
from azure.functions.decorators.openai import AssistantSkillTrigger, \
6+
TextCompletionInput, OpenAIModels, AssistantQueryInput, EmbeddingsInput, \
7+
AssistantCreateOutput, SemanticSearchInput, EmbeddingsStoreOutput
8+
9+
10+
class TestOpenAI(unittest.TestCase):
11+
12+
def test_assistant_skills_trigger_valid_creation(self):
13+
trigger = AssistantSkillTrigger(name="test",
14+
task_description="test_description",
15+
data_type=DataType.UNDEFINED,
16+
dummy_field="dummy")
17+
self.assertEqual(trigger.get_binding_name(),
18+
"assistantSkillsTrigger")
19+
self.assertEqual(
20+
trigger.get_dict_repr(), {"name": "test",
21+
"taskDescription": "test_description",
22+
"dataType": DataType.UNDEFINED,
23+
'type': 'assistantSkillsTrigger',
24+
'dummyField': 'dummy',
25+
"direction": BindingDirection.IN,
26+
})
27+
28+
def test_text_completion_input_valid_creation(self):
29+
input = TextCompletionInput(name="test",
30+
prompt="test_prompt",
31+
temperature="1",
32+
max_tokens="1",
33+
data_type=DataType.UNDEFINED,
34+
model=OpenAIModels.DefaultChatModel,
35+
dummy_field="dummy")
36+
self.assertEqual(input.get_binding_name(),
37+
"textCompletion")
38+
self.assertEqual(input.get_dict_repr(),
39+
{"name": "test",
40+
"temperature": "1",
41+
"maxTokens": "1",
42+
'type': 'textCompletion',
43+
"dataType": DataType.UNDEFINED,
44+
"dummyField": "dummy",
45+
"prompt": "test_prompt",
46+
"direction": BindingDirection.IN,
47+
"model": OpenAIModels.DefaultChatModel
48+
})
49+
50+
def test_assistant_query_input_valid_creation(self):
51+
input = AssistantQueryInput(name="test",
52+
timestamp_utc="timestamp_utc",
53+
data_type=DataType.UNDEFINED,
54+
id="test_id",
55+
type="assistantQueryInput",
56+
dummy_field="dummy")
57+
self.assertEqual(input.get_binding_name(),
58+
"assistantQuery")
59+
self.assertEqual(input.get_dict_repr(),
60+
{"name": "test",
61+
"timestampUtc": "timestamp_utc",
62+
"dataType": DataType.UNDEFINED,
63+
"direction": BindingDirection.IN,
64+
"type": "assistantQuery",
65+
"id": "test_id",
66+
"dummyField": "dummy"
67+
})
68+
69+
def test_embeddings_input_valid_creation(self):
70+
input = EmbeddingsInput(name="test",
71+
data_type=DataType.UNDEFINED,
72+
input="test_input",
73+
input_type="test_input_type",
74+
model="test_model",
75+
max_overlap=1,
76+
max_chunk_length=1,
77+
dummy_field="dummy")
78+
self.assertEqual(input.get_binding_name(),
79+
"embeddings")
80+
self.assertEqual(input.get_dict_repr(),
81+
{"name": "test",
82+
"type": "embeddings",
83+
"dataType": DataType.UNDEFINED,
84+
"input": "test_input",
85+
"inputType": "test_input_type",
86+
"model": "test_model",
87+
"maxOverlap": 1,
88+
"maxChunkLength": 1,
89+
"direction": BindingDirection.IN,
90+
"dummyField": "dummy"})
91+
92+
def test_assistant_create_output_valid_creation(self):
93+
output = AssistantCreateOutput(name="test",
94+
data_type=DataType.UNDEFINED)
95+
self.assertEqual(output.get_binding_name(),
96+
"assistantCreate")
97+
self.assertEqual(output.get_dict_repr(),
98+
{"name": "test",
99+
"dataType": DataType.UNDEFINED,
100+
"direction": BindingDirection.OUT,
101+
"type": "assistantCreate"})
102+
103+
def test_semantic_search_input_valid_creation(self):
104+
input = SemanticSearchInput(name="test",
105+
data_type=DataType.UNDEFINED,
106+
chat_model=OpenAIModels.DefaultChatModel,
107+
embeddings_model=OpenAIModels.DefaultEmbeddingsModel, # NoQA
108+
collection="test_collection",
109+
connection_name="test_connection",
110+
system_prompt="test_prompt",
111+
query="test_query",
112+
max_knowledge_count=1,
113+
dummy_field="dummy_field")
114+
self.assertEqual(input.get_binding_name(),
115+
"semanticSearch")
116+
self.assertEqual(input.get_dict_repr(),
117+
{"name": "test",
118+
"dataType": DataType.UNDEFINED,
119+
"direction": BindingDirection.IN,
120+
"dummyField": "dummy_field",
121+
"chatModel": OpenAIModels.DefaultChatModel,
122+
"embeddingsModel": OpenAIModels.DefaultEmbeddingsModel, # NoQA
123+
"type": "semanticSearch",
124+
"collection": "test_collection",
125+
"connectionName": "test_connection",
126+
"systemPrompt": "test_prompt",
127+
"maxKnowledgeCount": 1,
128+
"query": "test_query"})
129+
130+
def test_embeddings_store_output_valid_creation(self):
131+
output = EmbeddingsStoreOutput(name="test",
132+
data_type=DataType.UNDEFINED,
133+
input="test_input",
134+
input_type="test_input_type",
135+
connection_name="test_connection",
136+
max_overlap=1,
137+
max_chunk_length=1,
138+
collection="test_collection",
139+
model=OpenAIModels.DefaultChatModel,
140+
dummy_field="dummy_field")
141+
self.assertEqual(output.get_binding_name(),
142+
"embeddingsStore")
143+
self.assertEqual(output.get_dict_repr(),
144+
{"name": "test",
145+
"dataType": DataType.UNDEFINED,
146+
"direction": BindingDirection.OUT,
147+
"dummyField": "dummy_field",
148+
"input": "test_input",
149+
"inputType": "test_input_type",
150+
"collection": "test_collection",
151+
"model": OpenAIModels.DefaultChatModel,
152+
"connectionName": "test_connection",
153+
"maxOverlap": 1,
154+
"maxChunkLength": 1,
155+
"type": "embeddingsStore"})

0 commit comments

Comments
 (0)