Skip to content

Commit c36ff77

Browse files
feat(client): support parsing custom response types (#277)
1 parent 85c0e90 commit c36ff77

File tree

7 files changed

+393
-78
lines changed

7 files changed

+393
-78
lines changed

src/finch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ._types import NoneType, Transport, ProxiesTypes
55
from ._utils import file_from_path
66
from ._client import Finch, Client, Stream, Timeout, Transport, AsyncFinch, AsyncClient, AsyncStream, RequestOptions
7+
from ._models import BaseModel
78
from ._version import __title__, __version__
89
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
910
from ._exceptions import (
@@ -54,6 +55,7 @@
5455
"Finch",
5556
"AsyncFinch",
5657
"file_from_path",
58+
"BaseModel",
5759
]
5860

5961
_setup_logging()

src/finch/_legacy_response.py

+70-32
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,28 @@
55
import logging
66
import datetime
77
import functools
8-
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast
9-
from typing_extensions import Awaitable, ParamSpec, get_args, override, deprecated, get_origin
8+
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload
9+
from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin
1010

1111
import anyio
1212
import httpx
13+
import pydantic
1314

1415
from ._types import NoneType
1516
from ._utils import is_given
1617
from ._models import BaseModel, is_basemodel
1718
from ._constants import RAW_RESPONSE_HEADER
19+
from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type
1820
from ._exceptions import APIResponseValidationError
1921

2022
if TYPE_CHECKING:
2123
from ._models import FinalRequestOptions
22-
from ._base_client import Stream, BaseClient, AsyncStream
24+
from ._base_client import BaseClient
2325

2426

2527
P = ParamSpec("P")
2628
R = TypeVar("R")
29+
_T = TypeVar("_T")
2730

2831
log: logging.Logger = logging.getLogger(__name__)
2932

@@ -43,7 +46,7 @@ class LegacyAPIResponse(Generic[R]):
4346

4447
_cast_to: type[R]
4548
_client: BaseClient[Any, Any]
46-
_parsed: R | None
49+
_parsed_by_type: dict[type[Any], Any]
4750
_stream: bool
4851
_stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None
4952
_options: FinalRequestOptions
@@ -62,27 +65,60 @@ def __init__(
6265
) -> None:
6366
self._cast_to = cast_to
6467
self._client = client
65-
self._parsed = None
68+
self._parsed_by_type = {}
6669
self._stream = stream
6770
self._stream_cls = stream_cls
6871
self._options = options
6972
self.http_response = raw
7073

74+
@overload
75+
def parse(self, *, to: type[_T]) -> _T:
76+
...
77+
78+
@overload
7179
def parse(self) -> R:
80+
...
81+
82+
def parse(self, *, to: type[_T] | None = None) -> R | _T:
7283
"""Returns the rich python representation of this response's data.
7384
85+
NOTE: For the async client: this will become a coroutine in the next major version.
86+
7487
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
7588
76-
NOTE: For the async client: this will become a coroutine in the next major version.
89+
You can customise the type that the response is parsed into through
90+
the `to` argument, e.g.
91+
92+
```py
93+
from finch import BaseModel
94+
95+
96+
class MyModel(BaseModel):
97+
foo: str
98+
99+
100+
obj = response.parse(to=MyModel)
101+
print(obj.foo)
102+
```
103+
104+
We support parsing:
105+
- `BaseModel`
106+
- `dict`
107+
- `list`
108+
- `Union`
109+
- `str`
110+
- `httpx.Response`
77111
"""
78-
if self._parsed is not None:
79-
return self._parsed
112+
cache_key = to if to is not None else self._cast_to
113+
cached = self._parsed_by_type.get(cache_key)
114+
if cached is not None:
115+
return cached # type: ignore[no-any-return]
80116

81-
parsed = self._parse()
117+
parsed = self._parse(to=to)
82118
if is_given(self._options.post_parser):
83119
parsed = self._options.post_parser(parsed)
84120

85-
self._parsed = parsed
121+
self._parsed_by_type[cache_key] = parsed
86122
return parsed
87123

88124
@property
@@ -135,13 +171,29 @@ def elapsed(self) -> datetime.timedelta:
135171
"""The time taken for the complete request/response cycle to complete."""
136172
return self.http_response.elapsed
137173

138-
def _parse(self) -> R:
174+
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
139175
if self._stream:
176+
if to:
177+
if not is_stream_class_type(to):
178+
raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}")
179+
180+
return cast(
181+
_T,
182+
to(
183+
cast_to=extract_stream_chunk_type(
184+
to,
185+
failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]",
186+
),
187+
response=self.http_response,
188+
client=cast(Any, self._client),
189+
),
190+
)
191+
140192
if self._stream_cls:
141193
return cast(
142194
R,
143195
self._stream_cls(
144-
cast_to=_extract_stream_chunk_type(self._stream_cls),
196+
cast_to=extract_stream_chunk_type(self._stream_cls),
145197
response=self.http_response,
146198
client=cast(Any, self._client),
147199
),
@@ -160,7 +212,7 @@ def _parse(self) -> R:
160212
),
161213
)
162214

163-
cast_to = self._cast_to
215+
cast_to = to if to is not None else self._cast_to
164216
if cast_to is NoneType:
165217
return cast(R, None)
166218

@@ -186,14 +238,9 @@ def _parse(self) -> R:
186238
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
187239
return cast(R, response)
188240

189-
# The check here is necessary as we are subverting the the type system
190-
# with casts as the relationship between TypeVars and Types are very strict
191-
# which means we must return *exactly* what was input or transform it in a
192-
# way that retains the TypeVar state. As we cannot do that in this function
193-
# then we have to resort to using `cast`. At the time of writing, we know this
194-
# to be safe as we have handled all the types that could be bound to the
195-
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
196-
# this function would become unsafe but a type checker would not report an error.
241+
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
242+
raise TypeError("Pydantic models must subclass our base model type, e.g. `from finch import BaseModel`")
243+
197244
if (
198245
cast_to is not object
199246
and not origin is list
@@ -202,12 +249,12 @@ def _parse(self) -> R:
202249
and not issubclass(origin, BaseModel)
203250
):
204251
raise RuntimeError(
205-
f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
252+
f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}."
206253
)
207254

208255
# split is required to handle cases where additional information is included
209256
# in the response, e.g. application/json; charset=utf-8
210-
content_type, *_ = response.headers.get("content-type").split(";")
257+
content_type, *_ = response.headers.get("content-type", "*").split(";")
211258
if content_type != "application/json":
212259
if is_basemodel(cast_to):
213260
try:
@@ -253,15 +300,6 @@ def __init__(self) -> None:
253300
)
254301

255302

256-
def _extract_stream_chunk_type(stream_cls: type) -> type:
257-
args = get_args(stream_cls)
258-
if not args:
259-
raise TypeError(
260-
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
261-
)
262-
return cast(type, args[0])
263-
264-
265303
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
266304
"""Higher order function that takes one of our bound API methods and wraps it
267305
to support returning the raw `APIResponse` object directly.

0 commit comments

Comments
 (0)