Skip to content

Commit ee7284a

Browse files
chore(internal): update base client (#90)
1 parent 483ffef commit ee7284a

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/finch/_base_client.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
overload,
2525
)
2626
from functools import lru_cache
27-
from typing_extensions import Literal, get_origin
27+
from typing_extensions import Literal, get_args, get_origin
2828

2929
import anyio
3030
import httpx
@@ -458,6 +458,14 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o
458458
serialized[key] = value
459459
return serialized
460460

461+
def _extract_stream_chunk_type(self, stream_cls: type) -> type:
462+
args = get_args(stream_cls)
463+
if not args:
464+
raise TypeError(
465+
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
466+
)
467+
return cast(type, args[0])
468+
461469
def _process_response(
462470
self,
463471
*,
@@ -793,7 +801,10 @@ def _request(
793801
raise APIConnectionError(request=request) from err
794802

795803
if stream:
796-
stream_cls = stream_cls or cast("type[_StreamT] | None", self._default_stream_cls)
804+
if stream_cls:
805+
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)
806+
807+
stream_cls = cast("type[_StreamT] | None", self._default_stream_cls)
797808
if stream_cls is None:
798809
raise MissingStreamClassError()
799810
return stream_cls(cast_to=cast_to, response=response, client=self)
@@ -1156,7 +1167,10 @@ async def _request(
11561167
raise APIConnectionError(request=request) from err
11571168

11581169
if stream:
1159-
stream_cls = stream_cls or cast("type[_AsyncStreamT] | None", self._default_stream_cls)
1170+
if stream_cls:
1171+
return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self)
1172+
1173+
stream_cls = cast("type[_AsyncStreamT] | None", self._default_stream_cls)
11601174
if stream_cls is None:
11611175
raise MissingStreamClassError()
11621176
return stream_cls(cast_to=cast_to, response=response, client=self)

0 commit comments

Comments
 (0)