Skip to content

Commit 92d1d6a

Browse files
authored
REF: de-duplicate check_reduce_frame (#54393)
1 parent 6ffa4b7 commit 92d1d6a

File tree

6 files changed

+70
-78
lines changed

6 files changed

+70
-78
lines changed

pandas/tests/extension/base/reduce.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import final
12
import warnings
23

34
import pytest
@@ -15,6 +16,9 @@ class BaseReduceTests(BaseExtensionTests):
1516
"""
1617

1718
def check_reduce(self, s, op_name, skipna):
19+
# We perform the same operation on the np.float64 data and check
20+
# that the results match. Override if you need to cast to something
21+
# other than float64.
1822
res_op = getattr(s, op_name)
1923
exp_op = getattr(s.astype("float64"), op_name)
2024
if op_name == "count":
@@ -25,6 +29,43 @@ def check_reduce(self, s, op_name, skipna):
2529
expected = exp_op(skipna=skipna)
2630
tm.assert_almost_equal(result, expected)
2731

32+
def _get_expected_reduction_dtype(self, arr, op_name: str):
33+
# Find the expected dtype when the given reduction is done on a DataFrame
34+
# column with this array. The default assumes float64-like behavior,
35+
# i.e. retains the dtype.
36+
return arr.dtype
37+
38+
# We anticipate that authors should not need to override check_reduce_frame,
39+
# but should be able to do any necessary overriding in
40+
# _get_expected_reduction_dtype. If you have a use case where this
41+
# does not hold, please let us know at github.com/pandas-dev/pandas/issues.
42+
@final
43+
def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
44+
# Check that the 2D reduction done in a DataFrame reduction "looks like"
45+
# a wrapped version of the 1D reduction done by Series.
46+
arr = ser.array
47+
df = pd.DataFrame({"a": arr})
48+
49+
kwargs = {"ddof": 1} if op_name in ["var", "std"] else {}
50+
51+
cmp_dtype = self._get_expected_reduction_dtype(arr, op_name)
52+
53+
# The DataFrame method just calls arr._reduce with keepdims=True,
54+
# so this first check is perfunctory.
55+
result1 = arr._reduce(op_name, skipna=skipna, keepdims=True, **kwargs)
56+
result2 = getattr(df, op_name)(skipna=skipna, **kwargs).array
57+
tm.assert_extension_array_equal(result1, result2)
58+
59+
# Check that the 2D reduction looks like a wrapped version of the
60+
# 1D reduction
61+
if not skipna and ser.isna().any():
62+
expected = pd.array([pd.NA], dtype=cmp_dtype)
63+
else:
64+
exp_value = getattr(ser.dropna(), op_name)()
65+
expected = pd.array([exp_value], dtype=cmp_dtype)
66+
67+
tm.assert_extension_array_equal(result1, expected)
68+
2869

2970
class BaseNoReduceTests(BaseReduceTests):
3071
"""we don't define any reductions"""
@@ -71,9 +112,12 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna):
71112
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
72113
op_name = all_numeric_reductions
73114
s = pd.Series(data)
74-
if not is_numeric_dtype(s):
115+
if not is_numeric_dtype(s.dtype):
75116
pytest.skip("not numeric dtype")
76117

118+
if op_name in ["count", "kurt", "sem"]:
119+
pytest.skip(f"{op_name} not an array method")
120+
77121
self.check_reduce_frame(s, op_name, skipna)
78122

79123

pandas/tests/extension/decimal/test_decimal.py

+8-22
Original file line numberDiff line numberDiff line change
@@ -160,27 +160,6 @@ def check_reduce(self, s, op_name, skipna):
160160
expected = getattr(np.asarray(s), op_name)()
161161
tm.assert_almost_equal(result, expected)
162162

163-
def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
164-
arr = ser.array
165-
df = pd.DataFrame({"a": arr})
166-
167-
if op_name in ["count", "kurt", "sem", "skew", "median"]:
168-
assert not hasattr(arr, op_name)
169-
pytest.skip(f"{op_name} not an array method")
170-
171-
result1 = arr._reduce(op_name, skipna=skipna, keepdims=True)
172-
result2 = getattr(df, op_name)(skipna=skipna).array
173-
174-
tm.assert_extension_array_equal(result1, result2)
175-
176-
if not skipna and ser.isna().any():
177-
expected = DecimalArray([pd.NA])
178-
else:
179-
exp_value = getattr(ser.dropna(), op_name)()
180-
expected = DecimalArray([exp_value])
181-
182-
tm.assert_extension_array_equal(result1, expected)
183-
184163
def test_reduction_without_keepdims(self):
185164
# GH52788
186165
# test _reduce without keepdims
@@ -205,7 +184,14 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
205184

206185

207186
class TestNumericReduce(Reduce, base.BaseNumericReduceTests):
208-
pass
187+
@pytest.mark.parametrize("skipna", [True, False])
188+
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
189+
op_name = all_numeric_reductions
190+
if op_name in ["skew", "median"]:
191+
assert not hasattr(data, op_name)
192+
pytest.skip(f"{op_name} not an array method")
193+
194+
return super().test_reduce_frame(data, all_numeric_reductions, skipna)
209195

210196

211197
class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):

pandas/tests/extension/test_arrow.py

+9-17
Original file line numberDiff line numberDiff line change
@@ -499,15 +499,7 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
499499
request.node.add_marker(xfail_mark)
500500
super().test_reduce_series(data, all_numeric_reductions, skipna)
501501

502-
def check_reduce_frame(self, ser, op_name, skipna):
503-
arr = ser.array
504-
505-
if op_name in ["count", "kurt", "sem", "skew"]:
506-
assert not hasattr(arr, op_name)
507-
return
508-
509-
kwargs = {"ddof": 1} if op_name in ["var", "std"] else {}
510-
502+
def _get_expected_reduction_dtype(self, arr, op_name: str):
511503
if op_name in ["max", "min"]:
512504
cmp_dtype = arr.dtype
513505
elif arr.dtype.name == "decimal128(7, 3)[pyarrow]":
@@ -523,15 +515,15 @@ def check_reduce_frame(self, ser, op_name, skipna):
523515
"u": "uint64[pyarrow]",
524516
"f": "float64[pyarrow]",
525517
}[arr.dtype.kind]
526-
result = arr._reduce(op_name, skipna=skipna, keepdims=True, **kwargs)
518+
return cmp_dtype
527519

528-
if not skipna and ser.isna().any():
529-
expected = pd.array([pd.NA], dtype=cmp_dtype)
530-
else:
531-
exp_value = getattr(ser.dropna().astype(cmp_dtype), op_name)(**kwargs)
532-
expected = pd.array([exp_value], dtype=cmp_dtype)
533-
534-
tm.assert_extension_array_equal(result, expected)
520+
@pytest.mark.parametrize("skipna", [True, False])
521+
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
522+
op_name = all_numeric_reductions
523+
if op_name == "skew":
524+
assert not hasattr(data, op_name)
525+
return
526+
return super().test_reduce_frame(data, all_numeric_reductions, skipna)
535527

536528
@pytest.mark.parametrize("typ", ["int64", "uint64", "float64"])
537529
def test_median_not_approximate(self, typ):

pandas/tests/extension/test_boolean.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,7 @@ def check_reduce(self, s, op_name, skipna):
235235
expected = bool(expected)
236236
tm.assert_almost_equal(result, expected)
237237

238-
def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
239-
arr = ser.array
240-
241-
if op_name in ["count", "kurt", "sem"]:
242-
assert not hasattr(arr, op_name)
243-
pytest.skip(f"{op_name} not an array method")
244-
238+
def _get_expected_reduction_dtype(self, arr, op_name: str):
245239
if op_name in ["mean", "median", "var", "std", "skew"]:
246240
cmp_dtype = "Float64"
247241
elif op_name in ["min", "max"]:
@@ -251,14 +245,7 @@ def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
251245
cmp_dtype = "Int32" if is_windows_or_32bit else "Int64"
252246
else:
253247
raise TypeError("not supposed to reach this")
254-
255-
result = arr._reduce(op_name, skipna=skipna, keepdims=True)
256-
if not skipna and ser.isna().any():
257-
expected = pd.array([pd.NA], dtype=cmp_dtype)
258-
else:
259-
exp_value = getattr(ser.dropna().astype(cmp_dtype), op_name)()
260-
expected = pd.array([exp_value], dtype=cmp_dtype)
261-
tm.assert_extension_array_equal(result, expected)
248+
return cmp_dtype
262249

263250

264251
class TestBooleanReduce(base.BaseBooleanReduceTests):

pandas/tests/extension/test_masked_numeric.py

+4-22
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
)
4040
from pandas.tests.extension import base
4141

42+
is_windows_or_32bit = is_platform_windows() or not IS64
43+
4244
pytestmark = [
4345
pytest.mark.filterwarnings(
4446
"ignore:invalid value encountered in divide:RuntimeWarning"
@@ -246,16 +248,7 @@ def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
246248
expected = pd.NA
247249
tm.assert_almost_equal(result, expected)
248250

249-
def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
250-
if op_name in ["count", "kurt", "sem"]:
251-
assert not hasattr(ser.array, op_name)
252-
pytest.skip(f"{op_name} not an array method")
253-
254-
arr = ser.array
255-
df = pd.DataFrame({"a": arr})
256-
257-
is_windows_or_32bit = is_platform_windows() or not IS64
258-
251+
def _get_expected_reduction_dtype(self, arr, op_name: str):
259252
if tm.is_float_dtype(arr.dtype):
260253
cmp_dtype = arr.dtype.name
261254
elif op_name in ["mean", "median", "var", "std", "skew"]:
@@ -270,18 +263,7 @@ def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
270263
cmp_dtype = "UInt32" if is_windows_or_32bit else "UInt64"
271264
else:
272265
raise TypeError("not supposed to reach this")
273-
274-
if not skipna and ser.isna().any():
275-
expected = pd.array([pd.NA], dtype=cmp_dtype)
276-
else:
277-
exp_value = getattr(ser.dropna().astype(cmp_dtype), op_name)()
278-
expected = pd.array([exp_value], dtype=cmp_dtype)
279-
280-
result1 = arr._reduce(op_name, skipna=skipna, keepdims=True)
281-
result2 = getattr(df, op_name)(skipna=skipna).array
282-
283-
tm.assert_extension_array_equal(result1, result2)
284-
tm.assert_extension_array_equal(result2, expected)
266+
return cmp_dtype
285267

286268

287269
@pytest.mark.skip(reason="Tested in tests/reductions/test_reductions.py")

pandas/tests/extension/test_numpy.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ def check_reduce(self, s, op_name, skipna):
311311
tm.assert_almost_equal(result, expected)
312312

313313
@pytest.mark.skip("tests not written yet")
314-
def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
314+
@pytest.mark.parametrize("skipna", [True, False])
315+
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
315316
pass
316317

317318
@pytest.mark.parametrize("skipna", [True, False])

0 commit comments

Comments
 (0)