Skip to content

chore(internal): loosen type var restrictions #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions src/finch/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
Body,
Omit,
Query,
ModelT,
Headers,
Timeout,
NotGiven,
Expand All @@ -61,7 +60,6 @@
HttpxSendArgs,
AsyncTransport,
RequestOptions,
UnknownResponse,
ModelBuilderProtocol,
BinaryResponseContent,
)
Expand Down Expand Up @@ -142,7 +140,7 @@ def __init__(
self.params = params


class BasePage(GenericModel, Generic[ModelT]):
class BasePage(GenericModel, Generic[_T]):
"""
Defines the core interface for pagination.

Expand All @@ -155,7 +153,7 @@ class BasePage(GenericModel, Generic[ModelT]):
"""

_options: FinalRequestOptions = PrivateAttr()
_model: Type[ModelT] = PrivateAttr()
_model: Type[_T] = PrivateAttr()

def has_next_page(self) -> bool:
items = self._get_page_items()
Expand All @@ -166,7 +164,7 @@ def has_next_page(self) -> bool:
def next_page_info(self) -> Optional[PageInfo]:
...

def _get_page_items(self) -> Iterable[ModelT]: # type: ignore[empty-body]
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
...

def _params_from_url(self, url: URL) -> httpx.QueryParams:
Expand All @@ -191,13 +189,13 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
raise ValueError("Unexpected PageInfo state")


class BaseSyncPage(BasePage[ModelT], Generic[ModelT]):
class BaseSyncPage(BasePage[_T], Generic[_T]):
_client: SyncAPIClient = pydantic.PrivateAttr()

def _set_private_attributes(
self,
client: SyncAPIClient,
model: Type[ModelT],
model: Type[_T],
options: FinalRequestOptions,
) -> None:
self._model = model
Expand All @@ -212,7 +210,7 @@ def _set_private_attributes(
# methods should continue to work as expected as there is an alternative method
# to cast a model to a dictionary, model.dict(), which is used internally
# by pydantic.
def __iter__(self) -> Iterator[ModelT]: # type: ignore
def __iter__(self) -> Iterator[_T]: # type: ignore
for page in self.iter_pages():
for item in page._get_page_items():
yield item
Expand All @@ -237,13 +235,13 @@ def get_next_page(self: SyncPageT) -> SyncPageT:
return self._client._request_api_list(self._model, page=self.__class__, options=options)


class AsyncPaginator(Generic[ModelT, AsyncPageT]):
class AsyncPaginator(Generic[_T, AsyncPageT]):
def __init__(
self,
client: AsyncAPIClient,
options: FinalRequestOptions,
page_cls: Type[AsyncPageT],
model: Type[ModelT],
model: Type[_T],
) -> None:
self._model = model
self._client = client
Expand All @@ -266,7 +264,7 @@ def _parser(resp: AsyncPageT) -> AsyncPageT:

return await self._client.request(self._page_cls, self._options)

async def __aiter__(self) -> AsyncIterator[ModelT]:
async def __aiter__(self) -> AsyncIterator[_T]:
# https://github.com/microsoft/pyright/issues/3464
page = cast(
AsyncPageT,
Expand All @@ -276,20 +274,20 @@ async def __aiter__(self) -> AsyncIterator[ModelT]:
yield item


class BaseAsyncPage(BasePage[ModelT], Generic[ModelT]):
class BaseAsyncPage(BasePage[_T], Generic[_T]):
_client: AsyncAPIClient = pydantic.PrivateAttr()

def _set_private_attributes(
self,
model: Type[ModelT],
model: Type[_T],
client: AsyncAPIClient,
options: FinalRequestOptions,
) -> None:
self._model = model
self._client = client
self._options = options

async def __aiter__(self) -> AsyncIterator[ModelT]:
async def __aiter__(self) -> AsyncIterator[_T]:
async for page in self.iter_pages():
for item in page._get_page_items():
yield item
Expand Down Expand Up @@ -528,7 +526,7 @@ def _process_response_data(
if data is None:
return cast(ResponseT, None)

if cast_to is UnknownResponse:
if cast_to is object:
return cast(ResponseT, data)

try:
Expand Down Expand Up @@ -970,7 +968,7 @@ def _retry_request(

def _request_api_list(
self,
model: Type[ModelT],
model: Type[object],
page: Type[SyncPageT],
options: FinalRequestOptions,
) -> SyncPageT:
Expand Down Expand Up @@ -1132,7 +1130,7 @@ def get_api_list(
self,
path: str,
*,
model: Type[ModelT],
model: Type[object],
page: Type[SyncPageT],
body: Body | None = None,
options: RequestOptions = {},
Expand Down Expand Up @@ -1434,10 +1432,10 @@ async def _retry_request(

def _request_api_list(
self,
model: Type[ModelT],
model: Type[_T],
page: Type[AsyncPageT],
options: FinalRequestOptions,
) -> AsyncPaginator[ModelT, AsyncPageT]:
) -> AsyncPaginator[_T, AsyncPageT]:
return AsyncPaginator(client=self, options=options, page_cls=page, model=model)

@overload
Expand Down Expand Up @@ -1584,13 +1582,12 @@ def get_api_list(
self,
path: str,
*,
# TODO: support paginating `str`
model: Type[ModelT],
model: Type[_T],
page: Type[AsyncPageT],
body: Body | None = None,
options: RequestOptions = {},
method: str = "get",
) -> AsyncPaginator[ModelT, AsyncPageT]:
) -> AsyncPaginator[_T, AsyncPageT]:
opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
return self._request_api_list(model, page, opts)

Expand Down
4 changes: 2 additions & 2 deletions src/finch/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import httpx

from ._types import NoneType, UnknownResponse, BinaryResponseContent
from ._types import NoneType, BinaryResponseContent
from ._utils import is_given, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
Expand Down Expand Up @@ -162,7 +162,7 @@ def _parse(self) -> R:
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
# this function would become unsafe but a type checker would not report an error.
if (
cast_to is not UnknownResponse
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
Expand Down
17 changes: 11 additions & 6 deletions src/finch/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,6 @@ class RequestOptions(TypedDict, total=False):
idempotency_key: str


# Sentinel class used when the response type is an object with an unknown schema
class UnknownResponse:
...


# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
Expand Down Expand Up @@ -339,7 +334,17 @@ def get(self, __key: str) -> str | None:

ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
bound=Union[
object,
str,
None,
"BaseModel",
List[Any],
Dict[str, Any],
Response,
ModelBuilderProtocol,
BinaryResponseContent,
],
)

StrBytesIntFloat = Union[str, bytes, int, float]
Expand Down
51 changes: 26 additions & 25 deletions src/finch/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from httpx import Response

from ._types import ModelT
from ._utils import is_mapping
from ._models import BaseModel
from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage
Expand All @@ -24,12 +23,14 @@

_BaseModelT = TypeVar("_BaseModelT", bound=BaseModel)

_T = TypeVar("_T")

class SyncSinglePage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
items: List[ModelT]

class SyncSinglePage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
items: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
return self.items

@override
Expand All @@ -50,11 +51,11 @@ def build(cls: Type[_BaseModelT], *, response: Response, data: object) -> _BaseM
)


class AsyncSinglePage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
items: List[ModelT]
class AsyncSinglePage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
items: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
return self.items

@override
Expand All @@ -75,11 +76,11 @@ def build(cls: Type[_BaseModelT], *, response: Response, data: object) -> _BaseM
)


class SyncResponsesPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
responses: List[ModelT]
class SyncResponsesPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
responses: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
return self.responses

@override
Expand All @@ -91,11 +92,11 @@ def next_page_info(self) -> None:
return None


class AsyncResponsesPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
responses: List[ModelT]
class AsyncResponsesPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
responses: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
return self.responses

@override
Expand All @@ -107,12 +108,12 @@ def next_page_info(self) -> None:
return None


class SyncIndividualsPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
individuals: List[ModelT]
class SyncIndividualsPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
individuals: List[_T]
paging: Paging

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
return self.individuals

@override
Expand All @@ -134,12 +135,12 @@ def next_page_info(self) -> Optional[PageInfo]:
return None


class AsyncIndividualsPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
individuals: List[ModelT]
class AsyncIndividualsPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
individuals: List[_T]
paging: Paging

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
return self.individuals

@override
Expand All @@ -161,12 +162,12 @@ def next_page_info(self) -> Optional[PageInfo]:
return None


class SyncPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
paging: Paging
data: List[ModelT]
data: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
return self.data

@override
Expand All @@ -188,12 +189,12 @@ def next_page_info(self) -> Optional[PageInfo]:
return None


class AsyncPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
class AsyncPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
paging: Paging
data: List[ModelT]
data: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
return self.data

@override
Expand Down
8 changes: 1 addition & 7 deletions src/finch/resources/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@
import httpx

from ..types import Introspection, DisconnectResponse
from .._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
Expand Down
8 changes: 1 addition & 7 deletions src/finch/resources/hris/benefits/benefits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import httpx

from ...._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform
from ...._compat import cached_property
from .individuals import Individuals, AsyncIndividuals, IndividualsWithRawResponse, AsyncIndividualsWithRawResponse
Expand Down
Loading