Skip to content

Commit 6c46013

Browse files
authored
TST/CLN: Use fixtures instead of setup_method in tests/indexing (#49157)
1 parent c667fc4 commit 6c46013

File tree

5 files changed

+230
-237
lines changed

5 files changed

+230
-237
lines changed

pandas/tests/indexing/common.py

+24-174
Original file line numberDiff line numberDiff line change
@@ -1,190 +1,40 @@
11
""" common utilities """
2-
import itertools
2+
from __future__ import annotations
33

4-
import numpy as np
5-
6-
from pandas import (
7-
DataFrame,
8-
MultiIndex,
9-
Series,
10-
date_range,
11-
)
12-
import pandas._testing as tm
13-
from pandas.core.api import (
14-
Float64Index,
15-
UInt64Index,
4+
from typing import (
5+
Any,
6+
Literal,
167
)
178

189

19-
def _mklbl(prefix, n):
10+
def _mklbl(prefix: str, n: int):
2011
return [f"{prefix}{i}" for i in range(n)]
2112

2213

23-
def _axify(obj, key, axis):
24-
# create a tuple accessor
25-
axes = [slice(None)] * obj.ndim
26-
axes[axis] = key
27-
return tuple(axes)
28-
29-
30-
class Base:
31-
"""indexing comprehensive base class"""
32-
33-
_kinds = {"series", "frame"}
34-
_typs = {
35-
"ints",
36-
"uints",
37-
"labels",
38-
"mixed",
39-
"ts",
40-
"floats",
41-
"empty",
42-
"ts_rev",
43-
"multi",
44-
}
45-
46-
def setup_method(self):
47-
48-
self.series_ints = Series(np.random.rand(4), index=np.arange(0, 8, 2))
49-
self.frame_ints = DataFrame(
50-
np.random.randn(4, 4), index=np.arange(0, 8, 2), columns=np.arange(0, 12, 3)
51-
)
52-
53-
self.series_uints = Series(
54-
np.random.rand(4), index=UInt64Index(np.arange(0, 8, 2))
55-
)
56-
self.frame_uints = DataFrame(
57-
np.random.randn(4, 4),
58-
index=UInt64Index(range(0, 8, 2)),
59-
columns=UInt64Index(range(0, 12, 3)),
60-
)
61-
62-
self.series_floats = Series(
63-
np.random.rand(4), index=Float64Index(range(0, 8, 2))
64-
)
65-
self.frame_floats = DataFrame(
66-
np.random.randn(4, 4),
67-
index=Float64Index(range(0, 8, 2)),
68-
columns=Float64Index(range(0, 12, 3)),
69-
)
70-
71-
m_idces = [
72-
MultiIndex.from_product([[1, 2], [3, 4]]),
73-
MultiIndex.from_product([[5, 6], [7, 8]]),
74-
MultiIndex.from_product([[9, 10], [11, 12]]),
75-
]
76-
77-
self.series_multi = Series(np.random.rand(4), index=m_idces[0])
78-
self.frame_multi = DataFrame(
79-
np.random.randn(4, 4), index=m_idces[0], columns=m_idces[1]
80-
)
81-
82-
self.series_labels = Series(np.random.randn(4), index=list("abcd"))
83-
self.frame_labels = DataFrame(
84-
np.random.randn(4, 4), index=list("abcd"), columns=list("ABCD")
85-
)
86-
87-
self.series_mixed = Series(np.random.randn(4), index=[2, 4, "null", 8])
88-
self.frame_mixed = DataFrame(np.random.randn(4, 4), index=[2, 4, "null", 8])
89-
90-
self.series_ts = Series(
91-
np.random.randn(4), index=date_range("20130101", periods=4)
92-
)
93-
self.frame_ts = DataFrame(
94-
np.random.randn(4, 4), index=date_range("20130101", periods=4)
95-
)
96-
97-
dates_rev = date_range("20130101", periods=4).sort_values(ascending=False)
98-
self.series_ts_rev = Series(np.random.randn(4), index=dates_rev)
99-
self.frame_ts_rev = DataFrame(np.random.randn(4, 4), index=dates_rev)
100-
101-
self.frame_empty = DataFrame()
102-
self.series_empty = Series(dtype=object)
103-
104-
# form agglomerates
105-
for kind in self._kinds:
106-
d = {}
107-
for typ in self._typs:
108-
d[typ] = getattr(self, f"{kind}_{typ}")
109-
110-
setattr(self, kind, d)
111-
112-
def generate_indices(self, f, values=False):
113-
"""
114-
generate the indices
115-
if values is True , use the axis values
116-
is False, use the range
117-
"""
118-
axes = f.axes
119-
if values:
120-
axes = (list(range(len(ax))) for ax in axes)
121-
122-
return itertools.product(*axes)
123-
124-
def get_value(self, name, f, i, values=False):
125-
"""return the value for the location i"""
126-
# check against values
127-
if values:
128-
return f.values[i]
129-
130-
elif name == "iat":
131-
return f.iloc[i]
132-
else:
133-
assert name == "at"
134-
return f.loc[i]
135-
136-
def check_values(self, f, func, values=False):
137-
138-
if f is None:
139-
return
140-
axes = f.axes
141-
indices = itertools.product(*axes)
142-
143-
for i in indices:
144-
result = getattr(f, func)[i]
145-
146-
# check against values
147-
if values:
148-
expected = f.values[i]
149-
else:
150-
expected = f
151-
for a in reversed(i):
152-
expected = expected.__getitem__(a)
153-
154-
tm.assert_almost_equal(result, expected)
155-
156-
def check_result(self, method, key, typs=None, axes=None, fails=None):
157-
def _eq(axis, obj, key):
158-
"""compare equal for these 2 keys"""
159-
axified = _axify(obj, key, axis)
14+
def check_indexing_smoketest_or_raises(
15+
obj,
16+
method: Literal["iloc", "loc"],
17+
key: Any,
18+
axes: Literal[0, 1] | None = None,
19+
fails=None,
20+
) -> None:
21+
if axes is None:
22+
axes_list = [0, 1]
23+
else:
24+
assert axes in [0, 1]
25+
axes_list = [axes]
26+
27+
for ax in axes_list:
28+
if ax < obj.ndim:
29+
# create a tuple accessor
30+
new_axes = [slice(None)] * obj.ndim
31+
new_axes[ax] = key
32+
axified = tuple(new_axes)
16033
try:
16134
getattr(obj, method).__getitem__(axified)
162-
16335
except (IndexError, TypeError, KeyError) as detail:
164-
16536
# if we are in fails, the ok, otherwise raise it
16637
if fails is not None:
16738
if isinstance(detail, fails):
16839
return
16940
raise
170-
171-
if typs is None:
172-
typs = self._typs
173-
174-
if axes is None:
175-
axes = [0, 1]
176-
else:
177-
assert axes in [0, 1]
178-
axes = [axes]
179-
180-
# check
181-
for kind in self._kinds:
182-
183-
d = getattr(self, kind)
184-
for ax in axes:
185-
for typ in typs:
186-
assert typ in self._typs
187-
188-
obj = d[typ]
189-
if ax < obj.ndim:
190-
_eq(axis=ax, obj=obj, key=key)

pandas/tests/indexing/conftest.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas import (
5+
DataFrame,
6+
MultiIndex,
7+
Series,
8+
date_range,
9+
)
10+
from pandas.core.api import (
11+
Float64Index,
12+
UInt64Index,
13+
)
14+
15+
16+
@pytest.fixture
17+
def series_ints():
18+
return Series(np.random.rand(4), index=np.arange(0, 8, 2))
19+
20+
21+
@pytest.fixture
22+
def frame_ints():
23+
return DataFrame(
24+
np.random.randn(4, 4), index=np.arange(0, 8, 2), columns=np.arange(0, 12, 3)
25+
)
26+
27+
28+
@pytest.fixture
29+
def series_uints():
30+
return Series(np.random.rand(4), index=UInt64Index(np.arange(0, 8, 2)))
31+
32+
33+
@pytest.fixture
34+
def frame_uints():
35+
return DataFrame(
36+
np.random.randn(4, 4),
37+
index=UInt64Index(range(0, 8, 2)),
38+
columns=UInt64Index(range(0, 12, 3)),
39+
)
40+
41+
42+
@pytest.fixture
43+
def series_labels():
44+
return Series(np.random.randn(4), index=list("abcd"))
45+
46+
47+
@pytest.fixture
48+
def frame_labels():
49+
return DataFrame(np.random.randn(4, 4), index=list("abcd"), columns=list("ABCD"))
50+
51+
52+
@pytest.fixture
53+
def series_ts():
54+
return Series(np.random.randn(4), index=date_range("20130101", periods=4))
55+
56+
57+
@pytest.fixture
58+
def frame_ts():
59+
return DataFrame(np.random.randn(4, 4), index=date_range("20130101", periods=4))
60+
61+
62+
@pytest.fixture
63+
def series_floats():
64+
return Series(np.random.rand(4), index=Float64Index(range(0, 8, 2)))
65+
66+
67+
@pytest.fixture
68+
def frame_floats():
69+
return DataFrame(
70+
np.random.randn(4, 4),
71+
index=Float64Index(range(0, 8, 2)),
72+
columns=Float64Index(range(0, 12, 3)),
73+
)
74+
75+
76+
@pytest.fixture
77+
def series_mixed():
78+
return Series(np.random.randn(4), index=[2, 4, "null", 8])
79+
80+
81+
@pytest.fixture
82+
def frame_mixed():
83+
return DataFrame(np.random.randn(4, 4), index=[2, 4, "null", 8])
84+
85+
86+
@pytest.fixture
87+
def frame_empty():
88+
return DataFrame()
89+
90+
91+
@pytest.fixture
92+
def series_empty():
93+
return Series(dtype=object)
94+
95+
96+
@pytest.fixture
97+
def frame_multi():
98+
return DataFrame(
99+
np.random.randn(4, 4),
100+
index=MultiIndex.from_product([[1, 2], [3, 4]]),
101+
columns=MultiIndex.from_product([[5, 6], [7, 8]]),
102+
)
103+
104+
105+
@pytest.fixture
106+
def series_multi():
107+
return Series(np.random.rand(4), index=MultiIndex.from_product([[1, 2], [3, 4]]))

pandas/tests/indexing/test_iloc.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
import pandas._testing as tm
3434
from pandas.api.types import is_scalar
35-
from pandas.tests.indexing.common import Base
35+
from pandas.tests.indexing.common import check_indexing_smoketest_or_raises
3636

3737
# We pass through the error message from numpy
3838
_slice_iloc_msg = re.escape(
@@ -41,13 +41,19 @@
4141
)
4242

4343

44-
class TestiLoc(Base):
44+
class TestiLoc:
4545
@pytest.mark.parametrize("key", [2, -1, [0, 1, 2]])
46-
def test_iloc_getitem_int_and_list_int(self, key):
47-
self.check_result(
46+
@pytest.mark.parametrize("kind", ["series", "frame"])
47+
@pytest.mark.parametrize(
48+
"col",
49+
["labels", "mixed", "ts", "floats", "empty"],
50+
)
51+
def test_iloc_getitem_int_and_list_int(self, key, kind, col, request):
52+
obj = request.getfixturevalue(f"{kind}_{col}")
53+
check_indexing_smoketest_or_raises(
54+
obj,
4855
"iloc",
4956
key,
50-
typs=["labels", "mixed", "ts", "floats", "empty"],
5157
fails=IndexError,
5258
)
5359

0 commit comments

Comments
 (0)