|
24 | 24 | overload,
|
25 | 25 | )
|
26 | 26 | from functools import lru_cache
|
27 |
| -from typing_extensions import Literal, get_origin |
| 27 | +from typing_extensions import Literal, get_args, get_origin |
28 | 28 |
|
29 | 29 | import anyio
|
30 | 30 | import httpx
|
@@ -458,6 +458,14 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o
|
458 | 458 | serialized[key] = value
|
459 | 459 | return serialized
|
460 | 460 |
|
| 461 | + def _extract_stream_chunk_type(self, stream_cls: type) -> type: |
| 462 | + args = get_args(stream_cls) |
| 463 | + if not args: |
| 464 | + raise TypeError( |
| 465 | + f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}", |
| 466 | + ) |
| 467 | + return cast(type, args[0]) |
| 468 | + |
461 | 469 | def _process_response(
|
462 | 470 | self,
|
463 | 471 | *,
|
@@ -793,7 +801,10 @@ def _request(
|
793 | 801 | raise APIConnectionError(request=request) from err
|
794 | 802 |
|
795 | 803 | if stream:
|
796 |
| - stream_cls = stream_cls or cast("type[_StreamT] | None", self._default_stream_cls) |
| 804 | + if stream_cls: |
| 805 | + return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self) |
| 806 | + |
| 807 | + stream_cls = cast("type[_StreamT] | None", self._default_stream_cls) |
797 | 808 | if stream_cls is None:
|
798 | 809 | raise MissingStreamClassError()
|
799 | 810 | return stream_cls(cast_to=cast_to, response=response, client=self)
|
@@ -1156,7 +1167,10 @@ async def _request(
|
1156 | 1167 | raise APIConnectionError(request=request) from err
|
1157 | 1168 |
|
1158 | 1169 | if stream:
|
1159 |
| - stream_cls = stream_cls or cast("type[_AsyncStreamT] | None", self._default_stream_cls) |
| 1170 | + if stream_cls: |
| 1171 | + return stream_cls(cast_to=self._extract_stream_chunk_type(stream_cls), response=response, client=self) |
| 1172 | + |
| 1173 | + stream_cls = cast("type[_AsyncStreamT] | None", self._default_stream_cls) |
1160 | 1174 | if stream_cls is None:
|
1161 | 1175 | raise MissingStreamClassError()
|
1162 | 1176 | return stream_cls(cast_to=cast_to, response=response, client=self)
|
|
0 commit comments