2
2
from __future__ import annotations
3
3
4
4
import json
5
- from typing import TYPE_CHECKING , Any , Generic , Iterator , AsyncIterator
6
- from typing_extensions import override
5
+ from types import TracebackType
6
+ from typing import TYPE_CHECKING , Any , Generic , TypeVar , Iterator , AsyncIterator , cast
7
+ from typing_extensions import Self , override
7
8
8
9
import httpx
9
10
10
- from ._types import ResponseT
11
-
12
11
if TYPE_CHECKING :
13
12
from ._client import Finch , AsyncFinch
14
13
15
14
16
- class Stream (Generic [ResponseT ]):
15
+ _T = TypeVar ("_T" )
16
+
17
+
18
+ class Stream (Generic [_T ]):
17
19
"""Provides the core interface to iterate over a synchronous stream response."""
18
20
19
21
response : httpx .Response
20
22
21
23
def __init__ (
22
24
self ,
23
25
* ,
24
- cast_to : type [ResponseT ],
26
+ cast_to : type [_T ],
25
27
response : httpx .Response ,
26
28
client : Finch ,
27
29
) -> None :
@@ -31,18 +33,18 @@ def __init__(
31
33
self ._decoder = SSEDecoder ()
32
34
self ._iterator = self .__stream__ ()
33
35
34
- def __next__ (self ) -> ResponseT :
36
+ def __next__ (self ) -> _T :
35
37
return self ._iterator .__next__ ()
36
38
37
- def __iter__ (self ) -> Iterator [ResponseT ]:
39
+ def __iter__ (self ) -> Iterator [_T ]:
38
40
for item in self ._iterator :
39
41
yield item
40
42
41
43
def _iter_events (self ) -> Iterator [ServerSentEvent ]:
42
44
yield from self ._decoder .iter (self .response .iter_lines ())
43
45
44
- def __stream__ (self ) -> Iterator [ResponseT ]:
45
- cast_to = self ._cast_to
46
+ def __stream__ (self ) -> Iterator [_T ]:
47
+ cast_to = cast ( Any , self ._cast_to )
46
48
response = self .response
47
49
process_data = self ._client ._process_response_data
48
50
iterator = self ._iter_events ()
@@ -54,16 +56,35 @@ def __stream__(self) -> Iterator[ResponseT]:
54
56
for _sse in iterator :
55
57
...
56
58
59
+ def __enter__ (self ) -> Self :
60
+ return self
61
+
62
+ def __exit__ (
63
+ self ,
64
+ exc_type : type [BaseException ] | None ,
65
+ exc : BaseException | None ,
66
+ exc_tb : TracebackType | None ,
67
+ ) -> None :
68
+ self .close ()
69
+
70
+ def close (self ) -> None :
71
+ """
72
+ Close the response and release the connection.
73
+
74
+ Automatically called if the response body is read to completion.
75
+ """
76
+ self .response .close ()
57
77
58
- class AsyncStream (Generic [ResponseT ]):
78
+
79
+ class AsyncStream (Generic [_T ]):
59
80
"""Provides the core interface to iterate over an asynchronous stream response."""
60
81
61
82
response : httpx .Response
62
83
63
84
def __init__ (
64
85
self ,
65
86
* ,
66
- cast_to : type [ResponseT ],
87
+ cast_to : type [_T ],
67
88
response : httpx .Response ,
68
89
client : AsyncFinch ,
69
90
) -> None :
@@ -73,19 +94,19 @@ def __init__(
73
94
self ._decoder = SSEDecoder ()
74
95
self ._iterator = self .__stream__ ()
75
96
76
- async def __anext__ (self ) -> ResponseT :
97
+ async def __anext__ (self ) -> _T :
77
98
return await self ._iterator .__anext__ ()
78
99
79
- async def __aiter__ (self ) -> AsyncIterator [ResponseT ]:
100
+ async def __aiter__ (self ) -> AsyncIterator [_T ]:
80
101
async for item in self ._iterator :
81
102
yield item
82
103
83
104
async def _iter_events (self ) -> AsyncIterator [ServerSentEvent ]:
84
105
async for sse in self ._decoder .aiter (self .response .aiter_lines ()):
85
106
yield sse
86
107
87
- async def __stream__ (self ) -> AsyncIterator [ResponseT ]:
88
- cast_to = self ._cast_to
108
+ async def __stream__ (self ) -> AsyncIterator [_T ]:
109
+ cast_to = cast ( Any , self ._cast_to )
89
110
response = self .response
90
111
process_data = self ._client ._process_response_data
91
112
iterator = self ._iter_events ()
@@ -97,6 +118,25 @@ async def __stream__(self) -> AsyncIterator[ResponseT]:
97
118
async for _sse in iterator :
98
119
...
99
120
121
+ async def __aenter__ (self ) -> Self :
122
+ return self
123
+
124
+ async def __aexit__ (
125
+ self ,
126
+ exc_type : type [BaseException ] | None ,
127
+ exc : BaseException | None ,
128
+ exc_tb : TracebackType | None ,
129
+ ) -> None :
130
+ await self .close ()
131
+
132
+ async def close (self ) -> None :
133
+ """
134
+ Close the response and release the connection.
135
+
136
+ Automatically called if the response body is read to completion.
137
+ """
138
+ await self .response .aclose ()
139
+
100
140
101
141
class ServerSentEvent :
102
142
def __init__ (
0 commit comments