Skip to content

Commit 6322114

Browse files
feat(client): allow binary returns (#164)
1 parent 39d20a2 commit 6322114

File tree

4 files changed

+273
-7
lines changed

4 files changed

+273
-7
lines changed

src/finch/_base_client.py

+93
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
import json
45
import time
56
import uuid
@@ -60,6 +61,7 @@
6061
RequestOptions,
6162
UnknownResponse,
6263
ModelBuilderProtocol,
64+
BinaryResponseContent,
6365
)
6466
from ._utils import is_dict, is_given, is_mapping
6567
from ._compat import model_copy, model_dump
@@ -1672,3 +1674,94 @@ def _merge_mappings(
16721674
"""
16731675
merged = {**obj1, **obj2}
16741676
return {key: value for key, value in merged.items() if not isinstance(value, Omit)}
1677+
1678+
1679+
class HttpxBinaryResponseContent(BinaryResponseContent):
1680+
response: httpx.Response
1681+
1682+
def __init__(self, response: httpx.Response) -> None:
1683+
self.response = response
1684+
1685+
@property
1686+
@override
1687+
def content(self) -> bytes:
1688+
return self.response.content
1689+
1690+
@property
1691+
@override
1692+
def text(self) -> str:
1693+
return self.response.text
1694+
1695+
@property
1696+
@override
1697+
def encoding(self) -> Optional[str]:
1698+
return self.response.encoding
1699+
1700+
@property
1701+
@override
1702+
def charset_encoding(self) -> Optional[str]:
1703+
return self.response.charset_encoding
1704+
1705+
@override
1706+
def json(self, **kwargs: Any) -> Any:
1707+
return self.response.json(**kwargs)
1708+
1709+
@override
1710+
def read(self) -> bytes:
1711+
return self.response.read()
1712+
1713+
@override
1714+
def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
1715+
return self.response.iter_bytes(chunk_size)
1716+
1717+
@override
1718+
def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]:
1719+
return self.response.iter_text(chunk_size)
1720+
1721+
@override
1722+
def iter_lines(self) -> Iterator[str]:
1723+
return self.response.iter_lines()
1724+
1725+
@override
1726+
def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
1727+
return self.response.iter_raw(chunk_size)
1728+
1729+
@override
1730+
def stream_to_file(self, file: str | os.PathLike[str]) -> None:
1731+
with open(file, mode="wb") as f:
1732+
for data in self.response.iter_bytes():
1733+
f.write(data)
1734+
1735+
@override
1736+
def close(self) -> None:
1737+
return self.response.close()
1738+
1739+
@override
1740+
async def aread(self) -> bytes:
1741+
return await self.response.aread()
1742+
1743+
@override
1744+
async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
1745+
return self.response.aiter_bytes(chunk_size)
1746+
1747+
@override
1748+
async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]:
1749+
return self.response.aiter_text(chunk_size)
1750+
1751+
@override
1752+
async def aiter_lines(self) -> AsyncIterator[str]:
1753+
return self.response.aiter_lines()
1754+
1755+
@override
1756+
async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
1757+
return self.response.aiter_raw(chunk_size)
1758+
1759+
@override
1760+
async def astream_to_file(self, file: str | os.PathLike[str]) -> None:
1761+
with open(file, mode="wb") as f:
1762+
async for data in self.response.aiter_bytes():
1763+
f.write(data)
1764+
1765+
@override
1766+
async def aclose(self) -> None:
1767+
return await self.response.aclose()

src/finch/_response.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import httpx
1010
import pydantic
1111

12-
from ._types import NoneType, UnknownResponse
12+
from ._types import NoneType, UnknownResponse, BinaryResponseContent
1313
from ._utils import is_given
1414
from ._models import BaseModel
1515
from ._constants import RAW_RESPONSE_HEADER
@@ -135,6 +135,9 @@ def _parse(self) -> R:
135135

136136
origin = get_origin(cast_to) or cast_to
137137

138+
if inspect.isclass(origin) and issubclass(origin, BinaryResponseContent):
139+
return cast(R, cast_to(response)) # type: ignore
140+
138141
if origin == APIResponse:
139142
raise RuntimeError("Unexpected state - cast_to is `APIResponse`")
140143

src/finch/_types.py

+149-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from os import PathLike
4+
from abc import ABC, abstractmethod
45
from typing import (
56
IO,
67
TYPE_CHECKING,
@@ -13,8 +14,10 @@
1314
Mapping,
1415
TypeVar,
1516
Callable,
17+
Iterator,
1618
Optional,
1719
Sequence,
20+
AsyncIterator,
1821
)
1922
from typing_extensions import (
2023
Literal,
@@ -25,7 +28,6 @@
2528
runtime_checkable,
2629
)
2730

28-
import httpx
2931
import pydantic
3032
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
3133

@@ -40,6 +42,151 @@
4042
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
4143
_T = TypeVar("_T")
4244

45+
46+
class BinaryResponseContent(ABC):
47+
def __init__(
48+
self,
49+
response: Any,
50+
) -> None:
51+
...
52+
53+
@property
54+
@abstractmethod
55+
def content(self) -> bytes:
56+
pass
57+
58+
@property
59+
@abstractmethod
60+
def text(self) -> str:
61+
pass
62+
63+
@property
64+
@abstractmethod
65+
def encoding(self) -> Optional[str]:
66+
"""
67+
Return an encoding to use for decoding the byte content into text.
68+
The priority for determining this is given by...
69+
70+
* `.encoding = <>` has been set explicitly.
71+
* The encoding as specified by the charset parameter in the Content-Type header.
72+
* The encoding as determined by `default_encoding`, which may either be
73+
a string like "utf-8" indicating the encoding to use, or may be a callable
74+
which enables charset autodetection.
75+
"""
76+
pass
77+
78+
@property
79+
@abstractmethod
80+
def charset_encoding(self) -> Optional[str]:
81+
"""
82+
Return the encoding, as specified by the Content-Type header.
83+
"""
84+
pass
85+
86+
@abstractmethod
87+
def json(self, **kwargs: Any) -> Any:
88+
pass
89+
90+
@abstractmethod
91+
def read(self) -> bytes:
92+
"""
93+
Read and return the response content.
94+
"""
95+
pass
96+
97+
@abstractmethod
98+
def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
99+
"""
100+
A byte-iterator over the decoded response content.
101+
This allows us to handle gzip, deflate, and brotli encoded responses.
102+
"""
103+
pass
104+
105+
@abstractmethod
106+
def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]:
107+
"""
108+
A str-iterator over the decoded response content
109+
that handles both gzip, deflate, etc but also detects the content's
110+
string encoding.
111+
"""
112+
pass
113+
114+
@abstractmethod
115+
def iter_lines(self) -> Iterator[str]:
116+
pass
117+
118+
@abstractmethod
119+
def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]:
120+
"""
121+
A byte-iterator over the raw response content.
122+
"""
123+
pass
124+
125+
@abstractmethod
126+
def stream_to_file(self, file: str | PathLike[str]) -> None:
127+
"""
128+
Stream the output to the given file.
129+
"""
130+
pass
131+
132+
@abstractmethod
133+
def close(self) -> None:
134+
"""
135+
Close the response and release the connection.
136+
Automatically called if the response body is read to completion.
137+
"""
138+
pass
139+
140+
@abstractmethod
141+
async def aread(self) -> bytes:
142+
"""
143+
Read and return the response content.
144+
"""
145+
pass
146+
147+
@abstractmethod
148+
async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
149+
"""
150+
A byte-iterator over the decoded response content.
151+
This allows us to handle gzip, deflate, and brotli encoded responses.
152+
"""
153+
pass
154+
155+
@abstractmethod
156+
async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]:
157+
"""
158+
A str-iterator over the decoded response content
159+
that handles both gzip, deflate, etc but also detects the content's
160+
string encoding.
161+
"""
162+
pass
163+
164+
@abstractmethod
165+
async def aiter_lines(self) -> AsyncIterator[str]:
166+
pass
167+
168+
@abstractmethod
169+
async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]:
170+
"""
171+
A byte-iterator over the raw response content.
172+
"""
173+
pass
174+
175+
async def astream_to_file(self, file: str | PathLike[str]) -> None:
176+
"""
177+
Stream the output to the given file.
178+
"""
179+
pass
180+
181+
@abstractmethod
182+
async def aclose(self) -> None:
183+
"""
184+
Close the response and release the connection.
185+
Automatically called if the response body is read to completion.
186+
"""
187+
pass
188+
189+
43190
# Approximates httpx internal ProxiesTypes and RequestFiles types
44191
# while adding support for `PathLike` instances
45192
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
@@ -181,7 +328,7 @@ def get(self, __key: str) -> str | None:
181328

182329
ResponseT = TypeVar(
183330
"ResponseT",
184-
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], httpx.Response, UnknownResponse, ModelBuilderProtocol]",
331+
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
185332
)
186333

187334
StrBytesIntFloat = Union[str, bytes, int, float]

tests/test_client.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,23 @@ class TestFinch:
4141

4242
@pytest.mark.respx(base_url=base_url)
4343
def test_raw_response(self, respx_mock: MockRouter) -> None:
44-
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
44+
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}'))
4545

4646
response = self.client.post("/foo", cast_to=httpx.Response)
4747
assert response.status_code == 200
4848
assert isinstance(response, httpx.Response)
49-
assert response.json() == {"foo": "bar"}
49+
assert response.json() == '{"foo": "bar"}'
50+
51+
@pytest.mark.respx(base_url=base_url)
52+
def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
53+
respx_mock.post("/foo").mock(
54+
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
55+
)
56+
57+
response = self.client.post("/foo", cast_to=httpx.Response)
58+
assert response.status_code == 200
59+
assert isinstance(response, httpx.Response)
60+
assert response.json() == '{"foo": "bar"}'
5061

5162
def test_copy(self) -> None:
5263
copied = self.client.copy()
@@ -672,12 +683,24 @@ class TestAsyncFinch:
672683
@pytest.mark.respx(base_url=base_url)
673684
@pytest.mark.asyncio
674685
async def test_raw_response(self, respx_mock: MockRouter) -> None:
675-
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
686+
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}'))
687+
688+
response = await self.client.post("/foo", cast_to=httpx.Response)
689+
assert response.status_code == 200
690+
assert isinstance(response, httpx.Response)
691+
assert response.json() == '{"foo": "bar"}'
692+
693+
@pytest.mark.respx(base_url=base_url)
694+
@pytest.mark.asyncio
695+
async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
696+
respx_mock.post("/foo").mock(
697+
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
698+
)
676699

677700
response = await self.client.post("/foo", cast_to=httpx.Response)
678701
assert response.status_code == 200
679702
assert isinstance(response, httpx.Response)
680-
assert response.json() == {"foo": "bar"}
703+
assert response.json() == '{"foo": "bar"}'
681704

682705
def test_copy(self) -> None:
683706
copied = self.client.copy()

0 commit comments

Comments
 (0)