Skip to content

Commit e083475

Browse files
committed
clean tests/indexing/common.py
1 parent 46d88c1 commit e083475

File tree

3 files changed

+50
-52
lines changed

3 files changed

+50
-52
lines changed

pandas/tests/indexing/common.py

+45-47
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
""" common utilities """
2-
32
import itertools
3+
from typing import Dict, Hashable, Union
44
from warnings import catch_warnings, filterwarnings
55

66
import numpy as np
@@ -29,7 +29,9 @@ def _axify(obj, key, axis):
2929
class Base:
3030
""" indexing comprehensive base class """
3131

32-
_objs = {"series", "frame"}
32+
frame = None # type: Dict[str, DataFrame]
33+
series = None # type: Dict[str, Series]
34+
_kinds = {"series", "frame"}
3335
_typs = {
3436
"ints",
3537
"uints",
@@ -101,13 +103,12 @@ def setup_method(self, method):
101103
self.series_empty = Series()
102104

103105
# form agglomerates
104-
for o in self._objs:
105-
106-
d = dict()
107-
for t in self._typs:
108-
d[t] = getattr(self, "{o}_{t}".format(o=o, t=t), None)
106+
for kind in self._kinds:
107+
d = dict() # type: Dict[str, Union[DataFrame, Series]]
108+
for typ in self._typs:
109+
d[typ] = getattr(self, "{kind}_{typ}".format(kind=kind, typ=typ))
109110

110-
setattr(self, o, d)
111+
setattr(self, kind, d)
111112

112113
def generate_indices(self, f, values=False):
113114
""" generate the indices
@@ -117,7 +118,7 @@ def generate_indices(self, f, values=False):
117118

118119
axes = f.axes
119120
if values:
120-
axes = (list(range(len(a))) for a in axes)
121+
axes = (list(range(len(ax))) for ax in axes)
121122

122123
return itertools.product(*axes)
123124

@@ -186,34 +187,42 @@ def check_result(
186187
method2,
187188
key2,
188189
typs=None,
189-
objs=None,
190+
kinds=None,
190191
axes=None,
191192
fails=None,
192193
):
193-
def _eq(t, o, a, obj, k1, k2):
194-
""" compare equal for these 2 keys """
195194

196-
if a is not None and a > obj.ndim - 1:
195+
def _eq(
196+
typ: str,
197+
kind: str,
198+
axis: int,
199+
obj: Union[DataFrame, Series],
200+
key1: Hashable,
201+
key2: Hashable,
202+
) -> None:
203+
""" compare equal for these 2 keys """
204+
if axis > obj.ndim - 1:
197205
return
198206

199207
def _print(result, error=None):
200-
if error is not None:
201-
error = str(error)
202-
v = (
208+
err = str(error) if error is not None else ""
209+
msg = (
203210
"%-16.16s [%-16.16s]: [typ->%-8.8s,obj->%-8.8s,"
204211
"key1->(%-4.4s),key2->(%-4.4s),axis->%s] %s"
205-
% (name, result, t, o, method1, method2, a, error or "")
212+
% (name, result, typ, kind, method1, method2, axis, err)
206213
)
207214
if _verbose:
208-
pprint_thing(v)
215+
pprint_thing(msg)
209216

210217
try:
211-
rs = getattr(obj, method1).__getitem__(_axify(obj, k1, a))
218+
rs = getattr(obj, method1).__getitem__(_axify(obj, key1, axis))
212219

213220
with catch_warnings(record=True):
214221
filterwarnings("ignore", "\\n.ix", FutureWarning)
215222
try:
216-
xp = self.get_result(obj, method2, k2, a)
223+
xp = self.get_result(
224+
obj=obj, method=method2, key=key2, axis=axis
225+
)
217226
except (KeyError, IndexError):
218227
# TODO: why is this allowed?
219228
result = "no comp"
@@ -228,8 +237,8 @@ def _print(result, error=None):
228237
else:
229238
tm.assert_equal(rs, xp)
230239
result = "ok"
231-
except AssertionError as e:
232-
detail = str(e)
240+
except AssertionError as exc:
241+
detail = str(exc)
233242
result = "fail"
234243

235244
# reverse the checks
@@ -258,36 +267,25 @@ def _print(result, error=None):
258267
if typs is None:
259268
typs = self._typs
260269

261-
if objs is None:
262-
objs = self._objs
270+
if kinds is None:
271+
kinds = self._kinds
263272

264-
if axes is not None:
265-
if not isinstance(axes, (tuple, list)):
266-
axes = [axes]
267-
else:
268-
axes = list(axes)
269-
else:
273+
if axes is None:
270274
axes = [0, 1]
275+
elif not isinstance(axes, (tuple, list)):
276+
assert isinstance(axes, int)
277+
axes = [axes]
271278

272279
# check
273-
for o in objs:
274-
if o not in self._objs:
280+
for kind in kinds: # type: str
281+
if kind not in self._kinds:
275282
continue
276283

277-
d = getattr(self, o)
278-
for a in axes:
279-
for t in typs:
280-
if t not in self._typs:
284+
d = getattr(self, kind) # type: Dict[str, Union[DataFrame, Series]]
285+
for ax in axes:
286+
for typ in typs:
287+
if typ not in self._typs:
281288
continue
282289

283-
obj = d[t]
284-
if obj is None:
285-
continue
286-
287-
def _call(obj=obj):
288-
obj = obj.copy()
289-
290-
k2 = key2
291-
_eq(t, o, a, obj, key1, k2)
292-
293-
_call()
290+
obj = d[typ]
291+
_eq(typ=typ, kind=kind, axis=ax, obj=obj, key1=key1, key2=key2)

pandas/tests/indexing/test_iloc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def test_iloc_getitem_dups(self):
284284
[0, 1, 1, 3],
285285
"ix",
286286
{0: [0, 2, 2, 6], 1: [0, 3, 3, 9]},
287-
objs=["series", "frame"],
287+
kinds=["series", "frame"],
288288
typs=["ints", "uints"],
289289
)
290290

pandas/tests/indexing/test_scalar.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def _check(f, func, values=False):
1919
expected = self.get_value(f, i, values)
2020
tm.assert_almost_equal(result, expected)
2121

22-
for o in self._objs:
22+
for kind in self._kinds:
2323

24-
d = getattr(self, o)
24+
d = getattr(self, kind)
2525

2626
# iat
2727
for f in [d["ints"], d["uints"]]:
@@ -47,9 +47,9 @@ def _check(f, func, values=False):
4747
expected = self.get_value(f, i, values)
4848
tm.assert_almost_equal(expected, 1)
4949

50-
for t in self._objs:
50+
for kind in self._kinds:
5151

52-
d = getattr(self, t)
52+
d = getattr(self, kind)
5353

5454
# iat
5555
for f in [d["ints"], d["uints"]]:

0 commit comments

Comments
 (0)