diff --git a/azure/functions/decorators/function_app.py b/azure/functions/decorators/function_app.py index f0feb470..488f997f 100644 --- a/azure/functions/decorators/function_app.py +++ b/azure/functions/decorators/function_app.py @@ -1694,7 +1694,6 @@ def assistant_skill_trigger(self, function_description: str, function_name: Optional[str] = None, parameter_description_json: Optional[str] = None, # NoQA - model: Optional[OpenAIModels] = OpenAIModels.DefaultChatModel, # NoQA data_type: Optional[ Union[DataType, str]] = None, **kwargs: Any) -> Callable[..., Any]: @@ -1723,7 +1722,6 @@ def assistant_skill_trigger(self, :param parameter_description_json: A JSON description of the function parameter, which is provided to the LLM. If no description is provided, the description will be autogenerated. - :param model: The OpenAI chat model to use. :param data_type: Defines how Functions runtime should treat the parameter value. :param kwargs: Keyword arguments for specifying additional binding @@ -1741,7 +1739,6 @@ def decorator(): function_description=function_description, function_name=function_name, parameter_description_json=parameter_description_json, - model=model, data_type=parse_singular_param_to_enum(data_type, DataType), **kwargs)) @@ -3220,10 +3217,13 @@ def decorator(): def text_completion_input(self, arg_name: str, prompt: str, - model: Optional[OpenAIModels] = OpenAIModels.DefaultChatModel, # NoQA + chat_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultChatModel, temperature: Optional[str] = "0.5", top_p: Optional[str] = None, max_tokens: Optional[str] = "100", + is_reasoning_model: Optional[bool] = False, data_type: Optional[Union[DataType, str]] = None, **kwargs) \ -> Callable[..., Any]: @@ -3243,7 +3243,10 @@ def text_completion_input(self, :param arg_name: The name of binding parameter in the function code. :param prompt: The prompt to generate completions for, encoded as a string. - :param model: the ID of the model to use. + :param model: @deprecated. Use chat_model instead. The model parameter + is unused and will be removed in future versions. + :param chat_model: The deployment name or model name of OpenAI Chat + Completion API. The default value is "gpt-3.5-turbo". :param temperature: The sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. @@ -3255,7 +3258,10 @@ def text_completion_input(self, :param max_tokens: The maximum number of tokens to generate in the completion. The token count of your prompt plus max_tokens cannot exceed the model's context length. Most models have a context length of - 2048 tokens (except for the newest models, which support 4096). + 2048 tokens (except for the newest models, which support 4096) + :param is_reasoning_model: Whether the configured chat completion model + is a reasoning model or not. Properties max_tokens and temperature are not + supported for reasoning models. :param data_type: Defines how Functions runtime should treat the parameter value :param kwargs: Keyword arguments for specifying additional binding @@ -3271,10 +3277,11 @@ def decorator(): binding=TextCompletionInput( name=arg_name, prompt=prompt, - model=model, + chat_model=chat_model, temperature=temperature, top_p=top_p, max_tokens=max_tokens, + is_reasoning_model=is_reasoning_model, data_type=parse_singular_param_to_enum(data_type, DataType), **kwargs)) @@ -3371,9 +3378,15 @@ def decorator(): def assistant_post_input(self, arg_name: str, id: str, user_message: str, - model: Optional[str] = None, + chat_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultChatModel, chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 - collection_name: Optional[str] = "ChatState", # noqa: E501 + collection_name: Optional[str] = "ChatState", # noqa: E501 + temperature: Optional[str] = "0.5", + top_p: Optional[str] = None, + max_tokens: Optional[str] = "100", + is_reasoning_model: Optional[bool] = False, data_type: Optional[ Union[DataType, str]] = None, **kwargs) \ @@ -3386,12 +3399,30 @@ def assistant_post_input(self, arg_name: str, :param id: The ID of the assistant to update. :param user_message: The user message that user has entered for assistant to respond to. - :param model: The OpenAI chat model to use. + :param model: @deprecated. Use chat_model instead. The model parameter + is unused and will be removed in future versions. + :param chat_model: The deployment name or model name of OpenAI Chat + Completion API. The default value is "gpt-3.5-turbo". :param chat_storage_connection_setting: The configuration section name for the table settings for assistant chat storage. The default value is "AzureWebJobsStorage". :param collection_name: The table collection name for assistant chat storage. The default value is "ChatState". + :param temperature: The sampling temperature to use, between 0 and 2. + Higher values like 0.8 will make the output more random, while lower + values like 0.2 will make it more focused and deterministic. + :param top_p: An alternative to sampling with temperature, called + nucleus sampling, where the model considers the results of the tokens + with top_p probability mass. So 0.1 means only the tokens comprising + the top 10% probability mass are considered. It's generally recommend + to use this or temperature + :param max_tokens: The maximum number of tokens to generate in the + completion. The token count of your prompt plus max_tokens cannot + exceed the model's context length. Most models have a context length of + 2048 tokens (except for the newest models, which support 4096) + :param is_reasoning_model: Whether the configured chat completion model + is a reasoning model or not. Properties max_tokens and temperature are + not supported for reasoning models. :param data_type: Defines how Functions runtime should treat the parameter value :param kwargs: Keyword arguments for specifying additional binding @@ -3408,9 +3439,13 @@ def decorator(): name=arg_name, id=id, user_message=user_message, - model=model, + chat_model=chat_model, chat_storage_connection_setting=chat_storage_connection_setting, # noqa: E501 collection_name=collection_name, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + is_reasoning_model=is_reasoning_model, data_type=parse_singular_param_to_enum(data_type, DataType), **kwargs)) @@ -3424,7 +3459,9 @@ def embeddings_input(self, arg_name: str, input: str, input_type: InputType, - model: Optional[str] = None, + embeddings_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultEmbeddingsModel, max_chunk_length: Optional[int] = 8 * 1024, max_overlap: Optional[int] = 128, data_type: Optional[ @@ -3441,7 +3478,10 @@ def embeddings_input(self, :param input: The input source containing the data to generate embeddings for. :param input_type: The type of the input. - :param model: The ID of the model to use. + :param model: @deprecated. Use embeddings_model instead. The model + parameter is unused and will be removed in future versions. + :param embeddings_model: The deployment name or model name for OpenAI + Embeddings. The default value is "text-embedding-ada-002". :param max_chunk_length: The maximum number of characters to chunk the input into. Default value: 8 * 1024 :param max_overlap: The maximum number of characters to overlap @@ -3462,7 +3502,7 @@ def decorator(): name=arg_name, input=input, input_type=input_type, - model=model, + embeddings_model=embeddings_model, max_chunk_length=max_chunk_length, max_overlap=max_overlap, data_type=parse_singular_param_to_enum(data_type, @@ -3476,13 +3516,21 @@ def decorator(): def semantic_search_input(self, arg_name: str, - connection_name: str, + search_connection_name: str, collection: str, query: Optional[str] = None, - embeddings_model: Optional[OpenAIModels] = OpenAIModels.DefaultEmbeddingsModel, # NoQA - chat_model: Optional[OpenAIModels] = OpenAIModels.DefaultChatModel, # NoQA + embeddings_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultEmbeddingsModel, + chat_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultChatModel, system_prompt: Optional[str] = semantic_search_system_prompt, # NoQA max_knowledge_count: Optional[int] = 1, + temperature: Optional[str] = "0.5", + top_p: Optional[str] = None, + max_tokens: Optional[str] = "100", + is_reasoning_model: Optional[bool] = False, data_type: Optional[ Union[DataType, str]] = None, **kwargs) \ @@ -3499,19 +3547,34 @@ def semantic_search_input(self, Ref: https://platform.openai.com/docs/guides/embeddings :param arg_name: The name of binding parameter in the function code. - :param connection_name: app setting or environment variable which - contains a connection string value. + :param search_connection_name: app setting or environment variable + which contains a vector search connection setting value. :param collection: The name of the collection or table to search or store. :param query: The semantic query text to use for searching. - :param embeddings_model: The ID of the model to use for embeddings. - The default value is "text-embedding-ada-002". - :param chat_model: The name of the Large Language Model to invoke for - chat responses. The default value is "gpt-3.5-turbo". + :param embeddings_model: The deployment name or model name for OpenAI + Embeddings. The default value is "text-embedding-ada-002". + :param chat_model: The deployment name or model name of OpenAI Chat + Completion API. The default value is "gpt-3.5-turbo". :param system_prompt: Optional. The system prompt to use for prompting the large language model. :param max_knowledge_count: Optional. The number of knowledge items to inject into the SystemPrompt. Default value: 1 + :param temperature: The sampling temperature to use, between 0 and 2. + Higher values like 0.8 will make the output more random, while lower + values like 0.2 will make it more focused and deterministic. + :param top_p: An alternative to sampling with temperature, called + nucleus sampling, where the model considers the results of the tokens + with top_p probability mass. So 0.1 means only the tokens comprising + the top 10% probability mass are considered. It's generally recommend + to use this or temperature + :param max_tokens: The maximum number of tokens to generate in the + completion. The token count of your prompt plus max_tokens cannot + exceed the model's context length. Most models have a context length of + 2048 tokens (except for the newest models, which support 4096) + :param is_reasoning_model: Whether the configured chat completion model + is a reasoning model or not. Properties max_tokens and temperature are + not supported for reasoning models. :param data_type: Optional. Defines how Functions runtime should treat the parameter value. Default value: None :param kwargs: Keyword arguments for specifying additional binding @@ -3526,13 +3589,17 @@ def decorator(): fb.add_binding( binding=SemanticSearchInput( name=arg_name, - connection_name=connection_name, + search_connection_name=search_connection_name, collection=collection, query=query, embeddings_model=embeddings_model, chat_model=chat_model, system_prompt=system_prompt, max_knowledge_count=max_knowledge_count, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + is_reasoning_model=is_reasoning_model, data_type=parse_singular_param_to_enum(data_type, DataType), **kwargs)) @@ -3546,9 +3613,11 @@ def embeddings_store_output(self, arg_name: str, input: str, input_type: InputType, - connection_name: str, + store_connection_name: str, collection: str, - model: Optional[OpenAIModels] = OpenAIModels.DefaultEmbeddingsModel, # NoQA + embeddings_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultEmbeddingsModel, max_chunk_length: Optional[int] = 8 * 1024, max_overlap: Optional[int] = 128, data_type: Optional[ @@ -3568,10 +3637,13 @@ def embeddings_store_output(self, :param arg_name: The name of binding parameter in the function code. :param input: The input to generate embeddings for. :param input_type: The type of the input. - :param connection_name: The name of an app setting or environment - variable which contains a connection string value + :param store_connection_name: The name of an app setting or environment + variable which contains a vectore store connection setting value :param collection: The collection or table to search. - :param model: The ID of the model to use. + :param model: @deprecated. Use embeddings_model instead. The model + parameter is unused and will be removed in future versions. + :param embeddings_model: The deployment name or model name for OpenAI + Embeddings. The default value is "text-embedding-ada-002". :param max_chunk_length: The maximum number of characters to chunk the input into. :param max_overlap: The maximum number of characters to overlap between @@ -3592,9 +3664,9 @@ def decorator(): name=arg_name, input=input, input_type=input_type, - connection_name=connection_name, + store_connection_name=store_connection_name, collection=collection, - model=model, + embeddings_model=embeddings_model, max_chunk_length=max_chunk_length, max_overlap=max_overlap, data_type=parse_singular_param_to_enum(data_type, diff --git a/azure/functions/decorators/openai.py b/azure/functions/decorators/openai.py index 2563a78e..31306c67 100644 --- a/azure/functions/decorators/openai.py +++ b/azure/functions/decorators/openai.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from azure.functions.decorators.constants import (ASSISTANT_SKILL_TRIGGER, TEXT_COMPLETION, @@ -34,13 +34,11 @@ def __init__(self, function_description: str, function_name: Optional[str] = None, parameter_description_json: Optional[str] = None, - model: Optional[OpenAIModels] = OpenAIModels.DefaultChatModel, data_type: Optional[DataType] = None, **kwargs): self.function_description = function_description self.function_name = function_name self.parameter_description_json = parameter_description_json - self.model = model super().__init__(name=name, data_type=data_type) @@ -53,17 +51,21 @@ def get_binding_name() -> str: def __init__(self, name: str, prompt: str, - model: Optional[OpenAIModels] = OpenAIModels.DefaultChatModel, + chat_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultChatModel, temperature: Optional[str] = "0.5", top_p: Optional[str] = None, max_tokens: Optional[str] = "100", + is_reasoning_model: Optional[bool] = False, data_type: Optional[DataType] = None, **kwargs): self.prompt = prompt - self.model = model + self.chat_model = chat_model self.temperature = temperature self.top_p = top_p self.max_tokens = max_tokens + self.is_reasoning_model = is_reasoning_model super().__init__(name=name, data_type=data_type) @@ -98,7 +100,9 @@ def __init__(self, name: str, input: str, input_type: InputType, - model: Optional[str] = None, + embeddings_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultEmbeddingsModel, max_chunk_length: Optional[int] = 8 * 1024, max_overlap: Optional[int] = 128, data_type: Optional[DataType] = None, @@ -106,7 +110,7 @@ def __init__(self, self.name = name self.input = input self.input_type = input_type - self.model = model + self.embeddings_model = embeddings_model self.max_chunk_length = max_chunk_length self.max_overlap = max_overlap super().__init__(name=name, data_type=data_type) @@ -137,25 +141,35 @@ def get_binding_name() -> str: def __init__(self, name: str, - connection_name: str, + search_connection_name: str, collection: str, query: Optional[str] = None, - embeddings_model: Optional[ - OpenAIModels] = OpenAIModels.DefaultEmbeddingsModel, - chat_model: Optional[ - OpenAIModels] = OpenAIModels.DefaultChatModel, + embeddings_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultEmbeddingsModel, + chat_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultChatModel, system_prompt: Optional[str] = semantic_search_system_prompt, max_knowledge_count: Optional[int] = 1, + temperature: Optional[str] = "0.5", + top_p: Optional[str] = None, + max_tokens: Optional[str] = "100", + is_reasoning_model: Optional[bool] = False, data_type: Optional[DataType] = None, **kwargs): self.name = name - self.connection_name = connection_name + self.search_connection_name = search_connection_name self.collection = collection self.query = query self.embeddings_model = embeddings_model self.chat_model = chat_model self.system_prompt = system_prompt self.max_knowledge_count = max_knowledge_count + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.is_reasoning_model = is_reasoning_model super().__init__(name=name, data_type=data_type) @@ -168,17 +182,27 @@ def get_binding_name(): def __init__(self, name: str, id: str, user_message: str, - model: Optional[str] = None, + chat_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultChatModel, chat_storage_connection_setting: Optional[str] = "AzureWebJobsStorage", # noqa: E501 collection_name: Optional[str] = "ChatState", + temperature: Optional[str] = "0.5", + top_p: Optional[str] = None, + max_tokens: Optional[str] = "100", + is_reasoning_model: Optional[bool] = False, data_type: Optional[DataType] = None, **kwargs): self.name = name self.id = id self.user_message = user_message - self.model = model + self.chat_model = chat_model self.chat_storage_connection_setting = chat_storage_connection_setting self.collection_name = collection_name + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.is_reasoning_model = is_reasoning_model super().__init__(name=name, data_type=data_type) @@ -192,10 +216,11 @@ def __init__(self, name: str, input: str, input_type: InputType, - connection_name: str, + store_connection_name: str, collection: str, - model: Optional[ - OpenAIModels] = OpenAIModels.DefaultEmbeddingsModel, + embeddings_model: Optional + [Union[str, OpenAIModels]] + = OpenAIModels.DefaultEmbeddingsModel, max_chunk_length: Optional[int] = 8 * 1024, max_overlap: Optional[int] = 128, data_type: Optional[DataType] = None, @@ -203,9 +228,9 @@ def __init__(self, self.name = name self.input = input self.input_type = input_type - self.connection_name = connection_name + self.store_connection_name = store_connection_name self.collection = collection - self.model = model + self.embeddings_model = embeddings_model self.max_chunk_length = max_chunk_length self.max_overlap = max_overlap super().__init__(name=name, data_type=data_type) diff --git a/tests/decorators/test_openai.py b/tests/decorators/test_openai.py index c2009c72..78b4e0da 100644 --- a/tests/decorators/test_openai.py +++ b/tests/decorators/test_openai.py @@ -2,188 +2,403 @@ from azure.functions import DataType from azure.functions.decorators.core import BindingDirection -from azure.functions.decorators.openai import AssistantSkillTrigger, \ - TextCompletionInput, OpenAIModels, AssistantQueryInput, EmbeddingsInput, \ - AssistantCreateOutput, SemanticSearchInput, EmbeddingsStoreOutput, \ - AssistantPostInput +from azure.functions.decorators.openai import ( + AssistantSkillTrigger, + TextCompletionInput, + OpenAIModels, + AssistantQueryInput, + EmbeddingsInput, + AssistantCreateOutput, + SemanticSearchInput, + EmbeddingsStoreOutput, + AssistantPostInput, +) class TestOpenAI(unittest.TestCase): def test_assistant_skill_trigger_valid_creation(self): - trigger = AssistantSkillTrigger(name="test", - function_description="description", - function_name="test_function_name", - parameter_description_json="test_json", - model=OpenAIModels.DefaultChatModel, - data_type=DataType.UNDEFINED, - dummy_field="dummy") - self.assertEqual(trigger.get_binding_name(), - "assistantSkillTrigger") + trigger = AssistantSkillTrigger( + name="test", + function_description="description", + function_name="test_function_name", + parameter_description_json="test_json", + data_type=DataType.UNDEFINED, + dummy_field="dummy", + ) + self.assertEqual(trigger.get_binding_name(), "assistantSkillTrigger") self.assertEqual( - trigger.get_dict_repr(), {"name": "test", - "functionDescription": "description", - "functionName": "test_function_name", - "parameterDescriptionJson": "test_json", - "model": OpenAIModels.DefaultChatModel, - "dataType": DataType.UNDEFINED, - 'type': 'assistantSkillTrigger', - 'dummyField': 'dummy', - "direction": BindingDirection.IN, - }) + trigger.get_dict_repr(), + { + "name": "test", + "functionDescription": "description", + "functionName": "test_function_name", + "parameterDescriptionJson": "test_json", + "dataType": DataType.UNDEFINED, + "type": "assistantSkillTrigger", + "dummyField": "dummy", + "direction": BindingDirection.IN, + }, + ) def test_text_completion_input_valid_creation(self): - input = TextCompletionInput(name="test", - prompt="test_prompt", - temperature="1", - max_tokens="1", - data_type=DataType.UNDEFINED, - model=OpenAIModels.DefaultChatModel, - dummy_field="dummy") - self.assertEqual(input.get_binding_name(), - "textCompletion") - self.assertEqual(input.get_dict_repr(), - {"name": "test", - "temperature": "1", - "maxTokens": "1", - 'type': 'textCompletion', - "dataType": DataType.UNDEFINED, - "dummyField": "dummy", - "prompt": "test_prompt", - "direction": BindingDirection.IN, - "model": OpenAIModels.DefaultChatModel - }) + input = TextCompletionInput( + name="test", + prompt="test_prompt", + temperature="1", + max_tokens="1", + is_reasoning_model=False, + data_type=DataType.UNDEFINED, + chat_model=OpenAIModels.DefaultChatModel, + dummy_field="dummy", + ) + self.assertEqual(input.get_binding_name(), "textCompletion") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "temperature": "1", + "maxTokens": "1", + "type": "textCompletion", + "dataType": DataType.UNDEFINED, + "dummyField": "dummy", + "prompt": "test_prompt", + "direction": BindingDirection.IN, + "chatModel": OpenAIModels.DefaultChatModel, + "isReasoningModel": False, + }, + ) + + def test_text_completion_input_with_string_chat_model(self): + input = TextCompletionInput( + name="test", + prompt="test_prompt", + temperature="1", + max_tokens="1", + is_reasoning_model=True, + data_type=DataType.UNDEFINED, + chat_model="gpt-4o", + dummy_field="dummy", + ) + self.assertEqual(input.get_binding_name(), "textCompletion") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "temperature": "1", + "maxTokens": "1", + "type": "textCompletion", + "dataType": DataType.UNDEFINED, + "dummyField": "dummy", + "prompt": "test_prompt", + "direction": BindingDirection.IN, + "chatModel": "gpt-4o", + "isReasoningModel": True, + }, + ) def test_assistant_query_input_valid_creation(self): - input = AssistantQueryInput(name="test", - timestamp_utc="timestamp_utc", - chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 - collection_name="ChatState", - data_type=DataType.UNDEFINED, - id="test_id", - type="assistantQueryInput", - dummy_field="dummy") - self.assertEqual(input.get_binding_name(), - "assistantQuery") - self.assertEqual(input.get_dict_repr(), - {"name": "test", - "timestampUtc": "timestamp_utc", - "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 - "collectionName": "ChatState", - "dataType": DataType.UNDEFINED, - "direction": BindingDirection.IN, - "type": "assistantQuery", - "id": "test_id", - "dummyField": "dummy" - }) + input = AssistantQueryInput( + name="test", + timestamp_utc="timestamp_utc", + chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 + collection_name="ChatState", + data_type=DataType.UNDEFINED, + id="test_id", + type="assistantQueryInput", + dummy_field="dummy", + ) + self.assertEqual(input.get_binding_name(), "assistantQuery") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "timestampUtc": "timestamp_utc", + "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 + "collectionName": "ChatState", + "dataType": DataType.UNDEFINED, + "direction": BindingDirection.IN, + "type": "assistantQuery", + "id": "test_id", + "dummyField": "dummy", + }, + ) def test_embeddings_input_valid_creation(self): - input = EmbeddingsInput(name="test", - data_type=DataType.UNDEFINED, - input="test_input", - input_type="test_input_type", - model="test_model", - max_overlap=1, - max_chunk_length=1, - dummy_field="dummy") - self.assertEqual(input.get_binding_name(), - "embeddings") - self.assertEqual(input.get_dict_repr(), - {"name": "test", - "type": "embeddings", - "dataType": DataType.UNDEFINED, - "input": "test_input", - "inputType": "test_input_type", - "model": "test_model", - "maxOverlap": 1, - "maxChunkLength": 1, - "direction": BindingDirection.IN, - "dummyField": "dummy"}) + input = EmbeddingsInput( + name="test", + data_type=DataType.UNDEFINED, + input="test_input", + input_type="test_input_type", + embeddings_model="test_model", + max_overlap=1, + max_chunk_length=1, + dummy_field="dummy", + ) + self.assertEqual(input.get_binding_name(), "embeddings") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "type": "embeddings", + "dataType": DataType.UNDEFINED, + "input": "test_input", + "inputType": "test_input_type", + "embeddingsModel": "test_model", + "maxOverlap": 1, + "maxChunkLength": 1, + "direction": BindingDirection.IN, + "dummyField": "dummy", + }, + ) + + def test_embeddings_input_with_enum_embeddings_model(self): + input = EmbeddingsInput( + name="test", + data_type=DataType.UNDEFINED, + input="test_input", + input_type="test_input_type", + embeddings_model=OpenAIModels.DefaultEmbeddingsModel, + max_overlap=1, + max_chunk_length=1, + dummy_field="dummy", + ) + self.assertEqual(input.get_binding_name(), "embeddings") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "type": "embeddings", + "dataType": DataType.UNDEFINED, + "input": "test_input", + "inputType": "test_input_type", + "embeddingsModel": OpenAIModels.DefaultEmbeddingsModel, + "maxOverlap": 1, + "maxChunkLength": 1, + "direction": BindingDirection.IN, + "dummyField": "dummy", + }, + ) def test_assistant_create_output_valid_creation(self): - output = AssistantCreateOutput(name="test", - data_type=DataType.UNDEFINED) - self.assertEqual(output.get_binding_name(), - "assistantCreate") - self.assertEqual(output.get_dict_repr(), - {"name": "test", - "dataType": DataType.UNDEFINED, - "direction": BindingDirection.OUT, - "type": "assistantCreate"}) + output = AssistantCreateOutput( + name="test", data_type=DataType.UNDEFINED + ) + self.assertEqual(output.get_binding_name(), "assistantCreate") + self.assertEqual( + output.get_dict_repr(), + { + "name": "test", + "dataType": DataType.UNDEFINED, + "direction": BindingDirection.OUT, + "type": "assistantCreate", + }, + ) def test_assistant_post_input_valid_creation(self): - input = AssistantPostInput(name="test", - id="test_id", - model="test_model", - chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 - collection_name="ChatState", - user_message="test_message", - data_type=DataType.UNDEFINED, - dummy_field="dummy") - self.assertEqual(input.get_binding_name(), - "assistantPost") - self.assertEqual(input.get_dict_repr(), - {"name": "test", - "id": "test_id", - "model": "test_model", - "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 - "collectionName": "ChatState", - "userMessage": "test_message", - "dataType": DataType.UNDEFINED, - "direction": BindingDirection.IN, - "dummyField": "dummy", - "type": "assistantPost"}) + input = AssistantPostInput( + name="test", + id="test_id", + chat_model="test_model", + chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 + collection_name="ChatState", + user_message="test_message", + temperature="1", + max_tokens="1", + is_reasoning_model=False, + data_type=DataType.UNDEFINED, + dummy_field="dummy", + ) + self.assertEqual(input.get_binding_name(), "assistantPost") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "id": "test_id", + "chatModel": "test_model", + "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 + "collectionName": "ChatState", + "userMessage": "test_message", + "temperature": "1", + "maxTokens": "1", + "isReasoningModel": False, + "dataType": DataType.UNDEFINED, + "direction": BindingDirection.IN, + "dummyField": "dummy", + "type": "assistantPost", + }, + ) + + def test_assistant_post_input_with_enum_chat_model(self): + input = AssistantPostInput( + name="test", + id="test_id", + chat_model=OpenAIModels.DefaultChatModel, + chat_storage_connection_setting="AzureWebJobsStorage", # noqa: E501 + collection_name="ChatState", + user_message="test_message", + temperature="1", + max_tokens="1", + is_reasoning_model=False, + data_type=DataType.UNDEFINED, + dummy_field="dummy", + ) + self.assertEqual(input.get_binding_name(), "assistantPost") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "id": "test_id", + "chatModel": OpenAIModels.DefaultChatModel, + "chatStorageConnectionSetting": "AzureWebJobsStorage", # noqa: E501 + "collectionName": "ChatState", + "userMessage": "test_message", + "temperature": "1", + "maxTokens": "1", + "isReasoningModel": False, + "dataType": DataType.UNDEFINED, + "direction": BindingDirection.IN, + "dummyField": "dummy", + "type": "assistantPost", + }, + ) def test_semantic_search_input_valid_creation(self): - input = SemanticSearchInput(name="test", - data_type=DataType.UNDEFINED, - chat_model=OpenAIModels.DefaultChatModel, - embeddings_model=OpenAIModels.DefaultEmbeddingsModel, # NoQA - collection="test_collection", - connection_name="test_connection", - system_prompt="test_prompt", - query="test_query", - max_knowledge_count=1, - dummy_field="dummy_field") - self.assertEqual(input.get_binding_name(), - "semanticSearch") - self.assertEqual(input.get_dict_repr(), - {"name": "test", - "dataType": DataType.UNDEFINED, - "direction": BindingDirection.IN, - "dummyField": "dummy_field", - "chatModel": OpenAIModels.DefaultChatModel, - "embeddingsModel": OpenAIModels.DefaultEmbeddingsModel, # NoQA - "type": "semanticSearch", - "collection": "test_collection", - "connectionName": "test_connection", - "systemPrompt": "test_prompt", - "maxKnowledgeCount": 1, - "query": "test_query"}) + input = SemanticSearchInput( + name="test", + data_type=DataType.UNDEFINED, + chat_model=OpenAIModels.DefaultChatModel, + embeddings_model=OpenAIModels.DefaultEmbeddingsModel, # NoQA + collection="test_collection", + search_connection_name="test_connection", + system_prompt="test_prompt", + query="test_query", + max_knowledge_count=1, + temperature="1", + max_tokens="1", + is_reasoning_model=False, + dummy_field="dummy_field", + ) + self.assertEqual(input.get_binding_name(), "semanticSearch") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "dataType": DataType.UNDEFINED, + "direction": BindingDirection.IN, + "dummyField": "dummy_field", + "chatModel": OpenAIModels.DefaultChatModel, + "embeddingsModel": OpenAIModels.DefaultEmbeddingsModel, # NoQA + "type": "semanticSearch", + "collection": "test_collection", + "searchConnectionName": "test_connection", + "systemPrompt": "test_prompt", + "maxKnowledgeCount": 1, + "temperature": "1", + "maxTokens": "1", + "isReasoningModel": False, + "query": "test_query", + }, + ) + + def test_semantic_search_input_with_string_models(self): + input = SemanticSearchInput( + name="test", + data_type=DataType.UNDEFINED, + chat_model="gpt-4o", + embeddings_model="text-embedding-3-large", + collection="test_collection", + search_connection_name="test_connection", + system_prompt="test_prompt", + query="test_query", + max_knowledge_count=1, + temperature="1", + max_tokens="1", + is_reasoning_model=True, + dummy_field="dummy_field", + ) + self.assertEqual(input.get_binding_name(), "semanticSearch") + self.assertEqual( + input.get_dict_repr(), + { + "name": "test", + "dataType": DataType.UNDEFINED, + "direction": BindingDirection.IN, + "dummyField": "dummy_field", + "chatModel": "gpt-4o", + "embeddingsModel": "text-embedding-3-large", + "type": "semanticSearch", + "collection": "test_collection", + "searchConnectionName": "test_connection", + "systemPrompt": "test_prompt", + "maxKnowledgeCount": 1, + "temperature": "1", + "maxTokens": "1", + "isReasoningModel": True, + "query": "test_query", + }, + ) def test_embeddings_store_output_valid_creation(self): - output = EmbeddingsStoreOutput(name="test", - data_type=DataType.UNDEFINED, - input="test_input", - input_type="test_input_type", - connection_name="test_connection", - max_overlap=1, - max_chunk_length=1, - collection="test_collection", - model=OpenAIModels.DefaultChatModel, - dummy_field="dummy_field") - self.assertEqual(output.get_binding_name(), - "embeddingsStore") - self.assertEqual(output.get_dict_repr(), - {"name": "test", - "dataType": DataType.UNDEFINED, - "direction": BindingDirection.OUT, - "dummyField": "dummy_field", - "input": "test_input", - "inputType": "test_input_type", - "collection": "test_collection", - "model": OpenAIModels.DefaultChatModel, - "connectionName": "test_connection", - "maxOverlap": 1, - "maxChunkLength": 1, - "type": "embeddingsStore"}) + output = EmbeddingsStoreOutput( + name="test", + data_type=DataType.UNDEFINED, + input="test_input", + input_type="test_input_type", + store_connection_name="test_connection", + max_overlap=1, + max_chunk_length=1, + collection="test_collection", + embeddings_model=OpenAIModels.DefaultEmbeddingsModel, # noqa: E501 + dummy_field="dummy_field", + ) + self.assertEqual(output.get_binding_name(), "embeddingsStore") + self.assertEqual( + output.get_dict_repr(), + { + "name": "test", + "dataType": DataType.UNDEFINED, + "direction": BindingDirection.OUT, + "dummyField": "dummy_field", + "input": "test_input", + "inputType": "test_input_type", + "collection": "test_collection", + "embeddingsModel": OpenAIModels.DefaultEmbeddingsModel, # noqa: E501 + "storeConnectionName": "test_connection", + "maxOverlap": 1, + "maxChunkLength": 1, + "type": "embeddingsStore", + }, + ) + + def test_embeddings_store_output_with_string_embeddings_model(self): + output = EmbeddingsStoreOutput( + name="test", + data_type=DataType.UNDEFINED, + input="test_input", + input_type="test_input_type", + store_connection_name="test_connection", + max_overlap=1, + max_chunk_length=1, + collection="test_collection", + embeddings_model="text-embedding-3-small", + dummy_field="dummy_field", + ) + self.assertEqual(output.get_binding_name(), "embeddingsStore") + self.assertEqual( + output.get_dict_repr(), + { + "name": "test", + "dataType": DataType.UNDEFINED, + "direction": BindingDirection.OUT, + "dummyField": "dummy_field", + "input": "test_input", + "inputType": "test_input_type", + "collection": "test_collection", + "embeddingsModel": "text-embedding-3-small", + "storeConnectionName": "test_connection", + "maxOverlap": 1, + "maxChunkLength": 1, + "type": "embeddingsStore", + }, + )