Skip to content

Commit 1c2091e

Browse files
lilyyduaacebo
andauthored
[PY] feat: embeddings (#1217)
## Linked issues closes: #1070 ## Details Implemented the logic and tests for embeddings for Python, based on JS. 1. Client API calls and types are using the [OpenAI Python API library](https://github.com/openai/openai-python/tree/main) 2. I restructured the interfaces and classes for AzureOpenAI and OpenAI into its own classes. Two reasons- (a) some logic is different, and (b) this allows for better maintainability + response to bugs since they depend on different clients and configurations from the API library. 3. To replicate `setTimeout` in retries, I've used `asyncio.sleep` but open to better preferences/suggestions if any 4. Logging follows C# implementation (no colorization) 5. Had to disable 2 _pylint_ and _mypy_ type/argument checker (`unused-argument` because mock.patch requires the same method signature, and `arg-type` because openai.types doesn't yet export the Usage class- [filed a request here](openai/openai-python#1135)) ## Attestation Checklist - [x] My code follows the style guidelines of this project - I have checked for/fixed spelling, linting, and other errors - I have commented my code for clarity - I have made corresponding changes to the documentation (updating the doc strings in the code is sufficient) - My changes generate no new warnings - I have added tests that validates my changes, and provides sufficient test coverage. I have tested with: - Local testing - E2E testing in Teams - New and existing unit tests pass locally with my changes --------- Co-authored-by: Alex Acebo <[email protected]>
1 parent d4d3349 commit 1c2091e

16 files changed

+1140
-176
lines changed

python/packages/ai/poetry.lock

+411-138
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

python/packages/ai/pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ aiohttp = "^3.8.5"
1717
jsonschema = "^4.21.1"
1818
types-pyyaml = "^6.0.12.12"
1919
pyyaml = "^6.0.1"
20+
dataclasses-json = "^0.6.4"
21+
openai = "^1.11.1"
2022

2123
[tool.poetry.group.dev.dependencies]
2224
pytest = "^7.4.0"
@@ -26,7 +28,7 @@ pytest-asyncio = "^0.21.1"
2628
black = "^23.7.0"
2729
isort = "^5.12.0"
2830
mypy = "^1.5.0"
29-
dataclasses-json = "^0.6.4"
31+
httpx = "^0.26.0"
3032

3133
[tool.poetry.scripts]
3234
lint = "scripts:lint"

python/packages/ai/teams/ai/augmentations/monologue_augmentation.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,9 @@ async def validate_response(
185185
if validation_result.value:
186186
# pylint:disable=no-member, line-too-long
187187
# from_dict provided from @dataclass_json decorator
188-
monologue = InnerMonologue.from_dict(validation_result.value) # type: ignore[attr-defined]
188+
monologue = InnerMonologue.from_dict(validation_result.value) # type: ignore[attr-defined]
189189
parameters = (
190-
json.dumps(monologue.action.parameters)
191-
if monologue.action.parameters
192-
else ""
190+
json.dumps(monologue.action.parameters) if monologue.action.parameters else ""
193191
)
194192
message = Message[str](
195193
role="assistant",
@@ -227,20 +225,21 @@ async def create_plan_from_response(
227225
Plan: The created plan.
228226
"""
229227
# Identify the action to perform
230-
command: PredictedCommand
231228
if response.message and response.message.content:
229+
command: PredictedCommand
232230
monologue: InnerMonologue = response.message.content
233231

234232
if monologue.action.name == "SAY":
235-
params = monologue.action.parameters
233+
params = monologue.action.parameters
236234
response_val = cast(str, params.get("text")) if params else ""
237-
command = PredictedSayCommand(response= response_val)
235+
command = PredictedSayCommand(response=response_val)
238236
else:
239237
command = PredictedDoCommand(
240238
action=monologue.action.name,
241239
parameters=monologue.action.parameters if monologue.action.parameters else {},
242240
)
243-
return Plan(commands=[command])
241+
return Plan(commands=[command])
242+
return Plan()
244243

245244
def _append_say_action(self, actions: List[ChatCompletionAction]) -> List[ChatCompletionAction]:
246245
clone = actions

python/packages/ai/teams/ai/augmentations/sequence_augmentation.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def validate_response(
118118
for index, command in enumerate(plan.commands):
119119
if command.type == CommandType.DO:
120120
# Ensure that the model specified an action
121-
if not command.action: # type: ignore[attr-defined]
121+
if not command.action: # type: ignore[attr-defined]
122122
return Validation(
123123
valid=False,
124124
feedback='The plan JSON is missing the DO "action" for '
@@ -127,25 +127,29 @@ async def validate_response(
127127

128128
# Ensure that the action is valid
129129
parameters: str = ""
130-
if command.parameters: # type: ignore[attr-defined]
131-
parameters = json.dumps(command.parameters) # type: ignore[attr-defined]
130+
if command.parameters: # type: ignore[attr-defined]
131+
parameters = json.dumps(command.parameters) # type: ignore[attr-defined]
132132
message = Message[str](
133133
role="assistant",
134134
content=None,
135-
function_call=FunctionCall(name=command.action, # type: ignore[attr-defined]
136-
arguments=parameters),
135+
function_call=FunctionCall(
136+
name=command.action, arguments=parameters # type: ignore[attr-defined]
137+
),
137138
)
138139
action_validation = await self._action_validator.validate_response(
139-
context, memory, tokenizer,
140-
PromptResponse(message=message), remaining_attempts
140+
context,
141+
memory,
142+
tokenizer,
143+
PromptResponse(message=message),
144+
remaining_attempts,
141145
)
142146

143147
if not action_validation.valid:
144148
return cast(Any, action_validation)
145149

146150
elif command.type == CommandType.SAY:
147151
# Ensure that the model specified a response
148-
if not command.response: # type: ignore[attr-defined]
152+
if not command.response: # type: ignore[attr-defined]
149153
return Validation(
150154
valid=False,
151155
feedback='The plan JSON is missing the SAY "response" '
@@ -155,7 +159,7 @@ async def validate_response(
155159
return Validation(
156160
valid=False,
157161
feedback="The plan JSON contains an unknown command"
158-
+ f'type of ${command.type}. Only use DO or SAY commands.',
162+
+ f"type of ${command.type}. Only use DO or SAY commands.",
159163
)
160164

161165
# Return the validated monologue
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
from .azure_openai_embeddings import AzureOpenAIEmbeddings
6+
from .azure_openai_embeddings_options import AzureOpenAIEmbeddingsOptions
7+
from .embeddings_model import EmbeddingsModel
8+
from .embeddings_response import EmbeddingsResponse
9+
from .openai_embeddings import OpenAIEmbeddings
10+
from .openai_embeddings_options import OpenAIEmbeddingsOptions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
import asyncio
7+
from datetime import datetime
8+
from logging import Logger
9+
from operator import attrgetter
10+
from typing import List, Union
11+
12+
import openai
13+
14+
from teams.ai.embeddings.azure_openai_embeddings_options import (
15+
AzureOpenAIEmbeddingsOptions,
16+
)
17+
from teams.ai.embeddings.embeddings_model import EmbeddingsModel
18+
from teams.ai.embeddings.embeddings_response import EmbeddingsResponse
19+
20+
21+
class AzureOpenAIEmbeddings(EmbeddingsModel):
22+
"""
23+
A `EmbeddingsModel` for calling the AzureOpenAI hosted model.
24+
"""
25+
26+
_user_agent = "@microsoft/teams-ai-v1"
27+
_log: Logger
28+
29+
options: AzureOpenAIEmbeddingsOptions
30+
"Options the client was configured with."
31+
32+
def __init__(self, options: AzureOpenAIEmbeddingsOptions, log=Logger("teams.ai")) -> None:
33+
"""
34+
Creates a new `AzureOpenAIEmbeddings` instance.
35+
36+
Args:
37+
options (AzureOpenAIEmbeddingsOptions): Options for configuring the embeddings client.
38+
log (Logger): Logger to use.
39+
"""
40+
41+
self.options = options
42+
self._log = log
43+
44+
if not self.options.retry_policy:
45+
self.options.retry_policy = [2, 5]
46+
if not self.options.azure_api_version:
47+
self.options.azure_api_version = "2023-05-15"
48+
49+
endpoint = self.options.azure_endpoint.strip()
50+
if endpoint[-1] == "/":
51+
endpoint = endpoint[0 : (len(endpoint) - 1)]
52+
53+
if not endpoint.lower().startswith("https://"):
54+
raise ValueError(
55+
f"""
56+
Client created with an invalid endpoint of \"{endpoint}\".
57+
The endpoint must be a valid HTTPS url.
58+
"""
59+
)
60+
61+
self.options.azure_endpoint = endpoint
62+
63+
async def create_embeddings(
64+
self, inputs: Union[str, List[str], List[int], List[List[int]]], retry_count=0
65+
) -> EmbeddingsResponse:
66+
"""
67+
Creates embeddings for the given inputs.
68+
69+
Args:
70+
inputs(Union[str, List[str]]): Text inputs to create embeddings for.
71+
72+
Returns:
73+
EmbeddingsResponse: A status and embeddings/message when an error occurs.
74+
"""
75+
76+
if self.options.log_requests:
77+
self._log.info("Embeddings REQUEST: inputs=%s", inputs)
78+
79+
if not self.options.request_config:
80+
self.options.request_config = {"api-key": self.options.azure_api_key}
81+
else:
82+
self.options.request_config.update({"api-key": self.options.azure_api_key})
83+
84+
if not self.options.request_config.get("Content-Type"):
85+
self.options.request_config.update({"Content-Type": "application/json"})
86+
87+
if not self.options.request_config.get("User-Agent"):
88+
self.options.request_config.update({"User-Agent": self._user_agent})
89+
90+
client = openai.AsyncAzureOpenAI(
91+
api_key=self.options.azure_api_key,
92+
api_version=self.options.azure_api_version,
93+
azure_endpoint=self.options.azure_endpoint,
94+
default_headers=self.options.request_config,
95+
)
96+
try:
97+
start_time = datetime.now()
98+
res = await client.embeddings.create(input=inputs, model=self.options.azure_deployment)
99+
100+
data = list(map(attrgetter("embedding"), sorted(res.data, key=lambda x: x.index)))
101+
102+
if self.options.log_requests:
103+
duration = datetime.now() - start_time
104+
self._log.info(
105+
"Embeddings SUCCEEDED: duration=%s response=%s", duration.total_seconds, data
106+
)
107+
108+
return EmbeddingsResponse(status="success", output=data)
109+
except openai.RateLimitError:
110+
if self.options.retry_policy:
111+
if retry_count < len(self.options.retry_policy):
112+
delay = self.options.retry_policy[retry_count]
113+
await asyncio.sleep(delay)
114+
return await self.create_embeddings(inputs, retry_count + 1)
115+
return EmbeddingsResponse(
116+
status="rate_limited", output="The embeddings API returned a rate limit error."
117+
)
118+
except openai.APIError as err:
119+
return EmbeddingsResponse(
120+
status="error",
121+
output=f"The embeddings API returned an error status of {err.code}: {err.message}",
122+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
from dataclasses import dataclass
7+
from typing import Dict, List, Optional
8+
9+
10+
@dataclass
11+
class AzureOpenAIEmbeddingsOptions:
12+
"""
13+
Options for configuring an `AzureOpenAIEmbeddings` to generate embeddings.
14+
"""
15+
16+
azure_api_key: str
17+
"API key to use when making requests to Azure OpenAI."
18+
19+
azure_endpoint: str
20+
"Deployment endpoint to use."
21+
22+
azure_deployment: str
23+
"Name of the Azure OpenAI deployment (model) to use."
24+
25+
azure_api_version: Optional[str] = None
26+
"Optional. Version of the API being called. Defaults to `2023-05-15`."
27+
28+
log_requests: Optional[bool] = False
29+
"Whether to log requests to the console, useful for debugging and defaults to `false`"
30+
31+
retry_policy: Optional[List[int]] = None
32+
"Optional. Retry policy to use in seconds. The default retry policy is `[2, 5]`."
33+
34+
request_config: Optional[Dict[str, str]] = None
35+
"Request options to use."
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
from abc import ABC, abstractmethod
7+
from typing import List, Union
8+
9+
from teams.ai.embeddings.embeddings_response import EmbeddingsResponse
10+
11+
12+
class EmbeddingsModel(ABC):
13+
"""
14+
An AI model that can be used to create embeddings.
15+
"""
16+
17+
@abstractmethod
18+
async def create_embeddings(
19+
self, inputs: Union[str, List[str], List[int], List[List[int]]]
20+
) -> EmbeddingsResponse:
21+
"""
22+
Creates embeddings for the given inputs.
23+
24+
Args:
25+
inputs (Union[str, List[str],
26+
List[int], List[List[int]]]): Text inputs to create embeddings for.
27+
28+
Returns:
29+
EmbeddingsResponse: A status and embeddings/message when an error occurs.
30+
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
from dataclasses import dataclass
7+
from typing import List, Literal, Optional, Union
8+
9+
EmbeddingsResponseStatus = Literal[
10+
"success", # The embeddings were successfully created.
11+
"error", # An error occurred while creating the embeddings.
12+
"rate_limited", # The request was rate limited.
13+
]
14+
15+
16+
@dataclass
17+
class EmbeddingsResponse:
18+
"""
19+
Response returned for embeddings.
20+
"""
21+
22+
status: EmbeddingsResponseStatus
23+
"Status of the embeddings response."
24+
25+
output: Optional[Union[List[List[int]], str]] = None
26+
"Optional. Embeddings for the given inputs or the error string."
27+
28+
message: Optional[str] = None
29+
"Optional. Message status."

0 commit comments

Comments
 (0)