Skip to content

Commit f3e7e76

Browse files
feat(client): support passing BaseModels to request params at runtime (#166)
1 parent 6322114 commit f3e7e76

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

src/finch/_utils/_transform.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from datetime import date, datetime
55
from typing_extensions import Literal, get_args, override, get_type_hints
66

7+
import pydantic
8+
79
from ._utils import (
810
is_list,
911
is_mapping,
@@ -14,7 +16,7 @@
1416
is_annotated_type,
1517
strip_annotated_type,
1618
)
17-
from .._compat import is_typeddict
19+
from .._compat import model_dump, is_typeddict
1820

1921
_T = TypeVar("_T")
2022

@@ -165,6 +167,9 @@ def _transform_recursive(
165167
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
166168
return data
167169

170+
if isinstance(data, pydantic.BaseModel):
171+
return model_dump(data, exclude_unset=True, exclude_defaults=True)
172+
168173
return _transform_value(data, annotation)
169174

170175

tests/test_transform.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import List, Union, Optional
3+
from typing import Any, List, Union, Optional
44
from datetime import date, datetime
55
from typing_extensions import Required, Annotated, TypedDict
66

7+
import pytest
8+
79
from finch._utils import PropertyInfo, transform, parse_datetime
10+
from finch._models import BaseModel
811

912

1013
class Foo1(TypedDict):
@@ -186,3 +189,44 @@ class DateDictWithRequiredAlias(TypedDict, total=False):
186189
def test_datetime_with_alias() -> None:
187190
assert transform({"required_prop": None}, DateDictWithRequiredAlias) == {"prop": None} # type: ignore[comparison-overlap]
188191
assert transform({"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap]
192+
193+
194+
class MyModel(BaseModel):
195+
foo: str
196+
197+
198+
def test_pydantic_model_to_dictionary() -> None:
199+
assert transform(MyModel(foo="hi!"), Any) == {"foo": "hi!"}
200+
assert transform(MyModel.construct(foo="hi!"), Any) == {"foo": "hi!"}
201+
202+
203+
def test_pydantic_empty_model() -> None:
204+
assert transform(MyModel.construct(), Any) == {}
205+
206+
207+
def test_pydantic_unknown_field() -> None:
208+
assert transform(MyModel.construct(my_untyped_field=True), Any) == {"my_untyped_field": True}
209+
210+
211+
def test_pydantic_mismatched_types() -> None:
212+
model = MyModel.construct(foo=True)
213+
with pytest.warns(UserWarning):
214+
params = transform(model, Any)
215+
assert params == {"foo": True}
216+
217+
218+
def test_pydantic_mismatched_object_type() -> None:
219+
model = MyModel.construct(foo=MyModel.construct(hello="world"))
220+
with pytest.warns(UserWarning):
221+
params = transform(model, Any)
222+
assert params == {"foo": {"hello": "world"}}
223+
224+
225+
class ModelNestedObjects(BaseModel):
226+
nested: MyModel
227+
228+
229+
def test_pydantic_nested_objects() -> None:
230+
model = ModelNestedObjects.construct(nested={"foo": "stainless"})
231+
assert isinstance(model.nested, MyModel)
232+
assert transform(model, Any) == {"nested": {"foo": "stainless"}}

0 commit comments

Comments
 (0)