3
3
import posixpath
4
4
import warnings
5
5
from abc import ABC
6
+ from typing import ItemsView , cast
6
7
7
8
from typing_extensions import (
8
9
TYPE_CHECKING ,
@@ -73,6 +74,8 @@ def __init__(self, ctx: Context, path: str, /, **attributes):
73
74
class Resource (Protocol ):
74
75
def __getitem__ (self , key : Hashable ) -> Any : ...
75
76
77
+ def items (self ) -> ItemsView : ...
78
+
76
79
77
80
class _Resource (dict , Resource ):
78
81
def __init__ (self , ctx : Context , path : str , ** attributes ):
@@ -92,6 +95,10 @@ def update(self, **attributes): # type: ignore[reportIncompatibleMethodOverride
92
95
T = TypeVar ("T" , bound = Resource )
93
96
94
97
98
+ class ResourceFactory (Protocol ):
99
+ def __call__ (self , ctx : Context , path : str , ** attributes ) -> Resource : ...
100
+
101
+
95
102
class ResourceSequence (Protocol [T ]):
96
103
@overload
97
104
def __getitem__ (self , index : SupportsIndex , / ) -> T : ...
@@ -109,10 +116,17 @@ def __repr__(self) -> str: ...
109
116
110
117
111
118
class _ResourceSequence (Sequence [T ], ResourceSequence [T ]):
112
- def __init__ (self , ctx : Context , path : str , * , uid : str = "guid" ):
119
+ def __init__ (
120
+ self ,
121
+ ctx : Context ,
122
+ path : str ,
123
+ factory : ResourceFactory = _Resource ,
124
+ uid : str = "guid" ,
125
+ ):
113
126
self ._ctx = ctx
114
127
self ._path = path
115
128
self ._uid = uid
129
+ self ._factory = factory
116
130
117
131
def __getitem__ (self , index ):
118
132
return list (self .fetch ())[index ]
@@ -129,32 +143,32 @@ def __str__(self) -> str:
129
143
def __repr__ (self ) -> str :
130
144
return repr (self .fetch ())
131
145
132
- def create (self , ** attributes : Any ) -> Any :
146
+ def create (self , ** attributes : Any ) -> T :
133
147
response = self ._ctx .client .post (self ._path , json = attributes )
134
148
result = response .json ()
135
149
uid = result [self ._uid ]
136
150
path = posixpath .join (self ._path , uid )
137
- return _Resource ( self ._ctx , path , ** result )
151
+ return cast ( T , self ._factory ( self . _ctx , path , ** result ) )
138
152
139
- def fetch (self , ** conditions ) -> Iterable [Any ]:
153
+ def fetch (self , ** conditions ) -> Iterable [T ]:
140
154
response = self ._ctx .client .get (self ._path , params = conditions )
141
155
results = response .json ()
142
- resources = []
156
+ resources : List [ T ] = []
143
157
for result in results :
144
158
uid = result [self ._uid ]
145
159
path = posixpath .join (self ._path , uid )
146
- resource = _Resource ( self ._ctx , path , ** result )
160
+ resource = cast ( T , self ._factory ( self . _ctx , path , ** result ) )
147
161
resources .append (resource )
148
162
149
163
return resources
150
164
151
- def find (self , * args : str ) -> Any :
165
+ def find (self , * args : str ) -> T :
152
166
path = posixpath .join (self ._path , * args )
153
167
response = self ._ctx .client .get (path )
154
168
result = response .json ()
155
- return _Resource ( self ._ctx , path , ** result )
169
+ return cast ( T , self ._factory ( self . _ctx , path , ** result ) )
156
170
157
- def find_by (self , ** conditions ) -> Any | None :
171
+ def find_by (self , ** conditions ) -> T | None :
158
172
"""
159
173
Find the first record matching the specified conditions.
160
174
@@ -169,19 +183,19 @@ def find_by(self, **conditions) -> Any | None:
169
183
Optional[T]
170
184
The first record matching the conditions, or `None` if no match is found.
171
185
"""
172
- collection = self .fetch (** conditions )
186
+ collection : Iterable [ T ] = self .fetch (** conditions )
173
187
return next ((v for v in collection if v .items () >= conditions .items ()), None )
174
188
175
189
176
- class _PaginatedResourceSequence (_ResourceSequence ):
177
- def fetch (self , ** conditions ):
190
+ class _PaginatedResourceSequence (_ResourceSequence [ T ] ):
191
+ def fetch (self , ** conditions ) -> Iterator [ T ] :
178
192
paginator = Paginator (self ._ctx , self ._path , dict (** conditions ))
179
193
for page in paginator .fetch_pages ():
180
194
resources = []
181
195
results = page .results
182
196
for result in results :
183
197
uid = result [self ._uid ]
184
198
path = posixpath .join (self ._path , uid )
185
- resource = _Resource ( self ._ctx , path , ** result )
199
+ resource = cast ( T , self ._factory ( self . _ctx , path , ** result ) )
186
200
resources .append (resource )
187
201
yield from resources
0 commit comments