Skip to content

feat(client): support accessing raw response objects #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,25 @@ if response.my_field is None:
print('Got json like {"my_field": null}.')
```

### Accessing raw response data (e.g. headers)

The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call.

```py
from finch import Finch

client = Finch()
page = client.hris.directory.with_raw_response.list()
response = page.individuals[0]

print(response.headers.get('X-My-Header'))

directory = response.parse() # get the object that `hris.directory.list()` would have returned
print(directory.first_name)
```

These methods return an [`APIResponse`](https://github.com/Finch-API/finch-api-python/src/finch/_response.py) object.

### Configuring the HTTP client

You can directly override the [httpx client](https://www.python-httpx.org/api/#client) to customize it for your use case, including:
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ format = { chain = [

typecheck = { chain = [
"typecheck:pyright",
"typecheck:verify-types",
"typecheck:mypy"
]}
"typecheck:pyright" = "pyright"
Expand Down
217 changes: 85 additions & 132 deletions src/finch/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
overload,
)
from functools import lru_cache
from typing_extensions import Literal, get_args, override, get_origin
from typing_extensions import Literal, override

import anyio
import httpx
Expand All @@ -49,11 +49,11 @@
ModelT,
Headers,
Timeout,
NoneType,
NotGiven,
ResponseT,
Transport,
AnyMapping,
PostParser,
ProxiesTypes,
RequestFiles,
AsyncTransport,
Expand All @@ -63,20 +63,16 @@
)
from ._utils import is_dict, is_given, is_mapping
from ._compat import model_copy, model_dump
from ._models import (
BaseModel,
GenericModel,
FinalRequestOptions,
validate_type,
construct_type,
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._response import APIResponse
from ._constants import (
DEFAULT_LIMITS,
DEFAULT_TIMEOUT,
DEFAULT_MAX_RETRIES,
RAW_RESPONSE_HEADER,
)
from ._streaming import Stream, AsyncStream
from ._exceptions import (
APIStatusError,
APITimeoutError,
APIConnectionError,
APIResponseValidationError,
)
from ._exceptions import APIStatusError, APITimeoutError, APIConnectionError

log: logging.Logger = logging.getLogger(__name__)

Expand All @@ -101,19 +97,6 @@
HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)


# default timeout is 1 minute
DEFAULT_TIMEOUT = Timeout(timeout=60.0, connect=5.0)
DEFAULT_MAX_RETRIES = 2
DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20)


class MissingStreamClassError(TypeError):
def __init__(self) -> None:
super().__init__(
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `finch._streaming` for reference",
)


class PageInfo:
"""Stores the necesary information to build the request to retrieve the next page.

Expand Down Expand Up @@ -182,6 +165,7 @@ def _params_from_url(self, url: URL) -> httpx.QueryParams:

def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options = model_copy(self._options)
options._strip_raw_response_header()

if not isinstance(info.params, NotGiven):
options.params = {**options.params, **info.params}
Expand Down Expand Up @@ -260,13 +244,17 @@ def __await__(self) -> Generator[Any, None, AsyncPageT]:
return self._get_page().__await__()

async def _get_page(self) -> AsyncPageT:
page = await self._client.request(self._page_cls, self._options)
page._set_private_attributes( # pyright: ignore[reportPrivateUsage]
model=self._model,
options=self._options,
client=self._client,
)
return page
def _parser(resp: AsyncPageT) -> AsyncPageT:
resp._set_private_attributes(
model=self._model,
options=self._options,
client=self._client,
)
return resp

self._options.post_parser = _parser

return await self._client.request(self._page_cls, self._options)

async def __aiter__(self) -> AsyncIterator[ModelT]:
# https://github.com/microsoft/pyright/issues/3464
Expand Down Expand Up @@ -317,9 +305,10 @@ async def get_next_page(self: AsyncPageT) -> AsyncPageT:


_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])


class BaseClient(Generic[_HttpxClientT]):
class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_client: _HttpxClientT
_version: str
_base_url: URL
Expand All @@ -330,6 +319,7 @@ class BaseClient(Generic[_HttpxClientT]):
_transport: Transport | AsyncTransport | None
_strict_response_validation: bool
_idempotency_header: str | None
_default_stream_cls: type[_DefaultStreamT] | None = None

def __init__(
self,
Expand Down Expand Up @@ -504,80 +494,28 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o
serialized[key] = value
return serialized

def _extract_stream_chunk_type(self, stream_cls: type) -> type:
args = get_args(stream_cls)
if not args:
raise TypeError(
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
)
return cast(type, args[0])

def _process_response(
self,
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions, # noqa: ARG002
options: FinalRequestOptions,
response: httpx.Response,
stream: bool,
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
) -> ResponseT:
if cast_to is NoneType:
return cast(ResponseT, None)

if cast_to == str:
return cast(ResponseT, response.text)

origin = get_origin(cast_to) or cast_to

if inspect.isclass(origin) and issubclass(origin, httpx.Response):
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
# and pass that class to our request functions. We cannot change the variance to be either
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
# the response class ourselves but that is something that should be supported directly in httpx
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
if cast_to != httpx.Response:
raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`")
return cast(ResponseT, response)

# The check here is necessary as we are subverting the the type system
# with casts as the relationship between TypeVars and Types are very strict
# which means we must return *exactly* what was input or transform it in a
# way that retains the TypeVar state. As we cannot do that in this function
# then we have to resort to using `cast`. At the time of writing, we know this
# to be safe as we have handled all the types that could be bound to the
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
# this function would become unsafe but a type checker would not report an error.
if (
cast_to is not UnknownResponse
and not origin is list
and not origin is dict
and not origin is Union
and not issubclass(origin, BaseModel)
):
raise RuntimeError(
f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}."
)

# split is required to handle cases where additional information is included
# in the response, e.g. application/json; charset=utf-8
content_type, *_ = response.headers.get("content-type").split(";")
if content_type != "application/json":
if self._strict_response_validation:
raise APIResponseValidationError(
response=response,
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.",
body=response.text,
)

# If the API responds with content that isn't JSON then we just return
# the (decoded) text without performing any parsing so that you can still
# handle the response however you need to.
return response.text # type: ignore
api_response = APIResponse(
raw=response,
client=self,
cast_to=cast_to,
stream=stream,
stream_cls=stream_cls,
options=options,
)

data = response.json()
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
return cast(ResponseT, api_response)

try:
return self._process_response_data(data=data, cast_to=cast_to, response=response)
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, body=data) from err
return api_response.parse()

def _process_response_data(
self,
Expand Down Expand Up @@ -734,7 +672,7 @@ def _idempotency_key(self) -> str:
return f"stainless-python-retry-{uuid.uuid4()}"


class SyncAPIClient(BaseClient[httpx.Client]):
class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
_client: httpx.Client
_has_custom_http_client: bool
_default_stream_cls: type[Stream[Any]] | None = None
Expand Down Expand Up @@ -930,23 +868,32 @@ def _request(
raise self._make_status_error_from_response(err.response) from None
except httpx.TimeoutException as err:
if retries > 0:
return self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
return self._retry_request(
options,
cast_to,
retries,
stream=stream,
stream_cls=stream_cls,
)
raise APITimeoutError(request=request) from err
except Exception as err:
if retries > 0:
return self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
return self._retry_request(
options,
cast_to,
retries,
stream=stream,
stream_cls=stream_cls,
)
raise APIConnectionError(request=request) from err

if stream:
if stream_cls:
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)

stream_cls = cast("type[_StreamT] | None", self._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return stream_cls(cast_to=cast_to, response=response, client=self)

return self._process_response(cast_to=cast_to, options=options, response=response)
return self._process_response(
cast_to=cast_to,
options=options,
response=response,
stream=stream,
stream_cls=stream_cls,
)

def _retry_request(
self,
Expand Down Expand Up @@ -980,13 +927,17 @@ def _request_api_list(
page: Type[SyncPageT],
options: FinalRequestOptions,
) -> SyncPageT:
resp = self.request(page, options, stream=False)
resp._set_private_attributes( # pyright: ignore[reportPrivateUsage]
client=self,
model=model,
options=options,
)
return resp
def _parser(resp: SyncPageT) -> SyncPageT:
resp._set_private_attributes(
client=self,
model=model,
options=options,
)
return resp

options.post_parser = _parser

return self.request(page, options, stream=False)

@overload
def get(
Expand Down Expand Up @@ -1144,7 +1095,7 @@ def get_api_list(
return self._request_api_list(model, page, opts)


class AsyncAPIClient(BaseClient[httpx.AsyncClient]):
class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
_client: httpx.AsyncClient
_has_custom_http_client: bool
_default_stream_cls: type[AsyncStream[Any]] | None = None
Expand Down Expand Up @@ -1354,16 +1305,13 @@ async def _request(
return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls)
raise APIConnectionError(request=request) from err

if stream:
if stream_cls:
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)

stream_cls = cast("type[_AsyncStreamT] | None", self._default_stream_cls)
if stream_cls is None:
raise MissingStreamClassError()
return stream_cls(cast_to=cast_to, response=response, client=self)

return self._process_response(cast_to=cast_to, options=options, response=response)
return self._process_response(
cast_to=cast_to,
options=options,
response=response,
stream=stream,
stream_cls=stream_cls,
)

async def _retry_request(
self,
Expand Down Expand Up @@ -1560,6 +1508,7 @@ def make_request_options(
extra_body: Body | None = None,
idempotency_key: str | None = None,
timeout: float | None | NotGiven = NOT_GIVEN,
post_parser: PostParser | NotGiven = NOT_GIVEN,
) -> RequestOptions:
"""Create a dict of type RequestOptions without keys of NotGiven values."""
options: RequestOptions = {}
Expand All @@ -1581,6 +1530,10 @@ def make_request_options(
if idempotency_key is not None:
options["idempotency_key"] = idempotency_key

if is_given(post_parser):
# internal
options["post_parser"] = post_parser # type: ignore

return options


Expand Down
Loading