Skip to content

Commit e1ceaa5

Browse files
committed
feat(parsing): add support for pydantic dataclasses
1 parent 4e83b57 commit e1ceaa5

File tree

3 files changed

+99
-14
lines changed

3 files changed

+99
-14
lines changed

src/openai/lib/_parsing/_completions.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from .._tools import PydanticFunctionTool
1010
from ..._types import NOT_GIVEN, NotGiven
1111
from ..._utils import is_dict, is_given
12-
from ..._compat import model_parse_json
12+
from ..._compat import PYDANTIC_V2, model_parse_json
1313
from ..._models import construct_type_unchecked
14-
from .._pydantic import to_strict_json_schema
14+
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
1515
from ...types.chat import (
1616
ParsedChoice,
1717
ChatCompletion,
@@ -216,14 +216,16 @@ def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
216216
return cast(FunctionDefinition, input_fn).get("strict") or False
217217

218218

219-
def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
220-
return issubclass(typ, pydantic.BaseModel)
221-
222-
223219
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
224220
if is_basemodel_type(response_format):
225221
return cast(ResponseFormatT, model_parse_json(response_format, content))
226222

223+
if is_dataclass_like_type(response_format):
224+
if not PYDANTIC_V2:
225+
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
226+
227+
return pydantic.TypeAdapter(response_format).validate_json(content)
228+
227229
raise TypeError(f"Unable to automatically parse response format type {response_format}")
228230

229231

@@ -241,14 +243,22 @@ def type_to_response_format_param(
241243
# can only be a `type`
242244
response_format = cast(type, response_format)
243245

244-
if not is_basemodel_type(response_format):
246+
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
247+
248+
if is_basemodel_type(response_format):
249+
name = response_format.__name__
250+
json_schema_type = response_format
251+
elif is_dataclass_like_type(response_format):
252+
name = response_format.__name__
253+
json_schema_type = pydantic.TypeAdapter(response_format)
254+
else:
245255
raise TypeError(f"Unsupported response_format type - {response_format}")
246256

247257
return {
248258
"type": "json_schema",
249259
"json_schema": {
250-
"schema": to_strict_json_schema(response_format),
251-
"name": response_format.__name__,
260+
"schema": to_strict_json_schema(json_schema_type),
261+
"name": name,
252262
"strict": True,
253263
},
254264
}

src/openai/lib/_pydantic.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
import inspect
4+
from typing import Any, TypeVar
45
from typing_extensions import TypeGuard
56

67
import pydantic
78

89
from .._utils import is_dict as _is_dict, is_list
9-
from .._compat import model_json_schema
10+
from .._compat import PYDANTIC_V2, model_json_schema
1011

12+
_T = TypeVar("_T")
13+
14+
15+
def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
16+
if inspect.isclass(model) and is_basemodel_type(model):
17+
schema = model_json_schema(model)
18+
elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
19+
schema = model.json_schema()
20+
else:
21+
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
1122

12-
def to_strict_json_schema(model: type[pydantic.BaseModel]) -> dict[str, Any]:
13-
schema = model_json_schema(model)
1423
return _ensure_strict_json_schema(schema, path=(), root=schema)
1524

1625

@@ -110,6 +119,15 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object:
110119
return resolved
111120

112121

122+
def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
123+
return issubclass(typ, pydantic.BaseModel)
124+
125+
126+
def is_dataclass_like_type(typ: type) -> bool:
127+
"""Returns True if the given type likely used `@pydantic.dataclass`"""
128+
return hasattr(typ, "__pydantic_config__")
129+
130+
113131
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
114132
# just pretend that we know there are only `str` keys
115133
# as that check is not worth the performance cost

tests/lib/chat/test_completions.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import json
55
from enum import Enum
6-
from typing import Any, Callable
6+
from typing import Any, List, Callable
77
from typing_extensions import Literal, TypeVar
88

99
import httpx
@@ -260,6 +260,63 @@ class Location(BaseModel):
260260
)
261261

262262

263+
@pytest.mark.respx(base_url=base_url)
264+
@pytest.mark.skipif(not PYDANTIC_V2, reason="dataclasses only supported in v2")
265+
def test_parse_pydantic_dataclass(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
266+
from pydantic.dataclasses import dataclass
267+
268+
@dataclass
269+
class CalendarEvent:
270+
name: str
271+
date: str
272+
participants: List[str]
273+
274+
completion = _make_snapshot_request(
275+
lambda c: c.beta.chat.completions.parse(
276+
model="gpt-4o-2024-08-06",
277+
messages=[
278+
{"role": "system", "content": "Extract the event information."},
279+
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
280+
],
281+
response_format=CalendarEvent,
282+
),
283+
content_snapshot=snapshot(
284+
'{"id": "chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3", "object": "chat.completion", "created": 1723761008, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"name\\":\\"Science Fair\\",\\"date\\":\\"Friday\\",\\"participants\\":[\\"Alice\\",\\"Bob\\"]}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 32, "completion_tokens": 17, "total_tokens": 49}, "system_fingerprint": "fp_2a322c9ffc"}'
285+
),
286+
mock_client=client,
287+
respx_mock=respx_mock,
288+
)
289+
290+
assert print_obj(completion, monkeypatch) == snapshot(
291+
"""\
292+
ParsedChatCompletion[CalendarEvent](
293+
choices=[
294+
ParsedChoice[CalendarEvent](
295+
finish_reason='stop',
296+
index=0,
297+
logprobs=None,
298+
message=ParsedChatCompletionMessage[CalendarEvent](
299+
content='{"name":"Science Fair","date":"Friday","participants":["Alice","Bob"]}',
300+
function_call=None,
301+
parsed=CalendarEvent(name='Science Fair', date='Friday', participants=['Alice', 'Bob']),
302+
refusal=None,
303+
role='assistant',
304+
tool_calls=[]
305+
)
306+
)
307+
],
308+
created=1723761008,
309+
id='chatcmpl-9wdGqXkJJARAz7rOrLH5u5FBwLjF3',
310+
model='gpt-4o-2024-08-06',
311+
object='chat.completion',
312+
service_tier=None,
313+
system_fingerprint='fp_2a322c9ffc',
314+
usage=CompletionUsage(completion_tokens=17, prompt_tokens=32, total_tokens=49)
315+
)
316+
"""
317+
)
318+
319+
263320
@pytest.mark.respx(base_url=base_url)
264321
def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
265322
completion = _make_snapshot_request(

0 commit comments

Comments
 (0)