5
5
import logging
6
6
import datetime
7
7
import functools
8
- from typing import TYPE_CHECKING , Any , Union , Generic , TypeVar , Callable , Iterator , AsyncIterator , cast
9
- from typing_extensions import Awaitable , ParamSpec , get_args , override , deprecated , get_origin
8
+ from typing import TYPE_CHECKING , Any , Union , Generic , TypeVar , Callable , Iterator , AsyncIterator , cast , overload
9
+ from typing_extensions import Awaitable , ParamSpec , override , deprecated , get_origin
10
10
11
11
import anyio
12
12
import httpx
13
+ import pydantic
13
14
14
15
from ._types import NoneType
15
16
from ._utils import is_given
16
17
from ._models import BaseModel , is_basemodel
17
18
from ._constants import RAW_RESPONSE_HEADER
19
+ from ._streaming import Stream , AsyncStream , is_stream_class_type , extract_stream_chunk_type
18
20
from ._exceptions import APIResponseValidationError
19
21
20
22
if TYPE_CHECKING :
21
23
from ._models import FinalRequestOptions
22
- from ._base_client import Stream , BaseClient , AsyncStream
24
+ from ._base_client import BaseClient
23
25
24
26
25
27
P = ParamSpec ("P" )
26
28
R = TypeVar ("R" )
29
+ _T = TypeVar ("_T" )
27
30
28
31
log : logging .Logger = logging .getLogger (__name__ )
29
32
@@ -43,7 +46,7 @@ class LegacyAPIResponse(Generic[R]):
43
46
44
47
_cast_to : type [R ]
45
48
_client : BaseClient [Any , Any ]
46
- _parsed : R | None
49
+ _parsed_by_type : dict [ type [ Any ], Any ]
47
50
_stream : bool
48
51
_stream_cls : type [Stream [Any ]] | type [AsyncStream [Any ]] | None
49
52
_options : FinalRequestOptions
@@ -62,27 +65,60 @@ def __init__(
62
65
) -> None :
63
66
self ._cast_to = cast_to
64
67
self ._client = client
65
- self ._parsed = None
68
+ self ._parsed_by_type = {}
66
69
self ._stream = stream
67
70
self ._stream_cls = stream_cls
68
71
self ._options = options
69
72
self .http_response = raw
70
73
74
+ @overload
75
+ def parse (self , * , to : type [_T ]) -> _T :
76
+ ...
77
+
78
+ @overload
71
79
def parse (self ) -> R :
80
+ ...
81
+
82
+ def parse (self , * , to : type [_T ] | None = None ) -> R | _T :
72
83
"""Returns the rich python representation of this response's data.
73
84
85
+ NOTE: For the async client: this will become a coroutine in the next major version.
86
+
74
87
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
75
88
76
- NOTE: For the async client: this will become a coroutine in the next major version.
89
+ You can customise the type that the response is parsed into through
90
+ the `to` argument, e.g.
91
+
92
+ ```py
93
+ from finch import BaseModel
94
+
95
+
96
+ class MyModel(BaseModel):
97
+ foo: str
98
+
99
+
100
+ obj = response.parse(to=MyModel)
101
+ print(obj.foo)
102
+ ```
103
+
104
+ We support parsing:
105
+ - `BaseModel`
106
+ - `dict`
107
+ - `list`
108
+ - `Union`
109
+ - `str`
110
+ - `httpx.Response`
77
111
"""
78
- if self ._parsed is not None :
79
- return self ._parsed
112
+ cache_key = to if to is not None else self ._cast_to
113
+ cached = self ._parsed_by_type .get (cache_key )
114
+ if cached is not None :
115
+ return cached # type: ignore[no-any-return]
80
116
81
- parsed = self ._parse ()
117
+ parsed = self ._parse (to = to )
82
118
if is_given (self ._options .post_parser ):
83
119
parsed = self ._options .post_parser (parsed )
84
120
85
- self ._parsed = parsed
121
+ self ._parsed_by_type [ cache_key ] = parsed
86
122
return parsed
87
123
88
124
@property
@@ -135,13 +171,29 @@ def elapsed(self) -> datetime.timedelta:
135
171
"""The time taken for the complete request/response cycle to complete."""
136
172
return self .http_response .elapsed
137
173
138
- def _parse (self ) -> R :
174
+ def _parse (self , * , to : type [ _T ] | None = None ) -> R | _T :
139
175
if self ._stream :
176
+ if to :
177
+ if not is_stream_class_type (to ):
178
+ raise TypeError (f"Expected custom parse type to be a subclass of { Stream } or { AsyncStream } " )
179
+
180
+ return cast (
181
+ _T ,
182
+ to (
183
+ cast_to = extract_stream_chunk_type (
184
+ to ,
185
+ failure_message = "Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]" ,
186
+ ),
187
+ response = self .http_response ,
188
+ client = cast (Any , self ._client ),
189
+ ),
190
+ )
191
+
140
192
if self ._stream_cls :
141
193
return cast (
142
194
R ,
143
195
self ._stream_cls (
144
- cast_to = _extract_stream_chunk_type (self ._stream_cls ),
196
+ cast_to = extract_stream_chunk_type (self ._stream_cls ),
145
197
response = self .http_response ,
146
198
client = cast (Any , self ._client ),
147
199
),
@@ -160,7 +212,7 @@ def _parse(self) -> R:
160
212
),
161
213
)
162
214
163
- cast_to = self ._cast_to
215
+ cast_to = to if to is not None else self ._cast_to
164
216
if cast_to is NoneType :
165
217
return cast (R , None )
166
218
@@ -186,14 +238,9 @@ def _parse(self) -> R:
186
238
raise ValueError (f"Subclasses of httpx.Response cannot be passed to `cast_to`" )
187
239
return cast (R , response )
188
240
189
- # The check here is necessary as we are subverting the the type system
190
- # with casts as the relationship between TypeVars and Types are very strict
191
- # which means we must return *exactly* what was input or transform it in a
192
- # way that retains the TypeVar state. As we cannot do that in this function
193
- # then we have to resort to using `cast`. At the time of writing, we know this
194
- # to be safe as we have handled all the types that could be bound to the
195
- # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
196
- # this function would become unsafe but a type checker would not report an error.
241
+ if inspect .isclass (origin ) and not issubclass (origin , BaseModel ) and issubclass (origin , pydantic .BaseModel ):
242
+ raise TypeError ("Pydantic models must subclass our base model type, e.g. `from finch import BaseModel`" )
243
+
197
244
if (
198
245
cast_to is not object
199
246
and not origin is list
@@ -202,12 +249,12 @@ def _parse(self) -> R:
202
249
and not issubclass (origin , BaseModel )
203
250
):
204
251
raise RuntimeError (
205
- f"Invalid state , expected { cast_to } to be a subclass type of { BaseModel } , { dict } , { list } or { Union } ."
252
+ f"Unsupported type , expected { cast_to } to be a subclass of { BaseModel } , { dict } , { list } , { Union } , { NoneType } , { str } or { httpx . Response } ."
206
253
)
207
254
208
255
# split is required to handle cases where additional information is included
209
256
# in the response, e.g. application/json; charset=utf-8
210
- content_type , * _ = response .headers .get ("content-type" ).split (";" )
257
+ content_type , * _ = response .headers .get ("content-type" , "*" ).split (";" )
211
258
if content_type != "application/json" :
212
259
if is_basemodel (cast_to ):
213
260
try :
@@ -253,15 +300,6 @@ def __init__(self) -> None:
253
300
)
254
301
255
302
256
- def _extract_stream_chunk_type (stream_cls : type ) -> type :
257
- args = get_args (stream_cls )
258
- if not args :
259
- raise TypeError (
260
- f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received { stream_cls } " ,
261
- )
262
- return cast (type , args [0 ])
263
-
264
-
265
303
def to_raw_response_wrapper (func : Callable [P , R ]) -> Callable [P , LegacyAPIResponse [R ]]:
266
304
"""Higher order function that takes one of our bound API methods and wraps it
267
305
to support returning the raw `APIResponse` object directly.
0 commit comments