diff --git a/README.md b/README.md index f7c97805..9313c47e 100644 --- a/README.md +++ b/README.md @@ -236,7 +236,7 @@ If you need to, you can override it by setting default headers per-request or on from finch import Finch finch = Finch( - default_headers={"Finch-API-Version": My - Custom - Value}, + default_headers={"Finch-API-Version": "My-Custom-Value"}, ) ``` diff --git a/src/finch/_base_client.py b/src/finch/_base_client.py index 2cc8772d..b7953442 100644 --- a/src/finch/_base_client.py +++ b/src/finch/_base_client.py @@ -9,7 +9,6 @@ from typing import ( Any, Dict, - List, Type, Union, Generic, @@ -45,6 +44,7 @@ Timeout, NoneType, NotGiven, + ResponseT, Transport, AnyMapping, ProxiesTypes, @@ -61,6 +61,7 @@ validate_type, construct_type, ) +from ._streaming import Stream, AsyncStream from ._base_exceptions import ( APIStatusError, APITimeoutError, @@ -73,128 +74,23 @@ AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]") -ResponseT = TypeVar( - "ResponseT", - bound=Union[ - str, - None, - BaseModel, - List[Any], - Dict[str, Any], - httpx.Response, - UnknownResponse, - ModelBuilderProtocol, - ], -) - _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) +_StreamT = TypeVar("_StreamT", bound=Stream[Any]) +_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any]) + + DEFAULT_TIMEOUT = Timeout(timeout=60.0, connect=5.0) DEFAULT_MAX_RETRIES = 2 DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20) -class StopStreaming(Exception): - """Raised internally when processing of a streamed response should be stopped.""" - - -class Stream(Generic[ResponseT]): - response: httpx.Response - - def __init__( - self, - *, - cast_to: type[ResponseT], - response: httpx.Response, - client: SyncAPIClient, - ) -> None: - self.response = response - self._cast_to = cast_to - self._client = client - self._iterator = self.__iter() - - def __next__(self) -> ResponseT: - return self._iterator.__next__() - - def __iter__(self) -> Iterator[ResponseT]: - for item in self._iterator: - yield item - - def __iter(self) -> Iterator[ResponseT]: - cast_to = self._cast_to - response = self.response - process_line = self._client._process_stream_line - process_data = self._client._process_response_data - - awaiting_ping_data = False - for raw_line in response.iter_lines(): - if not raw_line or raw_line == "\n": - continue - - if raw_line.startswith("event: ping"): - awaiting_ping_data = True - continue - if awaiting_ping_data: - awaiting_ping_data = False - continue - - try: - line = process_line(raw_line) - except StopStreaming: - # we are done! - break - - yield process_data(data=json.loads(line), cast_to=cast_to, response=response) - - -class AsyncStream(Generic[ResponseT]): - response: httpx.Response - - def __init__( - self, - *, - cast_to: type[ResponseT], - response: httpx.Response, - client: AsyncAPIClient, - ) -> None: - self.response = response - self._cast_to = cast_to - self._client = client - self._iterator = self.__iter() - - async def __anext__(self) -> ResponseT: - return await self._iterator.__anext__() - - async def __aiter__(self) -> AsyncIterator[ResponseT]: - async for item in self._iterator: - yield item - - async def __iter(self) -> AsyncIterator[ResponseT]: - cast_to = self._cast_to - response = self.response - process_line = self._client._process_stream_line - process_data = self._client._process_response_data - - awaiting_ping_data = False - async for raw_line in response.aiter_lines(): - if not raw_line or raw_line == "\n": - continue - - if raw_line.startswith("event: ping"): - awaiting_ping_data = True - continue - if awaiting_ping_data: - awaiting_ping_data = False - continue - - try: - line = process_line(raw_line) - except StopStreaming: - # we are done! - break - - yield process_data(data=json.loads(line), cast_to=cast_to, response=response) +class MissingStreamClassError(TypeError): + def __init__(self) -> None: + super().__init__( + "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `finch._streaming` for reference", + ) class PageInfo: @@ -635,16 +531,6 @@ def _process_response_data( return cast(ResponseT, construct_type(type_=cast_to, value=data)) - def _process_stream_line(self, contents: str) -> str: - """Pre-process an indiviudal line from a streaming response""" - if contents.startswith("data: [DONE]"): - raise StopStreaming() - - if contents.startswith("data: "): - return contents[6:] - - return contents - @property def qs(self) -> Querystring: return Querystring() @@ -756,6 +642,7 @@ def _idempotency_key(self) -> str: class SyncAPIClient(BaseClient): _client: httpx.Client + _default_stream_cls: type[Stream[Any]] | None = None def __init__( self, @@ -798,7 +685,8 @@ def request( remaining_retries: Optional[int] = None, *, stream: Literal[True], - ) -> Stream[ResponseT]: + stream_cls: Type[_StreamT], + ) -> _StreamT: ... @overload @@ -820,7 +708,8 @@ def request( remaining_retries: Optional[int] = None, *, stream: bool = False, - ) -> ResponseT | Stream[ResponseT]: + stream_cls: Type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: ... def request( @@ -830,11 +719,13 @@ def request( remaining_retries: Optional[int] = None, *, stream: bool = False, - ) -> ResponseT | Stream[ResponseT]: + stream_cls: type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: return self._request( cast_to=cast_to, options=options, stream=stream, + stream_cls=stream_cls, remaining_retries=remaining_retries, ) @@ -845,7 +736,8 @@ def _request( options: FinalRequestOptions, remaining_retries: int | None, stream: bool, - ) -> ResponseT | Stream[ResponseT]: + stream_cls: type[_StreamT] | None, + ) -> ResponseT | _StreamT: retries = self._remaining_retries(remaining_retries, options) request = self._build_request(options) @@ -854,7 +746,14 @@ def _request( response.raise_for_status() except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code if retries > 0 and self._should_retry(err.response): - return self._retry_request(options, cast_to, retries, err.response.headers, stream=stream) + return self._retry_request( + options, + cast_to, + retries, + err.response.headers, + stream=stream, + stream_cls=stream_cls, + ) # If the response is streamed then we need to explicitly read the response # to completion before attempting to access the response text. @@ -862,15 +761,18 @@ def _request( raise self._make_status_error_from_response(request, err.response) from None except httpx.TimeoutException as err: if retries > 0: - return self._retry_request(options, cast_to, retries, stream=stream) + return self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls) raise APITimeoutError(request=request) from err except Exception as err: if retries > 0: - return self._retry_request(options, cast_to, retries, stream=stream) + return self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls) raise APIConnectionError(request=request) from err if stream: - return Stream(cast_to=cast_to, response=response, client=self) + stream_cls = stream_cls or cast("type[_StreamT] | None", self._default_stream_cls) + if stream_cls is None: + raise MissingStreamClassError() + return stream_cls(cast_to=cast_to, response=response, client=self) try: rsp = self._process_response(cast_to=cast_to, options=options, response=response) @@ -887,7 +789,8 @@ def _retry_request( response_headers: Optional[httpx.Headers] = None, *, stream: bool, - ) -> ResponseT | Stream[ResponseT]: + stream_cls: type[_StreamT] | None, + ) -> ResponseT | _StreamT: remaining = remaining_retries - 1 timeout = self._calculate_retry_timeout(remaining, options, response_headers) @@ -900,6 +803,7 @@ def _retry_request( cast_to=cast_to, remaining_retries=remaining, stream=stream, + stream_cls=stream_cls, ) def _request_api_list( @@ -951,7 +855,8 @@ def post( options: RequestOptions = {}, files: RequestFiles | None = None, stream: Literal[True], - ) -> Stream[ResponseT]: + stream_cls: type[_StreamT], + ) -> _StreamT: ... @overload @@ -964,7 +869,8 @@ def post( options: RequestOptions = {}, files: RequestFiles | None = None, stream: bool, - ) -> ResponseT | Stream[ResponseT]: + stream_cls: type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: ... def post( @@ -976,9 +882,10 @@ def post( options: RequestOptions = {}, files: RequestFiles | None = None, stream: bool = False, - ) -> ResponseT | Stream[ResponseT]: + stream_cls: type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options) - return cast(ResponseT, self.request(cast_to, opts, stream=stream)) + return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) def patch( self, @@ -1030,6 +937,7 @@ def get_api_list( class AsyncAPIClient(BaseClient): _client: httpx.AsyncClient + _default_stream_cls: type[AsyncStream[Any]] | None = None def __init__( self, @@ -1082,8 +990,9 @@ async def request( options: FinalRequestOptions, *, stream: Literal[True], + stream_cls: type[_AsyncStreamT], remaining_retries: Optional[int] = None, - ) -> AsyncStream[ResponseT]: + ) -> _AsyncStreamT: ... @overload @@ -1093,8 +1002,9 @@ async def request( options: FinalRequestOptions, *, stream: bool, + stream_cls: type[_AsyncStreamT] | None = None, remaining_retries: Optional[int] = None, - ) -> ResponseT | AsyncStream[ResponseT]: + ) -> ResponseT | _AsyncStreamT: ... async def request( @@ -1103,12 +1013,14 @@ async def request( options: FinalRequestOptions, *, stream: bool = False, + stream_cls: type[_AsyncStreamT] | None = None, remaining_retries: Optional[int] = None, - ) -> ResponseT | AsyncStream[ResponseT]: + ) -> ResponseT | _AsyncStreamT: return await self._request( cast_to=cast_to, options=options, stream=stream, + stream_cls=stream_cls, remaining_retries=remaining_retries, ) @@ -1118,8 +1030,9 @@ async def _request( options: FinalRequestOptions, *, stream: bool, + stream_cls: type[_AsyncStreamT] | None, remaining_retries: int | None, - ) -> ResponseT | AsyncStream[ResponseT]: + ) -> ResponseT | _AsyncStreamT: retries = self._remaining_retries(remaining_retries, options) request = self._build_request(options) @@ -1128,7 +1041,14 @@ async def _request( response.raise_for_status() except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code if retries > 0 and self._should_retry(err.response): - return await self._retry_request(options, cast_to, retries, err.response.headers, stream=stream) + return await self._retry_request( + options, + cast_to, + retries, + err.response.headers, + stream=stream, + stream_cls=stream_cls, + ) # If the response is streamed then we need to explicitly read the response # to completion before attempting to access the response text. @@ -1136,7 +1056,7 @@ async def _request( raise self._make_status_error_from_response(request, err.response) from None except httpx.ConnectTimeout as err: if retries > 0: - return await self._retry_request(options, cast_to, retries, stream=stream) + return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls) raise APITimeoutError(request=request) from err except httpx.ReadTimeout as err: # We explicitly do not retry on ReadTimeout errors as this means @@ -1146,15 +1066,18 @@ async def _request( raise except httpx.TimeoutException as err: if retries > 0: - return await self._retry_request(options, cast_to, retries, stream=stream) + return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls) raise APITimeoutError(request=request) from err except Exception as err: if retries > 0: - return await self._retry_request(options, cast_to, retries, stream=stream) + return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls) raise APIConnectionError(request=request) from err if stream: - return AsyncStream(cast_to=cast_to, response=response, client=self) + stream_cls = stream_cls or cast("type[_AsyncStreamT] | None", self._default_stream_cls) + if stream_cls is None: + raise MissingStreamClassError() + return stream_cls(cast_to=cast_to, response=response, client=self) try: rsp = self._process_response(cast_to=cast_to, options=options, response=response) @@ -1171,7 +1094,8 @@ async def _retry_request( response_headers: Optional[httpx.Headers] = None, *, stream: bool, - ) -> ResponseT | AsyncStream[ResponseT]: + stream_cls: type[_AsyncStreamT] | None, + ) -> ResponseT | _AsyncStreamT: remaining = remaining_retries - 1 timeout = self._calculate_retry_timeout(remaining, options, response_headers) @@ -1182,6 +1106,7 @@ async def _retry_request( cast_to=cast_to, remaining_retries=remaining, stream=stream, + stream_cls=stream_cls, ) def _request_api_list( @@ -1225,7 +1150,8 @@ async def post( files: RequestFiles | None = None, options: RequestOptions = {}, stream: Literal[True], - ) -> AsyncStream[ResponseT]: + stream_cls: type[_AsyncStreamT], + ) -> _AsyncStreamT: ... @overload @@ -1238,7 +1164,8 @@ async def post( files: RequestFiles | None = None, options: RequestOptions = {}, stream: bool, - ) -> ResponseT | AsyncStream[ResponseT]: + stream_cls: type[_AsyncStreamT] | None = None, + ) -> ResponseT | _AsyncStreamT: ... async def post( @@ -1250,9 +1177,10 @@ async def post( files: RequestFiles | None = None, options: RequestOptions = {}, stream: bool = False, - ) -> ResponseT | AsyncStream[ResponseT]: + stream_cls: type[_AsyncStreamT] | None = None, + ) -> ResponseT | _AsyncStreamT: opts = FinalRequestOptions.construct(method="post", url=path, json_data=body, files=files, **options) - return await self.request(cast_to, opts, stream=stream) + return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls) async def patch( self, diff --git a/src/finch/_client.py b/src/finch/_client.py index 4cdb3b70..b1f1cbf5 100644 --- a/src/finch/_client.py +++ b/src/finch/_client.py @@ -20,10 +20,15 @@ RequestOptions, ) from ._version import __version__ -from ._base_client import DEFAULT_LIMITS, DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES -from ._base_client import Stream as Stream -from ._base_client import AsyncStream as AsyncStream -from ._base_client import SyncAPIClient, AsyncAPIClient +from ._streaming import Stream as Stream +from ._streaming import AsyncStream as AsyncStream +from ._base_client import ( + DEFAULT_LIMITS, + DEFAULT_TIMEOUT, + DEFAULT_MAX_RETRIES, + SyncAPIClient, + AsyncAPIClient, +) __all__ = [ "Timeout", diff --git a/src/finch/_streaming.py b/src/finch/_streaming.py new file mode 100644 index 00000000..18749b53 --- /dev/null +++ b/src/finch/_streaming.py @@ -0,0 +1,204 @@ +# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator + +import httpx + +from ._types import ResponseT + +if TYPE_CHECKING: + from ._base_client import SyncAPIClient, AsyncAPIClient + + +class Stream(Generic[ResponseT]): + """Provides the core interface to iterate over a synchronous stream response.""" + + response: httpx.Response + + def __init__( + self, + *, + cast_to: type[ResponseT], + response: httpx.Response, + client: SyncAPIClient, + ) -> None: + self.response = response + self._cast_to = cast_to + self._client = client + self._decoder = SSEDecoder() + self._iterator = self.__stream__() + + def __next__(self) -> ResponseT: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[ResponseT]: + for item in self._iterator: + yield item + + def _iter_events(self) -> Iterator[ServerSentEvent]: + yield from self._decoder.iter(self.response.iter_lines()) + + def __stream__(self) -> Iterator[ResponseT]: + cast_to = self._cast_to + response = self.response + process_data = self._client._process_response_data + + for sse in self._iter_events(): + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + + +class AsyncStream(Generic[ResponseT]): + """Provides the core interface to iterate over an asynchronous stream response.""" + + response: httpx.Response + + def __init__( + self, + *, + cast_to: type[ResponseT], + response: httpx.Response, + client: AsyncAPIClient, + ) -> None: + self.response = response + self._cast_to = cast_to + self._client = client + self._decoder = SSEDecoder() + self._iterator = self.__stream__() + + async def __anext__(self) -> ResponseT: + return await self._iterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[ResponseT]: + async for item in self._iterator: + yield item + + async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: + async for sse in self._decoder.aiter(self.response.aiter_lines()): + yield sse + + async def __stream__(self) -> AsyncIterator[ResponseT]: + cast_to = self._cast_to + response = self.response + process_data = self._client._process_response_data + + async for sse in self._iter_events(): + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + + +class ServerSentEvent: + def __init__( + self, + *, + event: str | None = None, + data: str | None = None, + id: str | None = None, + retry: int | None = None, + ) -> None: + if data is None: + data = "" + + self._id = id + self._data = data + self._event = event or None + self._retry = retry + + @property + def event(self) -> str | None: + return self._event + + @property + def id(self) -> str | None: + return self._id + + @property + def retry(self) -> int | None: + return self._retry + + @property + def data(self) -> str: + return self._data + + def json(self) -> Any: + return json.loads(self.data) + + def __repr__(self) -> str: + return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" + + +class SSEDecoder: + _data: list[str] + _event: str | None + _retry: int | None + _last_event_id: str | None + + def __init__(self) -> None: + self._event = None + self._data = [] + self._last_event_id = None + self._retry = None + + def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields lines, iterate over it & yield every event encountered""" + for line in iterator: + line = line.rstrip("\n") + sse = self.decode(line) + if sse is not None: + yield sse + + async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]: + """Given an async iterator that yields lines, iterate over it & yield every event encountered""" + async for line in iterator: + line = line.rstrip("\n") + sse = self.decode(line) + if sse is not None: + yield sse + + def decode(self, line: str) -> ServerSentEvent | None: + # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 + + if not line: + if not self._event and not self._data and not self._last_event_id and self._retry is None: + return None + + sse = ServerSentEvent( + event=self._event, + data="\n".join(self._data), + id=self._last_event_id, + retry=self._retry, + ) + + # NOTE: as per the SSE spec, do not reset last_event_id. + self._event = None + self._data = [] + self._retry = None + + return sse + + if line.startswith(":"): + return None + + fieldname, _, value = line.partition(":") + + if value.startswith(" "): + value = value[1:] + + if fieldname == "event": + self._event = value + elif fieldname == "data": + self._data.append(value) + elif fieldname == "id": + if "\0" in value: + pass + else: + self._last_event_id = value + elif fieldname == "retry": + try: + self._retry = int(value) + except (TypeError, ValueError): + pass + else: + pass # Field is ignored. + + return None diff --git a/src/finch/_types.py b/src/finch/_types.py index c1ca74bf..3d0a265a 100644 --- a/src/finch/_types.py +++ b/src/finch/_types.py @@ -3,7 +3,9 @@ from typing import ( IO, TYPE_CHECKING, + Any, Dict, + List, Type, Tuple, Union, @@ -14,9 +16,13 @@ ) from typing_extensions import Literal, Protocol, TypedDict, runtime_checkable +import httpx import pydantic from httpx import Proxy, Timeout, Response, BaseTransport +if TYPE_CHECKING: + from ._models import BaseModel + Transport = BaseTransport Query = Mapping[str, object] Body = object @@ -143,3 +149,8 @@ def get(self, __key: str) -> str | None: HeadersLike = Union[Headers, HeadersLikeProtocol] + +ResponseT = TypeVar( + "ResponseT", + bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], httpx.Response, UnknownResponse, ModelBuilderProtocol]", +) diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 00000000..70eb81a3 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,104 @@ +from typing import Iterator, AsyncIterator + +import pytest + +from finch._streaming import SSEDecoder + + +@pytest.mark.asyncio +async def test_basic_async() -> None: + async def body() -> AsyncIterator[str]: + yield "event: completion" + yield 'data: {"foo":true}' + yield "" + + async for sse in SSEDecoder().aiter(body()): + assert sse.event == "completion" + assert sse.json() == {"foo": True} + + +def test_basic() -> None: + def body() -> Iterator[str]: + yield "event: completion" + yield 'data: {"foo":true}' + yield "" + + it = SSEDecoder().iter(body()) + sse = next(it) + assert sse.event == "completion" + assert sse.json() == {"foo": True} + + with pytest.raises(StopIteration): + next(it) + + +def test_data_missing_event() -> None: + def body() -> Iterator[str]: + yield 'data: {"foo":true}' + yield "" + + it = SSEDecoder().iter(body()) + sse = next(it) + assert sse.event is None + assert sse.json() == {"foo": True} + + with pytest.raises(StopIteration): + next(it) + + +def test_event_missing_data() -> None: + def body() -> Iterator[str]: + yield "event: ping" + yield "" + + it = SSEDecoder().iter(body()) + sse = next(it) + assert sse.event == "ping" + assert sse.data == "" + + with pytest.raises(StopIteration): + next(it) + + +def test_multiple_events() -> None: + def body() -> Iterator[str]: + yield "event: ping" + yield "" + yield "event: completion" + yield "" + + it = SSEDecoder().iter(body()) + + sse = next(it) + assert sse.event == "ping" + assert sse.data == "" + + sse = next(it) + assert sse.event == "completion" + assert sse.data == "" + + with pytest.raises(StopIteration): + next(it) + + +def test_multiple_events_with_data() -> None: + def body() -> Iterator[str]: + yield "event: ping" + yield 'data: {"foo":true}' + yield "" + yield "event: completion" + yield 'data: {"bar":false}' + yield "" + + it = SSEDecoder().iter(body()) + + sse = next(it) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + + sse = next(it) + assert sse.event == "completion" + assert sse.json() == {"bar": False} + + with pytest.raises(StopIteration): + next(it)