Skip to content

feat(client): support passing BaseModels to request params at runtime #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/finch/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints

import pydantic

from ._utils import (
is_list,
is_mapping,
Expand All @@ -14,7 +16,7 @@
is_annotated_type,
strip_annotated_type,
)
from .._compat import is_typeddict
from .._compat import model_dump, is_typeddict

_T = TypeVar("_T")

Expand Down Expand Up @@ -165,6 +167,9 @@ def _transform_recursive(
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
return data

if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, exclude_defaults=True)

return _transform_value(data, annotation)


Expand Down
46 changes: 45 additions & 1 deletion tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from typing import List, Union, Optional
from typing import Any, List, Union, Optional
from datetime import date, datetime
from typing_extensions import Required, Annotated, TypedDict

import pytest

from finch._utils import PropertyInfo, transform, parse_datetime
from finch._models import BaseModel


class Foo1(TypedDict):
Expand Down Expand Up @@ -186,3 +189,44 @@ class DateDictWithRequiredAlias(TypedDict, total=False):
def test_datetime_with_alias() -> None:
assert transform({"required_prop": None}, DateDictWithRequiredAlias) == {"prop": None} # type: ignore[comparison-overlap]
assert transform({"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap]


class MyModel(BaseModel):
foo: str


def test_pydantic_model_to_dictionary() -> None:
assert transform(MyModel(foo="hi!"), Any) == {"foo": "hi!"}
assert transform(MyModel.construct(foo="hi!"), Any) == {"foo": "hi!"}


def test_pydantic_empty_model() -> None:
assert transform(MyModel.construct(), Any) == {}


def test_pydantic_unknown_field() -> None:
assert transform(MyModel.construct(my_untyped_field=True), Any) == {"my_untyped_field": True}


def test_pydantic_mismatched_types() -> None:
model = MyModel.construct(foo=True)
with pytest.warns(UserWarning):
params = transform(model, Any)
assert params == {"foo": True}


def test_pydantic_mismatched_object_type() -> None:
model = MyModel.construct(foo=MyModel.construct(hello="world"))
with pytest.warns(UserWarning):
params = transform(model, Any)
assert params == {"foo": {"hello": "world"}}


class ModelNestedObjects(BaseModel):
nested: MyModel


def test_pydantic_nested_objects() -> None:
model = ModelNestedObjects.construct(nested={"foo": "stainless"})
assert isinstance(model.nested, MyModel)
assert transform(model, Any) == {"nested": {"foo": "stainless"}}