1
1
""" common utilities """
2
-
3
2
import itertools
3
+ from typing import Dict , Hashable , Union
4
4
from warnings import catch_warnings , filterwarnings
5
5
6
6
import numpy as np
@@ -29,7 +29,9 @@ def _axify(obj, key, axis):
29
29
class Base :
30
30
""" indexing comprehensive base class """
31
31
32
- _objs = {"series" , "frame" }
32
+ frame = None # type: Dict[str, DataFrame]
33
+ series = None # type: Dict[str, Series]
34
+ _kinds = {"series" , "frame" }
33
35
_typs = {
34
36
"ints" ,
35
37
"uints" ,
@@ -101,13 +103,12 @@ def setup_method(self, method):
101
103
self .series_empty = Series ()
102
104
103
105
# form agglomerates
104
- for o in self ._objs :
105
-
106
- d = dict ()
107
- for t in self ._typs :
108
- d [t ] = getattr (self , "{o}_{t}" .format (o = o , t = t ), None )
106
+ for kind in self ._kinds :
107
+ d = dict () # type: Dict[str, Union[DataFrame, Series]]
108
+ for typ in self ._typs :
109
+ d [typ ] = getattr (self , "{kind}_{typ}" .format (kind = kind , typ = typ ))
109
110
110
- setattr (self , o , d )
111
+ setattr (self , kind , d )
111
112
112
113
def generate_indices (self , f , values = False ):
113
114
""" generate the indices
@@ -117,7 +118,7 @@ def generate_indices(self, f, values=False):
117
118
118
119
axes = f .axes
119
120
if values :
120
- axes = (list (range (len (a ))) for a in axes )
121
+ axes = (list (range (len (ax ))) for ax in axes )
121
122
122
123
return itertools .product (* axes )
123
124
@@ -186,34 +187,42 @@ def check_result(
186
187
method2 ,
187
188
key2 ,
188
189
typs = None ,
189
- objs = None ,
190
+ kinds = None ,
190
191
axes = None ,
191
192
fails = None ,
192
193
):
193
- def _eq (t , o , a , obj , k1 , k2 ):
194
- """ compare equal for these 2 keys """
195
194
196
- if a is not None and a > obj .ndim - 1 :
195
+ def _eq (
196
+ typ : str ,
197
+ kind : str ,
198
+ axis : int ,
199
+ obj : Union [DataFrame , Series ],
200
+ key1 : Hashable ,
201
+ key2 : Hashable ,
202
+ ) -> None :
203
+ """ compare equal for these 2 keys """
204
+ if axis > obj .ndim - 1 :
197
205
return
198
206
199
207
def _print (result , error = None ):
200
- if error is not None :
201
- error = str (error )
202
- v = (
208
+ err = str (error ) if error is not None else ""
209
+ msg = (
203
210
"%-16.16s [%-16.16s]: [typ->%-8.8s,obj->%-8.8s,"
204
211
"key1->(%-4.4s),key2->(%-4.4s),axis->%s] %s"
205
- % (name , result , t , o , method1 , method2 , a , error or "" )
212
+ % (name , result , typ , kind , method1 , method2 , axis , err )
206
213
)
207
214
if _verbose :
208
- pprint_thing (v )
215
+ pprint_thing (msg )
209
216
210
217
try :
211
- rs = getattr (obj , method1 ).__getitem__ (_axify (obj , k1 , a ))
218
+ rs = getattr (obj , method1 ).__getitem__ (_axify (obj , key1 , axis ))
212
219
213
220
with catch_warnings (record = True ):
214
221
filterwarnings ("ignore" , "\\ n.ix" , FutureWarning )
215
222
try :
216
- xp = self .get_result (obj , method2 , k2 , a )
223
+ xp = self .get_result (
224
+ obj = obj , method = method2 , key = key2 , axis = axis
225
+ )
217
226
except (KeyError , IndexError ):
218
227
# TODO: why is this allowed?
219
228
result = "no comp"
@@ -228,8 +237,8 @@ def _print(result, error=None):
228
237
else :
229
238
tm .assert_equal (rs , xp )
230
239
result = "ok"
231
- except AssertionError as e :
232
- detail = str (e )
240
+ except AssertionError as exc :
241
+ detail = str (exc )
233
242
result = "fail"
234
243
235
244
# reverse the checks
@@ -258,36 +267,25 @@ def _print(result, error=None):
258
267
if typs is None :
259
268
typs = self ._typs
260
269
261
- if objs is None :
262
- objs = self ._objs
270
+ if kinds is None :
271
+ kinds = self ._kinds
263
272
264
- if axes is not None :
265
- if not isinstance (axes , (tuple , list )):
266
- axes = [axes ]
267
- else :
268
- axes = list (axes )
269
- else :
273
+ if axes is None :
270
274
axes = [0 , 1 ]
275
+ elif not isinstance (axes , (tuple , list )):
276
+ assert isinstance (axes , int )
277
+ axes = [axes ]
271
278
272
279
# check
273
- for o in objs :
274
- if o not in self ._objs :
280
+ for kind in kinds : # type: str
281
+ if kind not in self ._kinds :
275
282
continue
276
283
277
- d = getattr (self , o )
278
- for a in axes :
279
- for t in typs :
280
- if t not in self ._typs :
284
+ d = getattr (self , kind ) # type: Dict[str, Union[DataFrame, Series]]
285
+ for ax in axes :
286
+ for typ in typs :
287
+ if typ not in self ._typs :
281
288
continue
282
289
283
- obj = d [t ]
284
- if obj is None :
285
- continue
286
-
287
- def _call (obj = obj ):
288
- obj = obj .copy ()
289
-
290
- k2 = key2
291
- _eq (t , o , a , obj , key1 , k2 )
292
-
293
- _call ()
290
+ obj = d [typ ]
291
+ _eq (typ = typ , kind = kind , axis = ax , obj = obj , key1 = key1 , key2 = key2 )
0 commit comments