29
29
overload ,
30
30
)
31
31
from functools import lru_cache
32
- from typing_extensions import Literal , get_args , override , get_origin
32
+ from typing_extensions import Literal , override
33
33
34
34
import anyio
35
35
import httpx
49
49
ModelT ,
50
50
Headers ,
51
51
Timeout ,
52
- NoneType ,
53
52
NotGiven ,
54
53
ResponseT ,
55
54
Transport ,
56
55
AnyMapping ,
56
+ PostParser ,
57
57
ProxiesTypes ,
58
58
RequestFiles ,
59
59
AsyncTransport ,
63
63
)
64
64
from ._utils import is_dict , is_given , is_mapping
65
65
from ._compat import model_copy , model_dump
66
- from ._models import (
67
- BaseModel ,
68
- GenericModel ,
69
- FinalRequestOptions ,
70
- validate_type ,
71
- construct_type ,
66
+ from ._models import GenericModel , FinalRequestOptions , validate_type , construct_type
67
+ from ._response import APIResponse
68
+ from ._constants import (
69
+ DEFAULT_LIMITS ,
70
+ DEFAULT_TIMEOUT ,
71
+ DEFAULT_MAX_RETRIES ,
72
+ RAW_RESPONSE_HEADER ,
72
73
)
73
74
from ._streaming import Stream , AsyncStream
74
- from ._exceptions import (
75
- APIStatusError ,
76
- APITimeoutError ,
77
- APIConnectionError ,
78
- APIResponseValidationError ,
79
- )
75
+ from ._exceptions import APIStatusError , APITimeoutError , APIConnectionError
80
76
81
77
log : logging .Logger = logging .getLogger (__name__ )
82
78
101
97
HTTPX_DEFAULT_TIMEOUT = Timeout (5.0 )
102
98
103
99
104
- # default timeout is 1 minute
105
- DEFAULT_TIMEOUT = Timeout (timeout = 60.0 , connect = 5.0 )
106
- DEFAULT_MAX_RETRIES = 2
107
- DEFAULT_LIMITS = Limits (max_connections = 100 , max_keepalive_connections = 20 )
108
-
109
-
110
- class MissingStreamClassError (TypeError ):
111
- def __init__ (self ) -> None :
112
- super ().__init__ (
113
- "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `finch._streaming` for reference" ,
114
- )
115
-
116
-
117
100
class PageInfo :
118
101
"""Stores the necesary information to build the request to retrieve the next page.
119
102
@@ -182,6 +165,7 @@ def _params_from_url(self, url: URL) -> httpx.QueryParams:
182
165
183
166
def _info_to_options (self , info : PageInfo ) -> FinalRequestOptions :
184
167
options = model_copy (self ._options )
168
+ options ._strip_raw_response_header ()
185
169
186
170
if not isinstance (info .params , NotGiven ):
187
171
options .params = {** options .params , ** info .params }
@@ -260,13 +244,17 @@ def __await__(self) -> Generator[Any, None, AsyncPageT]:
260
244
return self ._get_page ().__await__ ()
261
245
262
246
async def _get_page (self ) -> AsyncPageT :
263
- page = await self ._client .request (self ._page_cls , self ._options )
264
- page ._set_private_attributes ( # pyright: ignore[reportPrivateUsage]
265
- model = self ._model ,
266
- options = self ._options ,
267
- client = self ._client ,
268
- )
269
- return page
247
+ def _parser (resp : AsyncPageT ) -> AsyncPageT :
248
+ resp ._set_private_attributes (
249
+ model = self ._model ,
250
+ options = self ._options ,
251
+ client = self ._client ,
252
+ )
253
+ return resp
254
+
255
+ self ._options .post_parser = _parser
256
+
257
+ return await self ._client .request (self ._page_cls , self ._options )
270
258
271
259
async def __aiter__ (self ) -> AsyncIterator [ModelT ]:
272
260
# https://github.com/microsoft/pyright/issues/3464
@@ -317,9 +305,10 @@ async def get_next_page(self: AsyncPageT) -> AsyncPageT:
317
305
318
306
319
307
_HttpxClientT = TypeVar ("_HttpxClientT" , bound = Union [httpx .Client , httpx .AsyncClient ])
308
+ _DefaultStreamT = TypeVar ("_DefaultStreamT" , bound = Union [Stream [Any ], AsyncStream [Any ]])
320
309
321
310
322
- class BaseClient (Generic [_HttpxClientT ]):
311
+ class BaseClient (Generic [_HttpxClientT , _DefaultStreamT ]):
323
312
_client : _HttpxClientT
324
313
_version : str
325
314
_base_url : URL
@@ -330,6 +319,7 @@ class BaseClient(Generic[_HttpxClientT]):
330
319
_transport : Transport | AsyncTransport | None
331
320
_strict_response_validation : bool
332
321
_idempotency_header : str | None
322
+ _default_stream_cls : type [_DefaultStreamT ] | None = None
333
323
334
324
def __init__ (
335
325
self ,
@@ -504,80 +494,28 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o
504
494
serialized [key ] = value
505
495
return serialized
506
496
507
- def _extract_stream_chunk_type (self , stream_cls : type ) -> type :
508
- args = get_args (stream_cls )
509
- if not args :
510
- raise TypeError (
511
- f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received { stream_cls } " ,
512
- )
513
- return cast (type , args [0 ])
514
-
515
497
def _process_response (
516
498
self ,
517
499
* ,
518
500
cast_to : Type [ResponseT ],
519
- options : FinalRequestOptions , # noqa: ARG002
501
+ options : FinalRequestOptions ,
520
502
response : httpx .Response ,
503
+ stream : bool ,
504
+ stream_cls : type [Stream [Any ]] | type [AsyncStream [Any ]] | None ,
521
505
) -> ResponseT :
522
- if cast_to is NoneType :
523
- return cast (ResponseT , None )
524
-
525
- if cast_to == str :
526
- return cast (ResponseT , response .text )
527
-
528
- origin = get_origin (cast_to ) or cast_to
529
-
530
- if inspect .isclass (origin ) and issubclass (origin , httpx .Response ):
531
- # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
532
- # and pass that class to our request functions. We cannot change the variance to be either
533
- # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
534
- # the response class ourselves but that is something that should be supported directly in httpx
535
- # as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
536
- if cast_to != httpx .Response :
537
- raise ValueError (f"Subclasses of httpx.Response cannot be passed to `cast_to`" )
538
- return cast (ResponseT , response )
539
-
540
- # The check here is necessary as we are subverting the the type system
541
- # with casts as the relationship between TypeVars and Types are very strict
542
- # which means we must return *exactly* what was input or transform it in a
543
- # way that retains the TypeVar state. As we cannot do that in this function
544
- # then we have to resort to using `cast`. At the time of writing, we know this
545
- # to be safe as we have handled all the types that could be bound to the
546
- # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
547
- # this function would become unsafe but a type checker would not report an error.
548
- if (
549
- cast_to is not UnknownResponse
550
- and not origin is list
551
- and not origin is dict
552
- and not origin is Union
553
- and not issubclass (origin , BaseModel )
554
- ):
555
- raise RuntimeError (
556
- f"Invalid state, expected { cast_to } to be a subclass type of { BaseModel } , { dict } , { list } or { Union } ."
557
- )
558
-
559
- # split is required to handle cases where additional information is included
560
- # in the response, e.g. application/json; charset=utf-8
561
- content_type , * _ = response .headers .get ("content-type" ).split (";" )
562
- if content_type != "application/json" :
563
- if self ._strict_response_validation :
564
- raise APIResponseValidationError (
565
- response = response ,
566
- message = f"Expected Content-Type response header to be `application/json` but received `{ content_type } ` instead." ,
567
- body = response .text ,
568
- )
569
-
570
- # If the API responds with content that isn't JSON then we just return
571
- # the (decoded) text without performing any parsing so that you can still
572
- # handle the response however you need to.
573
- return response .text # type: ignore
506
+ api_response = APIResponse (
507
+ raw = response ,
508
+ client = self ,
509
+ cast_to = cast_to ,
510
+ stream = stream ,
511
+ stream_cls = stream_cls ,
512
+ options = options ,
513
+ )
574
514
575
- data = response .json ()
515
+ if response .request .headers .get (RAW_RESPONSE_HEADER ) == "true" :
516
+ return cast (ResponseT , api_response )
576
517
577
- try :
578
- return self ._process_response_data (data = data , cast_to = cast_to , response = response )
579
- except pydantic .ValidationError as err :
580
- raise APIResponseValidationError (response = response , body = data ) from err
518
+ return api_response .parse ()
581
519
582
520
def _process_response_data (
583
521
self ,
@@ -734,7 +672,7 @@ def _idempotency_key(self) -> str:
734
672
return f"stainless-python-retry-{ uuid .uuid4 ()} "
735
673
736
674
737
- class SyncAPIClient (BaseClient [httpx .Client ]):
675
+ class SyncAPIClient (BaseClient [httpx .Client , Stream [ Any ] ]):
738
676
_client : httpx .Client
739
677
_has_custom_http_client : bool
740
678
_default_stream_cls : type [Stream [Any ]] | None = None
@@ -930,23 +868,32 @@ def _request(
930
868
raise self ._make_status_error_from_response (err .response ) from None
931
869
except httpx .TimeoutException as err :
932
870
if retries > 0 :
933
- return self ._retry_request (options , cast_to , retries , stream = stream , stream_cls = stream_cls )
871
+ return self ._retry_request (
872
+ options ,
873
+ cast_to ,
874
+ retries ,
875
+ stream = stream ,
876
+ stream_cls = stream_cls ,
877
+ )
934
878
raise APITimeoutError (request = request ) from err
935
879
except Exception as err :
936
880
if retries > 0 :
937
- return self ._retry_request (options , cast_to , retries , stream = stream , stream_cls = stream_cls )
881
+ return self ._retry_request (
882
+ options ,
883
+ cast_to ,
884
+ retries ,
885
+ stream = stream ,
886
+ stream_cls = stream_cls ,
887
+ )
938
888
raise APIConnectionError (request = request ) from err
939
889
940
- if stream :
941
- if stream_cls :
942
- return stream_cls (cast_to = self ._extract_stream_chunk_type (stream_cls ), response = response , client = self )
943
-
944
- stream_cls = cast ("type[_StreamT] | None" , self ._default_stream_cls )
945
- if stream_cls is None :
946
- raise MissingStreamClassError ()
947
- return stream_cls (cast_to = cast_to , response = response , client = self )
948
-
949
- return self ._process_response (cast_to = cast_to , options = options , response = response )
890
+ return self ._process_response (
891
+ cast_to = cast_to ,
892
+ options = options ,
893
+ response = response ,
894
+ stream = stream ,
895
+ stream_cls = stream_cls ,
896
+ )
950
897
951
898
def _retry_request (
952
899
self ,
@@ -980,13 +927,17 @@ def _request_api_list(
980
927
page : Type [SyncPageT ],
981
928
options : FinalRequestOptions ,
982
929
) -> SyncPageT :
983
- resp = self .request (page , options , stream = False )
984
- resp ._set_private_attributes ( # pyright: ignore[reportPrivateUsage]
985
- client = self ,
986
- model = model ,
987
- options = options ,
988
- )
989
- return resp
930
+ def _parser (resp : SyncPageT ) -> SyncPageT :
931
+ resp ._set_private_attributes (
932
+ client = self ,
933
+ model = model ,
934
+ options = options ,
935
+ )
936
+ return resp
937
+
938
+ options .post_parser = _parser
939
+
940
+ return self .request (page , options , stream = False )
990
941
991
942
@overload
992
943
def get (
@@ -1144,7 +1095,7 @@ def get_api_list(
1144
1095
return self ._request_api_list (model , page , opts )
1145
1096
1146
1097
1147
- class AsyncAPIClient (BaseClient [httpx .AsyncClient ]):
1098
+ class AsyncAPIClient (BaseClient [httpx .AsyncClient , AsyncStream [ Any ] ]):
1148
1099
_client : httpx .AsyncClient
1149
1100
_has_custom_http_client : bool
1150
1101
_default_stream_cls : type [AsyncStream [Any ]] | None = None
@@ -1354,16 +1305,13 @@ async def _request(
1354
1305
return await self ._retry_request (options , cast_to , retries , stream = stream , stream_cls = stream_cls )
1355
1306
raise APIConnectionError (request = request ) from err
1356
1307
1357
- if stream :
1358
- if stream_cls :
1359
- return stream_cls (cast_to = self ._extract_stream_chunk_type (stream_cls ), response = response , client = self )
1360
-
1361
- stream_cls = cast ("type[_AsyncStreamT] | None" , self ._default_stream_cls )
1362
- if stream_cls is None :
1363
- raise MissingStreamClassError ()
1364
- return stream_cls (cast_to = cast_to , response = response , client = self )
1365
-
1366
- return self ._process_response (cast_to = cast_to , options = options , response = response )
1308
+ return self ._process_response (
1309
+ cast_to = cast_to ,
1310
+ options = options ,
1311
+ response = response ,
1312
+ stream = stream ,
1313
+ stream_cls = stream_cls ,
1314
+ )
1367
1315
1368
1316
async def _retry_request (
1369
1317
self ,
@@ -1560,6 +1508,7 @@ def make_request_options(
1560
1508
extra_body : Body | None = None ,
1561
1509
idempotency_key : str | None = None ,
1562
1510
timeout : float | None | NotGiven = NOT_GIVEN ,
1511
+ post_parser : PostParser | NotGiven = NOT_GIVEN ,
1563
1512
) -> RequestOptions :
1564
1513
"""Create a dict of type RequestOptions without keys of NotGiven values."""
1565
1514
options : RequestOptions = {}
@@ -1581,6 +1530,10 @@ def make_request_options(
1581
1530
if idempotency_key is not None :
1582
1531
options ["idempotency_key" ] = idempotency_key
1583
1532
1533
+ if is_given (post_parser ):
1534
+ # internal
1535
+ options ["post_parser" ] = post_parser # type: ignore
1536
+
1584
1537
return options
1585
1538
1586
1539
0 commit comments