Skip to content

Commit 1dbd792

Browse files
Backport PR #54566 on branch 2.1.x (ENH: support Index.any/all with float, timedelta64 dtypes) (#54693)
Backport PR #54566: ENH: support Index.any/all with float, timedelta64 dtypes Co-authored-by: jbrockmendel <[email protected]>
1 parent 6c5e79b commit 1dbd792

File tree

5 files changed

+47
-23
lines changed

5 files changed

+47
-23
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ Other enhancements
265265
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to ``lzma.LZMAFile`` (:issue:`52979`)
266266
- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`)
267267
- :meth:`DataFrame.to_parquet` and :func:`read_parquet` will now write and read ``attrs`` respectively (:issue:`54346`)
268+
- :meth:`Index.all` and :meth:`Index.any` with floating dtypes and timedelta64 dtypes no longer raise ``TypeError``, matching the :meth:`Series.all` and :meth:`Series.any` behavior (:issue:`54566`)
268269
- :meth:`Series.cummax`, :meth:`Series.cummin` and :meth:`Series.cumprod` are now supported for pyarrow dtypes with pyarrow version 13.0 and above (:issue:`52085`)
269270
- Added support for the DataFrame Consortium Standard (:issue:`54383`)
270271
- Performance improvement in :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` (:issue:`51722`)

pandas/core/indexes/base.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -7215,11 +7215,12 @@ def any(self, *args, **kwargs):
72157215
"""
72167216
nv.validate_any(args, kwargs)
72177217
self._maybe_disable_logical_methods("any")
7218-
# error: Argument 1 to "any" has incompatible type "ArrayLike"; expected
7219-
# "Union[Union[int, float, complex, str, bytes, generic], Sequence[Union[int,
7220-
# float, complex, str, bytes, generic]], Sequence[Sequence[Any]],
7221-
# _SupportsArray]"
7222-
return np.any(self.values) # type: ignore[arg-type]
7218+
vals = self._values
7219+
if not isinstance(vals, np.ndarray):
7220+
# i.e. EA, call _reduce instead of "any" to get TypeError instead
7221+
# of AttributeError
7222+
return vals._reduce("any")
7223+
return np.any(vals)
72237224

72247225
def all(self, *args, **kwargs):
72257226
"""
@@ -7262,11 +7263,12 @@ def all(self, *args, **kwargs):
72627263
"""
72637264
nv.validate_all(args, kwargs)
72647265
self._maybe_disable_logical_methods("all")
7265-
# error: Argument 1 to "all" has incompatible type "ArrayLike"; expected
7266-
# "Union[Union[int, float, complex, str, bytes, generic], Sequence[Union[int,
7267-
# float, complex, str, bytes, generic]], Sequence[Sequence[Any]],
7268-
# _SupportsArray]"
7269-
return np.all(self.values) # type: ignore[arg-type]
7266+
vals = self._values
7267+
if not isinstance(vals, np.ndarray):
7268+
# i.e. EA, call _reduce instead of "all" to get TypeError instead
7269+
# of AttributeError
7270+
return vals._reduce("all")
7271+
return np.all(vals)
72707272

72717273
@final
72727274
def _maybe_disable_logical_methods(self, opname: str_t) -> None:
@@ -7275,9 +7277,9 @@ def _maybe_disable_logical_methods(self, opname: str_t) -> None:
72757277
"""
72767278
if (
72777279
isinstance(self, ABCMultiIndex)
7278-
or needs_i8_conversion(self.dtype)
7279-
or isinstance(self.dtype, (IntervalDtype, CategoricalDtype))
7280-
or is_float_dtype(self.dtype)
7280+
# TODO(3.0): PeriodArray and DatetimeArray any/all will raise,
7281+
# so checking needs_i8_conversion will be unnecessary
7282+
or (needs_i8_conversion(self.dtype) and self.dtype.kind != "m")
72817283
):
72827284
# This call will raise
72837285
make_invalid_op(opname)(self)

pandas/tests/indexes/numeric/test_numeric.py

+8
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,14 @@ def test_fillna_float64(self):
227227
exp = Index([1.0, "obj", 3.0], name="x")
228228
tm.assert_index_equal(idx.fillna("obj"), exp, exact=True)
229229

230+
def test_logical_compat(self, simple_index):
231+
idx = simple_index
232+
assert idx.all() == idx.values.all()
233+
assert idx.any() == idx.values.any()
234+
235+
assert idx.all() == idx.to_series().all()
236+
assert idx.any() == idx.to_series().any()
237+
230238

231239
class TestNumericInt:
232240
@pytest.fixture(params=[np.int64, np.int32, np.int16, np.int8, np.uint64])

pandas/tests/indexes/test_base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,12 @@ def test_format_missing(self, vals, nulls_fixture):
692692
@pytest.mark.parametrize("op", ["any", "all"])
693693
def test_logical_compat(self, op, simple_index):
694694
index = simple_index
695-
assert getattr(index, op)() == getattr(index.values, op)()
695+
left = getattr(index, op)()
696+
assert left == getattr(index.values, op)()
697+
right = getattr(index.to_series(), op)()
698+
# left might not match right exactly in e.g. string cases where the
699+
# because we use np.any/all instead of .any/all
700+
assert bool(left) == bool(right)
696701

697702
@pytest.mark.parametrize(
698703
"index", ["string", "int64", "int32", "float64", "float32"], indirect=True

pandas/tests/indexes/test_old_base.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,25 @@ def test_numeric_compat(self, simple_index):
209209
1 // idx
210210

211211
def test_logical_compat(self, simple_index):
212-
if (
213-
isinstance(simple_index, RangeIndex)
214-
or is_numeric_dtype(simple_index.dtype)
215-
or simple_index.dtype == object
216-
):
212+
if simple_index.dtype == object:
217213
pytest.skip("Tested elsewhere.")
218214
idx = simple_index
219-
with pytest.raises(TypeError, match="cannot perform all"):
220-
idx.all()
221-
with pytest.raises(TypeError, match="cannot perform any"):
222-
idx.any()
215+
if idx.dtype.kind in "iufcbm":
216+
assert idx.all() == idx._values.all()
217+
assert idx.all() == idx.to_series().all()
218+
assert idx.any() == idx._values.any()
219+
assert idx.any() == idx.to_series().any()
220+
else:
221+
msg = "cannot perform (any|all)"
222+
if isinstance(idx, IntervalIndex):
223+
msg = (
224+
r"'IntervalArray' with dtype interval\[.*\] does "
225+
"not support reduction '(any|all)'"
226+
)
227+
with pytest.raises(TypeError, match=msg):
228+
idx.all()
229+
with pytest.raises(TypeError, match=msg):
230+
idx.any()
223231

224232
def test_repr_roundtrip(self, simple_index):
225233
if isinstance(simple_index, IntervalIndex):

0 commit comments

Comments
 (0)