Skip to content

Commit 222e949

Browse files
chore(internal): minor utils restructuring (#232)
1 parent f435145 commit 222e949

File tree

8 files changed

+183
-67
lines changed

8 files changed

+183
-67
lines changed

src/finch/_response.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import datetime
66
import functools
77
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
8-
from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin
8+
from typing_extensions import Awaitable, ParamSpec, override, get_origin
99

1010
import httpx
1111

1212
from ._types import NoneType, UnknownResponse, BinaryResponseContent
13-
from ._utils import is_given
13+
from ._utils import is_given, extract_type_var_from_base
1414
from ._models import BaseModel, is_basemodel
1515
from ._constants import RAW_RESPONSE_HEADER
1616
from ._exceptions import APIResponseValidationError
@@ -221,12 +221,13 @@ def __init__(self) -> None:
221221

222222

223223
def _extract_stream_chunk_type(stream_cls: type) -> type:
224-
args = get_args(stream_cls)
225-
if not args:
226-
raise TypeError(
227-
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
228-
)
229-
return cast(type, args[0])
224+
from ._base_client import Stream, AsyncStream
225+
226+
return extract_type_var_from_base(
227+
stream_cls,
228+
index=0,
229+
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
230+
)
230231

231232

232233
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:

src/finch/_streaming.py

+56-16
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,28 @@
22
from __future__ import annotations
33

44
import json
5-
from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator
6-
from typing_extensions import override
5+
from types import TracebackType
6+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
7+
from typing_extensions import Self, override
78

89
import httpx
910

10-
from ._types import ResponseT
11-
1211
if TYPE_CHECKING:
1312
from ._client import Finch, AsyncFinch
1413

1514

16-
class Stream(Generic[ResponseT]):
15+
_T = TypeVar("_T")
16+
17+
18+
class Stream(Generic[_T]):
1719
"""Provides the core interface to iterate over a synchronous stream response."""
1820

1921
response: httpx.Response
2022

2123
def __init__(
2224
self,
2325
*,
24-
cast_to: type[ResponseT],
26+
cast_to: type[_T],
2527
response: httpx.Response,
2628
client: Finch,
2729
) -> None:
@@ -31,18 +33,18 @@ def __init__(
3133
self._decoder = SSEDecoder()
3234
self._iterator = self.__stream__()
3335

34-
def __next__(self) -> ResponseT:
36+
def __next__(self) -> _T:
3537
return self._iterator.__next__()
3638

37-
def __iter__(self) -> Iterator[ResponseT]:
39+
def __iter__(self) -> Iterator[_T]:
3840
for item in self._iterator:
3941
yield item
4042

4143
def _iter_events(self) -> Iterator[ServerSentEvent]:
4244
yield from self._decoder.iter(self.response.iter_lines())
4345

44-
def __stream__(self) -> Iterator[ResponseT]:
45-
cast_to = self._cast_to
46+
def __stream__(self) -> Iterator[_T]:
47+
cast_to = cast(Any, self._cast_to)
4648
response = self.response
4749
process_data = self._client._process_response_data
4850
iterator = self._iter_events()
@@ -54,16 +56,35 @@ def __stream__(self) -> Iterator[ResponseT]:
5456
for _sse in iterator:
5557
...
5658

59+
def __enter__(self) -> Self:
60+
return self
61+
62+
def __exit__(
63+
self,
64+
exc_type: type[BaseException] | None,
65+
exc: BaseException | None,
66+
exc_tb: TracebackType | None,
67+
) -> None:
68+
self.close()
69+
70+
def close(self) -> None:
71+
"""
72+
Close the response and release the connection.
73+
74+
Automatically called if the response body is read to completion.
75+
"""
76+
self.response.close()
5777

58-
class AsyncStream(Generic[ResponseT]):
78+
79+
class AsyncStream(Generic[_T]):
5980
"""Provides the core interface to iterate over an asynchronous stream response."""
6081

6182
response: httpx.Response
6283

6384
def __init__(
6485
self,
6586
*,
66-
cast_to: type[ResponseT],
87+
cast_to: type[_T],
6788
response: httpx.Response,
6889
client: AsyncFinch,
6990
) -> None:
@@ -73,19 +94,19 @@ def __init__(
7394
self._decoder = SSEDecoder()
7495
self._iterator = self.__stream__()
7596

76-
async def __anext__(self) -> ResponseT:
97+
async def __anext__(self) -> _T:
7798
return await self._iterator.__anext__()
7899

79-
async def __aiter__(self) -> AsyncIterator[ResponseT]:
100+
async def __aiter__(self) -> AsyncIterator[_T]:
80101
async for item in self._iterator:
81102
yield item
82103

83104
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
84105
async for sse in self._decoder.aiter(self.response.aiter_lines()):
85106
yield sse
86107

87-
async def __stream__(self) -> AsyncIterator[ResponseT]:
88-
cast_to = self._cast_to
108+
async def __stream__(self) -> AsyncIterator[_T]:
109+
cast_to = cast(Any, self._cast_to)
89110
response = self.response
90111
process_data = self._client._process_response_data
91112
iterator = self._iter_events()
@@ -97,6 +118,25 @@ async def __stream__(self) -> AsyncIterator[ResponseT]:
97118
async for _sse in iterator:
98119
...
99120

121+
async def __aenter__(self) -> Self:
122+
return self
123+
124+
async def __aexit__(
125+
self,
126+
exc_type: type[BaseException] | None,
127+
exc: BaseException | None,
128+
exc_tb: TracebackType | None,
129+
) -> None:
130+
await self.close()
131+
132+
async def close(self) -> None:
133+
"""
134+
Close the response and release the connection.
135+
136+
Automatically called if the response body is read to completion.
137+
"""
138+
await self.response.aclose()
139+
100140

101141
class ServerSentEvent:
102142
def __init__(

src/finch/_types.py

+14
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,17 @@ def get(self, __key: str) -> str | None:
353353
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"
354354

355355
PostParser = Callable[[Any], Any]
356+
357+
358+
@runtime_checkable
359+
class InheritsGeneric(Protocol):
360+
"""Represents a type that has inherited from `Generic`
361+
The `__orig_bases__` property can be used to determine the resolved
362+
type variable for a given base class.
363+
"""
364+
365+
__orig_bases__: tuple[_GenericAlias]
366+
367+
368+
class _GenericAlias(Protocol):
369+
__origin__: type[object]

src/finch/_utils/__init__.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,32 @@
99
from ._utils import parse_date as parse_date
1010
from ._utils import is_sequence as is_sequence
1111
from ._utils import coerce_float as coerce_float
12-
from ._utils import is_list_type as is_list_type
1312
from ._utils import is_mapping_t as is_mapping_t
1413
from ._utils import removeprefix as removeprefix
1514
from ._utils import removesuffix as removesuffix
1615
from ._utils import extract_files as extract_files
1716
from ._utils import is_sequence_t as is_sequence_t
18-
from ._utils import is_union_type as is_union_type
1917
from ._utils import required_args as required_args
2018
from ._utils import coerce_boolean as coerce_boolean
2119
from ._utils import coerce_integer as coerce_integer
2220
from ._utils import file_from_path as file_from_path
2321
from ._utils import parse_datetime as parse_datetime
2422
from ._utils import strip_not_given as strip_not_given
2523
from ._utils import deepcopy_minimal as deepcopy_minimal
26-
from ._utils import extract_type_arg as extract_type_arg
27-
from ._utils import is_required_type as is_required_type
2824
from ._utils import get_async_library as get_async_library
29-
from ._utils import is_annotated_type as is_annotated_type
3025
from ._utils import maybe_coerce_float as maybe_coerce_float
3126
from ._utils import get_required_header as get_required_header
3227
from ._utils import maybe_coerce_boolean as maybe_coerce_boolean
3328
from ._utils import maybe_coerce_integer as maybe_coerce_integer
34-
from ._utils import strip_annotated_type as strip_annotated_type
29+
from ._typing import is_list_type as is_list_type
30+
from ._typing import is_union_type as is_union_type
31+
from ._typing import extract_type_arg as extract_type_arg
32+
from ._typing import is_required_type as is_required_type
33+
from ._typing import is_annotated_type as is_annotated_type
34+
from ._typing import strip_annotated_type as strip_annotated_type
35+
from ._typing import extract_type_var_from_base as extract_type_var_from_base
36+
from ._streams import consume_sync_iterator as consume_sync_iterator
37+
from ._streams import consume_async_iterator as consume_async_iterator
3538
from ._transform import PropertyInfo as PropertyInfo
3639
from ._transform import transform as transform
3740
from ._transform import maybe_transform as maybe_transform

src/finch/_utils/_streams.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any
2+
from typing_extensions import Iterator, AsyncIterator
3+
4+
5+
def consume_sync_iterator(iterator: Iterator[Any]) -> None:
6+
for _ in iterator:
7+
...
8+
9+
10+
async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
11+
async for _ in iterator:
12+
...

src/finch/_utils/_transform.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
import pydantic
88

9-
from ._utils import (
10-
is_list,
11-
is_mapping,
9+
from ._utils import is_list, is_mapping
10+
from ._typing import (
1211
is_list_type,
1312
is_union_type,
1413
extract_type_arg,

src/finch/_utils/_typing.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, cast
4+
from typing_extensions import Required, Annotated, get_args, get_origin
5+
6+
from .._types import InheritsGeneric
7+
from .._compat import is_union as _is_union
8+
9+
10+
def is_annotated_type(typ: type) -> bool:
11+
return get_origin(typ) == Annotated
12+
13+
14+
def is_list_type(typ: type) -> bool:
15+
return (get_origin(typ) or typ) == list
16+
17+
18+
def is_union_type(typ: type) -> bool:
19+
return _is_union(get_origin(typ))
20+
21+
22+
def is_required_type(typ: type) -> bool:
23+
return get_origin(typ) == Required
24+
25+
26+
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
27+
def strip_annotated_type(typ: type) -> type:
28+
if is_required_type(typ) or is_annotated_type(typ):
29+
return strip_annotated_type(cast(type, get_args(typ)[0]))
30+
31+
return typ
32+
33+
34+
def extract_type_arg(typ: type, index: int) -> type:
35+
args = get_args(typ)
36+
try:
37+
return cast(type, args[index])
38+
except IndexError as err:
39+
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
40+
41+
42+
def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type:
43+
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
44+
45+
This also handles the case where a concrete subclass is given, e.g.
46+
```py
47+
class MyResponse(Foo[bytes]):
48+
...
49+
50+
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
51+
```
52+
"""
53+
cls = cast(object, get_origin(typ) or typ)
54+
if cls in generic_bases:
55+
# we're given the class directly
56+
return extract_type_arg(typ, index)
57+
58+
# if a subclass is given
59+
# ---
60+
# this is needed as __orig_bases__ is not present in the typeshed stubs
61+
# because it is intended to be for internal use only, however there does
62+
# not seem to be a way to resolve generic TypeVars for inherited subclasses
63+
# without using it.
64+
if isinstance(cls, InheritsGeneric):
65+
target_base_class: Any | None = None
66+
for base in cls.__orig_bases__:
67+
if base.__origin__ in generic_bases:
68+
target_base_class = base
69+
break
70+
71+
if target_base_class is None:
72+
raise RuntimeError(
73+
"Could not find the generic base class;\n"
74+
"This should never happen;\n"
75+
f"Does {cls} inherit from one of {generic_bases} ?"
76+
)
77+
78+
return extract_type_arg(target_base_class, index)
79+
80+
raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")

src/finch/_utils/_utils.py

+1-34
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
overload,
1717
)
1818
from pathlib import Path
19-
from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin
19+
from typing_extensions import TypeGuard
2020

2121
import sniffio
2222

2323
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
24-
from .._compat import is_union as _is_union
2524
from .._compat import parse_date as parse_date
2625
from .._compat import parse_datetime as parse_datetime
2726

@@ -166,38 +165,6 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
166165
return isinstance(obj, list)
167166

168167

169-
def is_annotated_type(typ: type) -> bool:
170-
return get_origin(typ) == Annotated
171-
172-
173-
def is_list_type(typ: type) -> bool:
174-
return (get_origin(typ) or typ) == list
175-
176-
177-
def is_union_type(typ: type) -> bool:
178-
return _is_union(get_origin(typ))
179-
180-
181-
def is_required_type(typ: type) -> bool:
182-
return get_origin(typ) == Required
183-
184-
185-
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
186-
def strip_annotated_type(typ: type) -> type:
187-
if is_required_type(typ) or is_annotated_type(typ):
188-
return strip_annotated_type(cast(type, get_args(typ)[0]))
189-
190-
return typ
191-
192-
193-
def extract_type_arg(typ: type, index: int) -> type:
194-
args = get_args(typ)
195-
try:
196-
return cast(type, args[index])
197-
except IndexError as err:
198-
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
199-
200-
201168
def deepcopy_minimal(item: _T) -> _T:
202169
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
203170

0 commit comments

Comments
 (0)